廣義可加模型 (GAM)




以下範例說明高斯和泊松迴歸,其中類別變數被視為線性項,而兩個解釋變數的影響則由懲罰的 B 樣條捕捉。數據來自汽車數據集 https://archive.ics.uci.edu/ml/datasets/automobile 我們可以從單元測試模組中載入具有選定列的數據框。

In [1]: import statsmodels.api as sm

In [2]: from statsmodels.gam.api import GLMGam, BSplines

# import data
In [3]: from statsmodels.gam.tests.test_penalized import df_autos

# create spline basis for weight and hp
In [4]: x_spline = df_autos[['weight', 'hp']]

In [5]: bs = BSplines(x_spline, df=[12, 10], degree=[3, 3])

# penalization weight
In [6]: alpha = np.array([21833888.8, 6460.38479])

In [7]: gam_bs = GLMGam.from_formula('city_mpg ~ fuel + drive', data=df_autos,
   ...:                              smoother=bs, alpha=alpha)

In [8]: res_bs = gam_bs.fit()

In [9]: print(res_bs.summary())
                 Generalized Linear Model Regression Results                  
Dep. Variable:               city_mpg   No. Observations:                  203
Model:                         GLMGam   Df Residuals:                   189.13
Model Family:                Gaussian   Df Model:                        12.87
Link Function:               Identity   Scale:                          4.8825
Method:                         PIRLS   Log-Likelihood:                -441.81
Date:                Thu, 03 Oct 2024   Deviance:                       923.45
Time:                        16:09:46   Pearson chi2:                     923.
No. Iterations:                     3   Pseudo R-squ. (CS):             0.9996
Covariance Type:            nonrobust                                         
                   coef    std err          z      P>|z|      [0.025      0.975]
Intercept       51.9923      1.997     26.034      0.000      48.078      55.906
fuel[T.gas]     -5.8099      0.727     -7.989      0.000      -7.235      -4.385
drive[T.fwd]     1.3910      0.819      1.699      0.089      -0.213       2.995
drive[T.rwd]     1.0638      0.842      1.263      0.207      -0.587       2.715
weight_s0       -3.5556      0.959     -3.707      0.000      -5.436      -1.676
weight_s1       -9.0876      1.750     -5.193      0.000     -12.518      -5.658
weight_s2      -13.0303      1.827     -7.132      0.000     -16.611      -9.450
weight_s3      -14.2641      1.854     -7.695      0.000     -17.897     -10.631
weight_s4      -15.1805      1.892     -8.024      0.000     -18.889     -11.472
weight_s5      -15.9557      1.963     -8.128      0.000     -19.803     -12.108
weight_s6      -16.6297      2.038     -8.161      0.000     -20.624     -12.636
weight_s7      -16.9928      2.045     -8.308      0.000     -21.002     -12.984
weight_s8      -19.3480      2.367     -8.174      0.000     -23.987     -14.709
weight_s9      -20.7978      2.455     -8.472      0.000     -25.609     -15.986
weight_s10     -20.8062      2.443     -8.517      0.000     -25.594     -16.018
hp_s0           -1.4473      0.558     -2.592      0.010      -2.542      -0.353
hp_s1           -3.4228      1.012     -3.381      0.001      -5.407      -1.438
hp_s2           -5.9026      1.251     -4.717      0.000      -8.355      -3.450
hp_s3           -7.2389      1.352     -5.354      0.000      -9.889      -4.589
hp_s4           -9.1052      1.384     -6.581      0.000     -11.817      -6.393
hp_s5           -9.9865      1.525     -6.547      0.000     -12.976      -6.997
hp_s6          -13.3639      2.228     -5.998      0.000     -17.731      -8.997
hp_s7          -13.8902      3.194     -4.349      0.000     -20.150      -7.630
hp_s8          -11.9752      2.556     -4.685      0.000     -16.985      -6.965

# plot smooth components
In [10]: res_bs.plot_partial(0, cpr=True)
Out[10]: <Figure size 640x480 with 1 Axes>

In [11]: res_bs.plot_partial(1, cpr=True)
Out[11]: <Figure size 640x480 with 1 Axes>

In [12]: alpha = np.array([8283989284.5829611, 14628207.58927821])

In [13]: gam_bs = GLMGam.from_formula('city_mpg ~ fuel + drive', data=df_autos,
   ....:                              smoother=bs, alpha=alpha,
   ....:                              family=sm.families.Poisson())

In [14]: res_bs = gam_bs.fit()

In [15]: print(res_bs.summary())
                 Generalized Linear Model Regression Results                  
Dep. Variable:               city_mpg   No. Observations:                  203
Model:                         GLMGam   Df Residuals:                   194.75
Model Family:                 Poisson   Df Model:                         7.25
Link Function:                    Log   Scale:                          1.0000
Method:                         PIRLS   Log-Likelihood:                -530.38
Date:                Thu, 03 Oct 2024   Deviance:                       37.569
Time:                        16:09:46   Pearson chi2:                     37.4
No. Iterations:                     6   Pseudo R-squ. (CS):             0.7715
Covariance Type:            nonrobust                                         
                   coef    std err          z      P>|z|      [0.025      0.975]
Intercept        3.9960      0.130     30.844      0.000       3.742       4.250
fuel[T.gas]     -0.2398      0.057     -4.222      0.000      -0.351      -0.128
drive[T.fwd]     0.0386      0.075      0.513      0.608      -0.109       0.186
drive[T.rwd]     0.0309      0.078      0.395      0.693      -0.122       0.184
weight_s0       -0.0811      0.030     -2.689      0.007      -0.140      -0.022
weight_s1       -0.1938      0.063     -3.067      0.002      -0.318      -0.070
weight_s2       -0.3160      0.082     -3.864      0.000      -0.476      -0.156
weight_s3       -0.3735      0.090     -4.160      0.000      -0.549      -0.198
weight_s4       -0.4187      0.096     -4.360      0.000      -0.607      -0.230
weight_s5       -0.4645      0.103     -4.495      0.000      -0.667      -0.262
weight_s6       -0.5092      0.112     -4.555      0.000      -0.728      -0.290
weight_s7       -0.5469      0.119     -4.598      0.000      -0.780      -0.314
weight_s8       -0.6211      0.137     -4.528      0.000      -0.890      -0.352
weight_s9       -0.6866      0.153     -4.486      0.000      -0.987      -0.387
weight_s10      -0.7370      0.174     -4.228      0.000      -1.079      -0.395
hp_s0           -0.0247      0.010     -2.378      0.017      -0.045      -0.004
hp_s1           -0.0557      0.022     -2.479      0.013      -0.100      -0.012
hp_s2           -0.1046      0.038     -2.719      0.007      -0.180      -0.029
hp_s3           -0.1438      0.050     -2.857      0.004      -0.242      -0.045
hp_s4           -0.1919      0.063     -3.047      0.002      -0.315      -0.068
hp_s5           -0.2567      0.079     -3.231      0.001      -0.412      -0.101
hp_s6           -0.4152      0.120     -3.455      0.001      -0.651      -0.180
hp_s7           -0.4889      0.152     -3.214      0.001      -0.787      -0.191
hp_s8           -0.5470      0.195     -2.810      0.005      -0.928      -0.166

# Optimal penalization weights alpha can be obtained through generalized
# cross-validation or k-fold cross-validation.
# The alpha above are from the unit tests against the R mgcv package.
In [16]: gam_bs.select_penweight()[0]
Out[16]: array([8.2839e+09, 1.4628e+07])

In [17]: gam_bs.select_penweight_kfold()[0]
Out[17]: (np.float64(10000000.0), np.float64(15848.931924611108))


  • Hastie, Trevor, and Robert Tibshirani. 1986. Generalized Additive Models. Statistical Science 1 (3): 297-310.

  • Wood, Simon N. 2006. Generalized Additive Models: An Introduction with R. Texts in Statistical Science. Boca Raton, FL: Chapman & Hall/CRC.

  • Wood, Simon N. 2017. Generalized Additive Models: An Introduction with R. Second edition. Chapman & Hall/CRC Texts in Statistical Science. Boca Raton: CRC Press/Taylor & Francis Group.



GLMGam(endog[, exog, smoother, alpha, ...])

廣義可加模型 (GAM)

LogitGam(endog, smoother, alpha, *args, **kwargs)

離散 Logit 的廣義可加模型


GLMGamResults(model, params, ...)

廣義可加模型 GAM 的結果類別。



BSplines(x, df, degree[, include_intercept, ...])

使用 B 樣條的加性平滑成分

CyclicCubicSplines(x, df[, constraints, ...])


statsmodels.gam.smooth_basis 包含額外的樣條和一個(全局)多項式平滑器基底,但這些尚未經過驗證。

GLMGam 中的分佈族系與 GLM 相同,因此對應的連結函數也相同。目前的單元測試僅涵蓋高斯和泊松分佈,而 GLMGam 可能不適用於 GLM 中所有可用的選項。

