plotlyによる95%信頼区間の表示

初めに

plotlyはインタラクティブなグラフを描くことができ、データ解析において強力なツールです。今回は、おなじみirisデータセットを用い、plotlyで95%信頼区間を見栄えよく表示させる方法を紹介します。

グラフとコード

まずは実際のグラフとコードを示します。

# ライブラリインポート
from sklearn.datasets import load_iris
import polars as pl
import plotly.graph_objects as go
import plotly.express as px
import statsmodels.formula.api as smf

# データインポート
iris = load_iris()

# polarsのDataFrame化
df = (
    pl.DataFrame(iris["data"], schema = iris["feature_names"])
    .with_columns(target = iris["target"])
    .with_columns(target = pl.when(pl.col("target") == 0)
                  .then(pl.lit("setosa"))
                  .when(pl.col("target") == 1)
                  .then(pl.lit("versicolor"))
                  .otherwise(pl.lit("virginica")))

    # statsmodels.apiを使用するために必要なリネーム
    .rename({"sepal width (cm)" : "sepal_width",
             "sepal length (cm)" : "sepal_length"})
)


# statsmodels.apiによる信頼区間の計算
list_species = ["setosa", "versicolor", "virginica"]
list_pred = []
for species in list_species:
    df_species = df.filter(pl.col("target") == species)

    # sepal_widthを目的変数、sepal_lengthを説明変数に設定し、95%信頼区間を求める。
    lm_model = smf.ols(formula = "sepal_width ~ sepal_length", data = df_species).fit()
    list_pred += [pl.concat([
        df_species,
        pl.DataFrame(lm_model.get_prediction().summary_frame(alpha = 0.05))
        ], how = "horizontal").sort("sepal_length")]

df = pl.concat(list_pred)


# 信頼区間領域を塗りつぶす色
list_fillcolor = ["rgba(99, 110, 250, 0.2)", "rgba(239, 85, 59, 0.2)", "rgba(0, 204, 150, 0.2)"]

fig = go.Figure()
for i, species in enumerate(list_species):
    # 種ごとでフィルタリング
    plot = df.filter(pl.col("target") == species)

    # 生データ
    fig.add_trace(go.Scatter(
        x = plot["sepal_length"],
        y = plot["sepal_width"],
        mode = "markers",
        marker = dict(color = px.colors.qualitative.Plotly[i]),
        name = species
    ))

    # 線形回帰線
    fig.add_trace(go.Scatter(
        x = plot["sepal_length"],
        y = plot["mean"],
        mode = "lines",
        line = dict(color = px.colors.qualitative.Plotly[i]),
        name = species + "_mean"
    ))

    # 信頼区間下側
    fig.add_trace(go.Scatter(
        x = plot["sepal_length"],
        y = plot["mean_ci_lower"],
        mode = "lines",
        line = dict(
            color = px.colors.qualitative.Plotly[i],
            width = 0
        ),
        showlegend = False,
        legendgroup = species,
        name = species + "_conf_interval",
    ))

    # 信頼区間上側
    fig.add_trace(go.Scatter(
        x = plot["sepal_length"],
        y = plot["mean_ci_upper"],
        mode = "lines",
        line = dict(
            color = px.colors.qualitative.Plotly[i],
            width = 0
        ),
        legendgroup = species,
        name = species + "_conf_interval",
        fill = "tonexty",
        fillcolor = list_fillcolor[i]
    ))

fig.update_xaxes(title = "sepal length (cm)")
fig.update_yaxes(title = "sepal width (cm)")
fig.show()

解説

データ準備

まず、以下コードでirisデータセットを取得します。  

iris = load_iris()

続いて、以下コードでpolarsのDataFrameに変換します。  

df = (
    pl.DataFrame(iris["data"], schema = iris["feature_names"])
    .with_columns(target = iris["target"])
    .with_columns(target = pl.when(pl.col("target") == 0)
                  .then(pl.lit("setosa"))
                  .when(pl.col("target") == 1)
                  .then(pl.lit("versicolor"))
                  .otherwise(pl.lit("virginica")))

    # statsmodels.apiを使用するために必要なリネーム
    .rename({"sepal width (cm)" : "sepal_width",
             "sepal length (cm)" : "sepal_length"})
)

ここで、1つ注意点があります。targetの0, 1, 2にそれぞれの種名を割り当てる際にpl.when()を用いていますが、thenの中身をstr型で渡す場合はpl.lit()として渡す必要があります。

95%信頼区間の算出

今回はstatsmodels.apiを利用した線形回帰モデルにより、95%信頼区間を求めます。
list_species = ["setosa", "versicolor", "virginica"]
list_pred = []
for species in list_species:
    df_species = df.filter(pl.col("target") == species)

    # sepal_widthを目的変数、sepal_lengthを説明変数に設定し、95%信頼区間を求める。
    lm_model = smf.ols(formula = "sepal_width ~ sepal_length", data = df_species).fit()
    list_pred += [pl.concat([
        df_species,
        pl.DataFrame(lm_model.get_prediction().summary_frame(alpha = 0.05))
        ], how = "horizontal").sort("sepal_length")]

df = pl.concat(list_pred)
まず、list_speciesに品種名をリストとして格納します。そして、種ごとに線形回帰モデルを作成し95%信頼区間を求めます。 lm_model = smf.ols().fit()でモデル構築を行い、lm_model.get_prediction().summary_frame()で信頼区間を得ます。 同時に回帰線、予想区間も出力されます。たった2行で線形回帰計算が完結してしまうとは。

グラフ作成

plotly.graph_objectsを用いてグラフを作成していきます。
まず、塗りつぶしの色をリストに格納します。これは、今回透明度設定(alpha)をしたいためで、rgba形式の色コードをリスト化することにより実装しています。
なお、ベースの色はplotlyデフォルトの色に対応させています。

list_fillcolor = ["rgba(99, 110, 250, 0.2)", "rgba(239, 85, 59, 0.2)", "rgba(0, 204, 150, 0.2)"]

続いて、いよいよ描画です。今回は種ごとに色を分けたいので、種でループを回します。        

plotlyのグラフオブジェクトを生成したのち、先ほど作成したリストを用い、プロットに用いる一次的なDataFrame(plot)を作成しています。

fig = go.Figure()
for i, species in enumerate(list_species):
    # 種ごとでフィルタリング
    plot = df.filter(pl.col("target") == species)

まず、生データを散布図で図示します。marker=dict(color=...)でマーカーの色を指定しています。px.colors.qualitative.Plotlyにループ回数を渡すことで、種別の色分けを実装しています。

fig.add_trace(go.Scatter(
    x = plot["sepal_length"],
    y = plot["sepal_width"],
    mode = "markers",
    marker = dict(color = px.colors.qualitative.Plotly[i]),
    name = species
))

同様に、線形回帰線も図示します。

fig.add_trace(go.Scatter(
    x = plot["sepal_length"],
    y = plot["mean"],
    mode = "lines",
    line = dict(color = px.colors.qualitative.Plotly[i]),
    name = species + "_mean"
))

続いて、信頼区間を図示していきます。まずは下側です。

fig.add_trace(go.Scatter(
    x = plot["sepal_length"],
    y = plot["mean_ci_lower"],
    mode = "lines",
    line = dict(
        color = px.colors.qualitative.Plotly[i],
        width = 0
    ),
    showlegend = False,
    legendgroup = species,
    name = species + "_conf_interval",
))
領域として示したいので、mode=linesとした上でwidth=0とし、実質的に見えなくしています。また、上側と1セットにしたいので、legendgroup=speciesとしています。

次に、上側です。

fig.add_trace(go.Scatter(
    x = plot["sepal_length"],
    y = plot["mean_ci_upper"],
    mode = "lines",
    line = dict(
        color = px.colors.qualitative.Plotly[i],
        width = 0
    ),
    legendgroup = species,
    name = species + "_conf_interval",
    fill = "tonexty",
    fillcolor = list_fillcolor[i]
))
下側の図示とほぼ同じですが、fillおよびfillcolorが追加されています。これが、plotlyで領域の塗りつぶしを実装するコマンドです。 今回は"tonexty"を指定していますが、他にも"tozeroy"などいろいろあります。詳しくは以下リンクを参照ください。
Filled
Over 9 examples of Filled Area Plots including changing color, size, log axes, and more in Python.

最後に軸ラベルを設定して完成です。

fig.update_xaxes(title = "sepal length (cm)")
fig.update_yaxes(title = "sepal width (cm)")
fig.show()

4. 最後に

いかがだったでしょうか。通常の散布図を図示するのとほとんど同じ手順で実装できるので(fillfillcolorを追加するだけ)、ぜひ試してみてください。

コメント

タイトルとURLをコピーしました