15.5 循环神经网络¶
公交车一站站往前开,小率拿着路线图记录:起点、公园站、学校站、商场站……每到一站,他都会把新信息写进卡片,同时保留一部分旧信息。均哥说:“处理序列时,模型也要一边看当前输入,一边记住过去。”
循环神经网络(Recurrent Neural Network, RNN)就是为序列数据设计的模型。文本、语音、股票价格、心电信号,都有“前后顺序”。
15.5.1 RNN 把过去压进隐藏状态¶
RNN 在每个时间步 \(t\) 接收当前输入 \(\mathbf{x}_t\),并更新隐藏状态(Hidden State)\(\mathbf{h}_t\):
\[
\mathbf{h}_t = \tanh(\mathbf{W}_x\mathbf{x}_t + \mathbf{W}_h\mathbf{h}_{t-1} + \mathbf{b})
\]
输出可以由隐藏状态得到:
\[
\hat{\mathbf{y}}_t = f(\mathbf{W}_y\mathbf{h}_t + \mathbf{c})
\]
隐藏状态像一张随车更新的路线卡:它不能保存全部历史,只能保存模型认为有用的摘要。
15.5.2 LSTM 用门来决定记住什么¶
普通 RNN 处理长序列时,早期信息容易被冲淡。长短期记忆网络(Long Short-Term Memory, LSTM)加入细胞状态(Cell State)和三个门:
| 门 | 作用 | 公交路线类比 |
|---|---|---|
| 遗忘门 | 丢掉不重要旧信息 | 不再记经过的小路口 |
| 输入门 | 写入新的重要信息 | 记录关键换乘站 |
| 输出门 | 决定当前要拿出什么 | 给下一步导航的摘要 |
门不是人工规则,而是可学习的函数。模型会从数据中学会哪些信息该留、哪些该丢。
15.5.3 看一眼 RNN 的张量形状¶
import torch
import torch.nn as nn
batch_size = 4
seq_len = 6
input_dim = 3
hidden_dim = 8
x = torch.randn(batch_size, seq_len, input_dim)
rnn = nn.RNN(input_size=input_dim, hidden_size=hidden_dim, batch_first=True)
out, h_last = rnn(x)
print(out.shape) # [4, 6, 8]
print(h_last.shape) # [1, 4, 8]
out 保存每个时间步的隐藏状态;h_last 是最后一个时间步的摘要。
序列模型的边界
RNN/LSTM 按时间步逐个处理,天然适合短到中等长度序列。但当序列很长、需要任意位置互相关注时,Transformer 往往更高效。
小率的笔记本
RNN 的核心是“当前输入 + 过去摘要”。LSTM 在这个基础上加入门控机制,让模型更可靠地保存长期信息。只要数据有顺序,先问自己:过去的信息是否会影响现在?
