"""第 2 章 描述统计 / 2.4 均值

生成图表:
    fig_2_4_1_balance.png        均值作为重心示意
    fig_2_4_2_means_compare.png  五种均值在偏态数据上的对比
    fig_2_4_3_geom_path.png      投资回报: 几何 vs 算术

运行:
    python docs/assets/scripts/ch02_descriptive/04_mean.py
"""
from __future__ import annotations
import sys, pathlib
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parent.parent))
from _shared.plot_style import apply_style, PALETTE, figure

import numpy as np
import matplotlib.pyplot as plt

OUT = pathlib.Path(__file__).resolve().parents[2] / "figures" / "ch02_descriptive"
OUT.mkdir(parents=True, exist_ok=True)


# ---------------------------------------------------------------------------
# 图 2.4.1  均值作为重心
# ---------------------------------------------------------------------------
def fig_balance() -> None:
    data = np.array([2, 4, 6, 8, 15], dtype=float)
    mean = data.mean()
    fig, ax = figure(7.2, 2.4)

    # 数轴
    ax.axhline(0, color=PALETTE["axis"], lw=1.2)
    ax.set_xlim(-1, 17)
    ax.set_ylim(-1.4, 1.6)
    ax.set_yticks([])
    ax.spines[["left", "right", "top"]].set_visible(False)
    ax.grid(False)
    ax.set_xticks(range(0, 17, 2))

    # 数据点 (球)
    for i, x in enumerate(data, 1):
        ax.scatter(x, 0.35, s=420, color=PALETTE["primary"], zorder=3,
                   edgecolors="white", linewidths=1.5)
        ax.text(x, 0.35, str(i), color="white", ha="center", va="center",
                fontsize=11, fontweight="bold", zorder=4)
        ax.plot([x, x], [0, 0.35], color=PALETTE["primary"], lw=1, alpha=0.4)

    # 三角支点
    ax.plot([mean, mean - 0.6, mean + 0.6, mean],
            [-0.05, -0.7, -0.7, -0.05],
            color=PALETTE["accent"], lw=2)
    ax.fill_between([mean - 0.6, mean + 0.6], -0.7, -0.05,
                    color=PALETTE["accent"], alpha=0.85)
    ax.text(mean, -1.1, f"x̄ = {mean:.1f}", color=PALETTE["accent"],
            ha="center", fontsize=12, fontweight="bold")

    # 偏差箭头
    for x in data:
        ax.annotate("", xy=(mean, 0.9), xytext=(x, 0.9),
                    arrowprops=dict(arrowstyle="->", color=PALETTE["muted"], lw=1))
    ax.text(8.5, 1.25, "所有偏差之和 = 0", color=PALETTE["text"],
            ha="center", fontsize=11)

    ax.set_title("均值 = 数据的「物理重心」")
    plt.savefig(OUT / "fig_2_4_1_balance.png")
    plt.close(fig)


# ---------------------------------------------------------------------------
# 图 2.4.2  几种均值在偏态数据上的对比
# ---------------------------------------------------------------------------
def fig_means_compare() -> None:
    rng = np.random.default_rng(42)
    salaries = np.r_[rng.normal(7, 1.2, 95), rng.normal(80, 15, 5)]  # 单位: 千元
    salaries = np.clip(salaries, 3, None)

    arith = salaries.mean()
    median = np.median(salaries)
    trimmed = np.mean(np.sort(salaries)[5:-5])  # 截掉两端各 5%
    geom = np.exp(np.mean(np.log(salaries)))

    fig, (ax1, ax2) = plt.subplots(
        1, 2, figsize=(9.6, 4.2),
        gridspec_kw={"width_ratios": [3, 1], "wspace": 0.05},
    )
    apply_style()  # 重新应用样式至新建的 ax

    # 主图: 主体区间 (3~15 千元), 细 bin
    main_mask = salaries <= 15
    ax1.hist(salaries[main_mask], bins=np.arange(3, 15.5, 0.5),
             color=PALETTE["primary"], alpha=0.7, edgecolor="white")
    for v, label, color in [
        (arith,   f"算术平均  {arith:.1f}",  PALETTE["series"][0]),
        (median,  f"中位数    {median:.1f}", PALETTE["series"][1]),
        (trimmed, f"截尾均值  {trimmed:.1f}", PALETTE["series"][2]),
        (geom,    f"几何平均  {geom:.1f}",    PALETTE["series"][3]),
    ]:
        if v <= 15:
            ax1.axvline(v, color=color, lw=2.2, ls="--", label=label)
        else:
            ax1.axvline(15, color=color, lw=2.2, ls="--",
                        label=label + " →")
    ax1.set_xlim(3, 15)
    ax1.set_xlabel("月薪（千元）")
    ax1.set_ylabel("人数")
    ax1.set_title("主体分布（95 人）")
    ax1.legend(frameon=False, loc="upper right", fontsize=10)
    ax1.grid(True, axis="y", alpha=0.3)

    # 副图: 高薪尾巴
    tail = salaries[~main_mask]
    ax2.hist(tail, bins=np.arange(40, 110, 10),
             color=PALETTE["accent"], alpha=0.75, edgecolor="white")
    ax2.set_xlim(40, 110)
    ax2.set_xlabel("月薪（千元）")
    ax2.set_title(f"高薪尾巴（{len(tail)} 人）")
    ax2.set_yticks(range(0, max(3, len(tail) + 1)))
    ax2.grid(True, axis="y", alpha=0.3)

    fig.suptitle("同一份工资数据上的四种「平均」", y=1.0)
    plt.tight_layout()
    plt.savefig(OUT / "fig_2_4_2_means_compare.png", bbox_inches="tight")
    plt.close(fig)


# ---------------------------------------------------------------------------
# 图 2.4.3  投资回报: 几何 vs 算术
# ---------------------------------------------------------------------------
def fig_geom_path() -> None:
    rates = np.array([0.10, -0.05, 0.20])
    factors = 1 + rates
    years = np.arange(0, 4)

    actual = 100 * np.r_[1, np.cumprod(factors)]
    arith_mean = rates.mean()
    geom_mean = np.prod(factors) ** (1 / len(factors)) - 1

    arith_path = 100 * (1 + arith_mean) ** years
    geom_path = 100 * (1 + geom_mean) ** years

    fig, ax = figure(7.0, 4.0)
    ax.plot(years, actual, "o-", color=PALETTE["primary"], lw=2.4,
            markersize=8, label="实际本金路径")
    ax.plot(years, arith_path, "s--", color=PALETTE["accent"], lw=1.8,
            markersize=7, label=f"用算术平均 {arith_mean*100:.2f}% 推算（错）")
    ax.plot(years, geom_path, "^:", color=PALETTE["series"][2], lw=1.8,
            markersize=7, label=f"用几何平均 {geom_mean*100:.2f}% 推算（对）")

    for x, y in zip(years, actual):
        ax.annotate(f"{y:.1f}", xy=(x, y), xytext=(0, 10),
                    textcoords="offset points",
                    ha="center", fontsize=10, color=PALETTE["text"])

    ax.set_xticks(years)
    ax.set_xlabel("年")
    ax.set_ylabel("本金（元）")
    ax.set_title("100 元起步, 三年收益 +10% / -5% / +20%")
    ax.legend(frameon=False, loc="upper left")
    plt.savefig(OUT / "fig_2_4_3_geom_path.png")
    plt.close(fig)


if __name__ == "__main__":
    apply_style()
    fig_balance()
    fig_means_compare()
    fig_geom_path()
    print(f"图片已生成至 {OUT}")
