初めに
plotlyはインタラクティブなグラフを描くことができ、データ解析において強力なツールです。今回は、おなじみirisデータセットを用い、plotlyで95%信頼区間を見栄えよく表示させる方法を紹介します。
グラフとコード
まずは実際のグラフとコードを示します。
# ライブラリインポート
from sklearn.datasets import load_iris
import polars as pl
import plotly.graph_objects as go
import plotly.express as px
import statsmodels.formula.api as smf
# データインポート
iris = load_iris()
# polarsのDataFrame化
df = (
pl.DataFrame(iris["data"], schema = iris["feature_names"])
.with_columns(target = iris["target"])
.with_columns(target = pl.when(pl.col("target") == 0)
.then(pl.lit("setosa"))
.when(pl.col("target") == 1)
.then(pl.lit("versicolor"))
.otherwise(pl.lit("virginica")))
# statsmodels.apiを使用するために必要なリネーム
.rename({"sepal width (cm)" : "sepal_width",
"sepal length (cm)" : "sepal_length"})
)
# statsmodels.apiによる信頼区間の計算
list_species = ["setosa", "versicolor", "virginica"]
list_pred = []
for species in list_species:
df_species = df.filter(pl.col("target") == species)
# sepal_widthを目的変数、sepal_lengthを説明変数に設定し、95%信頼区間を求める。
lm_model = smf.ols(formula = "sepal_width ~ sepal_length", data = df_species).fit()
list_pred += [pl.concat([
df_species,
pl.DataFrame(lm_model.get_prediction().summary_frame(alpha = 0.05))
], how = "horizontal").sort("sepal_length")]
df = pl.concat(list_pred)
# 信頼区間領域を塗りつぶす色
list_fillcolor = ["rgba(99, 110, 250, 0.2)", "rgba(239, 85, 59, 0.2)", "rgba(0, 204, 150, 0.2)"]
fig = go.Figure()
for i, species in enumerate(list_species):
# 種ごとでフィルタリング
plot = df.filter(pl.col("target") == species)
# 生データ
fig.add_trace(go.Scatter(
x = plot["sepal_length"],
y = plot["sepal_width"],
mode = "markers",
marker = dict(color = px.colors.qualitative.Plotly[i]),
name = species
))
# 線形回帰線
fig.add_trace(go.Scatter(
x = plot["sepal_length"],
y = plot["mean"],
mode = "lines",
line = dict(color = px.colors.qualitative.Plotly[i]),
name = species + "_mean"
))
# 信頼区間下側
fig.add_trace(go.Scatter(
x = plot["sepal_length"],
y = plot["mean_ci_lower"],
mode = "lines",
line = dict(
color = px.colors.qualitative.Plotly[i],
width = 0
),
showlegend = False,
legendgroup = species,
name = species + "_conf_interval",
))
# 信頼区間上側
fig.add_trace(go.Scatter(
x = plot["sepal_length"],
y = plot["mean_ci_upper"],
mode = "lines",
line = dict(
color = px.colors.qualitative.Plotly[i],
width = 0
),
legendgroup = species,
name = species + "_conf_interval",
fill = "tonexty",
fillcolor = list_fillcolor[i]
))
fig.update_xaxes(title = "sepal length (cm)")
fig.update_yaxes(title = "sepal width (cm)")
fig.show()
解説
データ準備
まず、以下コードでirisデータセットを取得します。
iris = load_iris()
続いて、以下コードでpolarsのDataFrameに変換します。
df = (
pl.DataFrame(iris["data"], schema = iris["feature_names"])
.with_columns(target = iris["target"])
.with_columns(target = pl.when(pl.col("target") == 0)
.then(pl.lit("setosa"))
.when(pl.col("target") == 1)
.then(pl.lit("versicolor"))
.otherwise(pl.lit("virginica")))
# statsmodels.apiを使用するために必要なリネーム
.rename({"sepal width (cm)" : "sepal_width",
"sepal length (cm)" : "sepal_length"})
)
ここで、1つ注意点があります。target
の0, 1, 2にそれぞれの種名を割り当てる際にpl.when()
を用いていますが、then
の中身をstr
型で渡す場合はpl.lit()
として渡す必要があります。
95%信頼区間の算出
今回はstatsmodels.api
を利用した線形回帰モデルにより、95%信頼区間を求めます。
list_species = ["setosa", "versicolor", "virginica"]
list_pred = []
for species in list_species:
df_species = df.filter(pl.col("target") == species)
# sepal_widthを目的変数、sepal_lengthを説明変数に設定し、95%信頼区間を求める。
lm_model = smf.ols(formula = "sepal_width ~ sepal_length", data = df_species).fit()
list_pred += [pl.concat([
df_species,
pl.DataFrame(lm_model.get_prediction().summary_frame(alpha = 0.05))
], how = "horizontal").sort("sepal_length")]
df = pl.concat(list_pred)
list_species
に品種名をリストとして格納します。そして、種ごとに線形回帰モデルを作成し95%信頼区間を求めます。
lm_model = smf.ols().fit()
でモデル構築を行い、lm_model.get_prediction().summary_frame()
で信頼区間を得ます。
同時に回帰線、予想区間も出力されます。たった2行で線形回帰計算が完結してしまうとは。
グラフ作成
plotly.graph_objects
を用いてグラフを作成していきます。
まず、塗りつぶしの色をリストに格納します。これは、今回透明度設定(alpha)をしたいためで、rgba形式の色コードをリスト化することにより実装しています。
なお、ベースの色はplotlyデフォルトの色に対応させています。
list_fillcolor = ["rgba(99, 110, 250, 0.2)", "rgba(239, 85, 59, 0.2)", "rgba(0, 204, 150, 0.2)"]
続いて、いよいよ描画です。今回は種ごとに色を分けたいので、種でループを回します。
plotlyのグラフオブジェクトを生成したのち、先ほど作成したリストを用い、プロットに用いる一次的なDataFrame(plot)を作成しています。
fig = go.Figure()
for i, species in enumerate(list_species):
# 種ごとでフィルタリング
plot = df.filter(pl.col("target") == species)
まず、生データを散布図で図示します。marker=dict(color=...)
でマーカーの色を指定しています。px.colors.qualitative.Plotly
にループ回数を渡すことで、種別の色分けを実装しています。
fig.add_trace(go.Scatter(
x = plot["sepal_length"],
y = plot["sepal_width"],
mode = "markers",
marker = dict(color = px.colors.qualitative.Plotly[i]),
name = species
))
同様に、線形回帰線も図示します。
fig.add_trace(go.Scatter(
x = plot["sepal_length"],
y = plot["mean"],
mode = "lines",
line = dict(color = px.colors.qualitative.Plotly[i]),
name = species + "_mean"
))
続いて、信頼区間を図示していきます。まずは下側です。
fig.add_trace(go.Scatter(
x = plot["sepal_length"],
y = plot["mean_ci_lower"],
mode = "lines",
line = dict(
color = px.colors.qualitative.Plotly[i],
width = 0
),
showlegend = False,
legendgroup = species,
name = species + "_conf_interval",
))
mode=lines
とした上でwidth=0
とし、実質的に見えなくしています。また、上側と1セットにしたいので、legendgroup=species
としています。
次に、上側です。
fig.add_trace(go.Scatter(
x = plot["sepal_length"],
y = plot["mean_ci_upper"],
mode = "lines",
line = dict(
color = px.colors.qualitative.Plotly[i],
width = 0
),
legendgroup = species,
name = species + "_conf_interval",
fill = "tonexty",
fillcolor = list_fillcolor[i]
))
fill
およびfillcolor
が追加されています。これが、plotlyで領域の塗りつぶしを実装するコマンドです。
今回は"tonexty"
を指定していますが、他にも"tozeroy"
などいろいろあります。詳しくは以下リンクを参照ください。

Filled
Over 9 examples of Filled Area Plots including changing color, size, log axes, and more in Python.
最後に軸ラベルを設定して完成です。
fig.update_xaxes(title = "sepal length (cm)")
fig.update_yaxes(title = "sepal width (cm)")
fig.show()
4. 最後に
いかがだったでしょうか。通常の散布図を図示するのとほとんど同じ手順で実装できるので(fill
とfillcolor
を追加するだけ)、ぜひ試してみてください。
コメント