"""第 8 章 回归分析 / 8.2 简单线性回归 —— ISLP 式实操 lab + 配图。

生成图表:
    fig_8_2_2_r2_decomposition.png   R² 的方差分解: SS(mean) vs SS(fit)
    fig_8_2_3_pred_vs_conf_bands.png 均值置信带(窄) vs 预测带(宽)
    fig_8_2_4_residual_vs_fitted.png 残差-拟合诊断图

并打印 statsmodels 工作流输出(summary / conf_int / get_prediction),
正文 8.2.9 直接引用这些数字。数据为"拟真合成"(固定 seed, 可复现)。

运行:
    python docs/assets/scripts/ch08_regression/02_simple_linear_regression.py
"""
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.formula.api as smf

if hasattr(sys.stdout, "reconfigure"):  # Windows 控制台默认 cp1252, 切到 utf-8
    sys.stdout.reconfigure(encoding="utf-8")

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "_shared"))
from plot_style import apply_style, figure, PALETTE  # noqa: E402

apply_style()
OUT = Path(__file__).resolve().parents[3] / "assets" / "figures" / "ch08_regression"
OUT.mkdir(parents=True, exist_ok=True)
rng = np.random.default_rng(8)  # 固定 seed -> 数字与配图可复现


# ---------------------------------------------------------------------------
# 拟真广告数据 (TV / radio / newspaper -> sales), 系数呼应 ISLP 第 3 章
# ---------------------------------------------------------------------------
def make_advertising(n=200):
    TV = rng.uniform(0, 300, n)
    radio = rng.uniform(0, 50, n)
    newspaper = rng.uniform(0, 110, n)
    sales = 2.9 + 0.046 * TV + 0.19 * radio + 0.001 * newspaper + rng.normal(0, 1.7, n)
    return pd.DataFrame(
        {"TV": TV, "radio": radio, "newspaper": newspaper, "sales": sales}
    )


df = make_advertising(200)
model = smf.ols("sales ~ TV", data=df).fit()


# ---------------------------------------------------------------------------
# 图 8.2.2 —— R² 的方差分解 (StatQuest 直觉: 围着均值 vs 围着拟合线)
# 用同一份广告数据的 30 城子样本演示, 便于看清竖直残差
# ---------------------------------------------------------------------------
def fig_r2_decomposition():
    sub = df.iloc[:30]
    x, y = sub["TV"].to_numpy(), sub["sales"].to_numpy()
    ybar = y.mean()
    m = smf.ols("sales ~ TV", data=sub).fit()
    yhat = m.fittedvalues.to_numpy()
    ss_mean = float(((y - ybar) ** 2).sum())   # SST
    ss_fit = float(((y - yhat) ** 2).sum())     # SSE
    r2 = 1 - ss_fit / ss_mean

    apply_style()
    fig, axes = plt.subplots(1, 2, figsize=(11, 4.2), sharey=True)
    xs = np.linspace(x.min(), x.max(), 100)

    # 左: 围着均值的总变异 SS(mean)
    ax = axes[0]
    ax.axhline(ybar, color=PALETTE["muted"], lw=2, label=f"均值 ȳ = {ybar:.1f}")
    ax.vlines(x, np.minimum(y, ybar), np.maximum(y, ybar),
              color=PALETTE["accent"], lw=1.1, alpha=0.7)
    ax.scatter(x, y, s=34, color=PALETTE["primary"], edgecolors="white",
               linewidths=0.7, zorder=3)
    ax.set_title(f"平庸基准: 只猜均值\nSS(mean) = Σ(y−ȳ)² = {ss_mean:.0f}")
    ax.set_xlabel("电视广告投入 TV")
    ax.set_ylabel("销售额 sales")
    ax.legend(loc="upper left", fontsize=10)

    # 右: 围着拟合线的剩余变异 SS(fit)
    ax = axes[1]
    ax.plot(xs, m.params["Intercept"] + m.params["TV"] * xs,
            color=PALETTE["primary"], lw=2, label="OLS 拟合线")
    ax.vlines(x, np.minimum(y, yhat), np.maximum(y, yhat),
              color=PALETTE["accent"], lw=1.1, alpha=0.7)
    ax.scatter(x, y, s=34, color=PALETTE["primary"], edgecolors="white",
               linewidths=0.7, zorder=3)
    ax.set_title(f"用上 TV 之后: 围着拟合线\nSS(fit) = Σ(y−ŷ)² = {ss_fit:.0f}")
    ax.set_xlabel("电视广告投入 TV")
    ax.legend(loc="upper left", fontsize=10)

    fig.suptitle(
        f"R² = (SS(mean) − SS(fit)) / SS(mean) = {r2:.3f}"
        f"  →  约 {r2 * 100:.0f}% 的销售变异被 TV 解释掉",
        fontsize=14, fontweight="bold",
    )
    fig.tight_layout(rect=(0, 0, 1, 0.94))
    plt.savefig(OUT / "fig_8_2_2_r2_decomposition.png")
    plt.close(fig)
    return r2


# ---------------------------------------------------------------------------
# 图 8.2.3 —— 均值置信带 vs 预测带 (ISLP get_prediction)
# ---------------------------------------------------------------------------
def fig_pred_vs_conf_bands():
    grid = pd.DataFrame({"TV": np.linspace(df["TV"].min(), df["TV"].max(), 120)})
    pr = model.get_prediction(grid).summary_frame(alpha=0.05)
    gx = grid["TV"].to_numpy()

    apply_style()
    fig, ax = figure(8.0, 4.4)
    ax.scatter(df["TV"], df["sales"], s=16, color=PALETTE["primary"],
               alpha=0.45, edgecolors="none", label="各城市观测")
    ax.fill_between(gx, pr["obs_ci_lower"], pr["obs_ci_upper"],
                    color=PALETTE["accent"], alpha=0.16, label="95% 预测带 (含 ε)")
    ax.fill_between(gx, pr["mean_ci_lower"], pr["mean_ci_upper"],
                    color=PALETTE["accent"], alpha=0.42, label="95% 均值置信带")
    ax.plot(gx, pr["mean"], color=PALETTE["accent"], lw=2.2, label="OLS 拟合线")
    ax.set_title("均值置信带(窄) vs 预测带(宽): 预测带额外吞下了误差 ε")
    ax.set_xlabel("电视广告投入 TV")
    ax.set_ylabel("销售额 sales")
    ax.legend(loc="upper left", fontsize=10)
    fig.tight_layout()
    plt.savefig(OUT / "fig_8_2_3_pred_vs_conf_bands.png")
    plt.close(fig)


# ---------------------------------------------------------------------------
# 图 8.2.4 —— 残差 vs 拟合值诊断图
# ---------------------------------------------------------------------------
def fig_residual_vs_fitted():
    fitted = model.fittedvalues.to_numpy()
    resid = model.resid.to_numpy()

    apply_style()
    fig, ax = figure(7.6, 4.2)
    ax.axhline(0, color=PALETTE["accent"], lw=1.6, ls="--")
    ax.scatter(fitted, resid, s=22, color=PALETTE["primary"], alpha=0.6,
               edgecolors="white", linewidths=0.5)
    ax.set_title("残差 vs 拟合值: 随机散布在 0 线两侧 → 线性 / 同方差大致成立")
    ax.set_xlabel("拟合值 ŷ")
    ax.set_ylabel("残差 e = y − ŷ")
    fig.tight_layout()
    plt.savefig(OUT / "fig_8_2_4_residual_vs_fitted.png")
    plt.close(fig)


if __name__ == "__main__":
    print("=" * 64)
    print("数据预览 df.head():")
    print(df.head().round(2).to_string(index=False))

    print("\n" + "=" * 64)
    print("model = smf.ols('sales ~ TV', data=df).fit()")
    print(model.summary())

    print("\n" + "=" * 64)
    print("95% 置信区间 model.conf_int():")
    print(model.conf_int().round(4).to_string())

    print("\n" + "=" * 64)
    print("get_prediction 对比 均值区间 vs 预测区间 (TV = 50/150/250):")
    new = pd.DataFrame({"TV": [50, 150, 250]})
    frame = model.get_prediction(new).summary_frame(alpha=0.05)
    frame.insert(0, "TV", new["TV"].to_numpy())
    cols = ["TV", "mean", "mean_ci_lower", "mean_ci_upper",
            "obs_ci_lower", "obs_ci_upper"]
    print(frame[cols].round(3).to_string(index=False))

    r2 = fig_r2_decomposition()
    fig_pred_vs_conf_bands()
    fig_residual_vs_fitted()
    print("\n" + "=" * 64)
    print(f"图 8.2.2 子样本 R² = {r2:.3f}")
    print(f"图片已生成至 {OUT}")
