Attention 自测题
完成以下题目检验你的理解程度
🎮 交互式自测(推荐)
Q1
Self-Attention 中为什么要除以 √d_k?
Q2
Q、K、V 三个矩阵分别代表什么含义?
Q3
Multi-Head Attention 的主要优势是什么?
Q4
GQA(Grouped Query Attention)的作用是什么?
Q5
因果掩码(Causal Mask)的作用是什么?
Q6
repeat_kv 函数在 GQA 中的作用是什么?
Q7
Flash Attention 相比标准实现的优势是什么?
🎯 综合问答题
Q8: 实战问题
假设你在调试一个 Attention 模块,发现所有 token 的注意力权重几乎均匀分布(每个位置都是 ~1/seq_len),这可能是什么问题?如何解决?
点击查看参考答案
可能的原因:
Q 和 K 没有正确初始化:
- 投影矩阵初始值太小
- 导致 Q·K 分数接近 0
- softmax(0, 0, ..., 0) ≈ 均匀分布
缩放因子问题:
- 除以了过大的值
- 或忘记开根号(除以 d_k 而不是 √d_k)
head_dim 设置错误:
- head_dim 过大导致点积方差过大
- 但这通常会导致极端分布,不是均匀分布
没有学习到有意义的模式:
- 训练数据问题
- 模型容量不足
诊断方法:
python
# 检查 Q·K 分数(softmax 之前)
scores = torch.matmul(xq, xk.transpose(-2, -1)) / math.sqrt(head_dim)
print(f"scores mean: {scores.mean()}, std: {scores.std()}")
# 正常情况:mean ≈ 0, std ≈ 1
# 问题情况:std 太小(接近 0)解决方案:
检查投影矩阵初始化:
python# 使用 Xavier 或 Kaiming 初始化 nn.init.xavier_uniform_(self.wq.weight)验证缩放因子:
python# 确保是 sqrt(head_dim),不是 head_dim scale = math.sqrt(self.head_dim)可视化注意力:
pythonplt.imshow(attn_weights[0, 0].detach().numpy()) plt.title("Attention weights") plt.colorbar()
Q9: 概念理解
为什么 Self-Attention 中 Q、K、V 都来自同一个输入 x,但还需要三个不同的投影矩阵?直接用 x 做 Q、K、V 不行吗?
点击查看参考答案
直接用 x 的问题:
如果 Q = K = V = x,则:
python
scores = x @ x.T # 自己和自己的点积这相当于计算每个 token 与其他 token 的"余弦相似度"(内积)。
问题:
对称性:token_i 对 token_j 的注意力 = token_j 对 token_i 的注意力
- 但语言中关系往往是不对称的
- "猫吃鱼":猫应该注意鱼,但鱼不一定要注意猫
表达能力有限:
- 只能表达"相似度"这一种关系
- 无法学习"主谓关系"、"修饰关系"等
三个投影的意义:
python
Q = x @ W_Q # "作为查询者,我关注什么特征?"
K = x @ W_K # "作为被查询者,我展示什么特征?"
V = x @ W_V # "我实际要传递什么内容?"优势:
- 非对称性:Q 和 K 不同,允许非对称关系
- 角色分离:查询角度、被查询角度、内容传递可以不同
- 表达能力:可以学习任意复杂的关系模式
类比:
- 图书馆场景:
- 你的问题(Q):用自然语言描述需求
- 书的索引(K):用关键词标签
- 书的内容(V):实际文字
- 三者用不同的"语言",通过匹配找到正确的内容
Q10: 代码理解
解释以下代码中 contiguous() 的必要性:
python
output = output.transpose(1, 2).contiguous().view(batch, seq_len, -1)点击查看参考答案
背景:
PyTorch tensor 有两个概念:
- 存储(Storage):实际的内存布局
- 视图(View):如何解释这块内存
transpose 的行为:
python
# 假设 output 形状是 [batch, n_heads, seq, head_dim]
# 内存布局是按这个顺序排列的
output = output.transpose(1, 2)
# 现在形状是 [batch, seq, n_heads, head_dim]
# 但内存布局没变!只是改变了"视图"问题:
view() 要求 tensor 在内存中是连续的,但 transpose 后内存不连续:
python
# 原始内存顺序(简化示例):
# [head0_pos0, head0_pos1, head1_pos0, head1_pos1, ...]
# transpose 后逻辑顺序:
# [head0_pos0, head1_pos0, head0_pos1, head1_pos1, ...]
# 内存不连续 → view 会报错contiguous() 的作用:
python
output = output.transpose(1, 2).contiguous()
# 1. 重新分配内存
# 2. 按新的逻辑顺序排列数据
# 3. 现在可以安全使用 view性能考虑:
contiguous()需要拷贝内存,有开销- 但这是必要的开销
- 替代方案:
reshape()会自动处理,但不够显式
最佳实践:
python
# 明确知道需要连续内存时,显式调用
output = output.transpose(1, 2).contiguous().view(...)
# 或使用 reshape(隐式处理)
output = output.transpose(1, 2).reshape(...)✅ 完成检查
完成所有题目后,检查你是否达到:
- [ ] Q1-Q7 全对:基础知识扎实
- [ ] Q8 能提出 2+ 诊断方法:具备调试能力
- [ ] Q9 能解释投影矩阵的意义:深刻理解设计原则
- [ ] Q10 能解释 contiguous 的必要性:理解 PyTorch 内存模型
如果还有不清楚的地方,回到 teaching.md 复习,或重新运行实验代码。
🎓 进阶挑战
想要更深入理解?尝试:
修改实验代码:
- 实现一个没有缩放因子的 Attention,观察 softmax 输出
- 实现 MQA(Multi-Query Attention),对比 GQA
- 可视化不同头学到的注意力模式
阅读论文:
- Attention Is All You Need - Transformer 原始论文
- GQA Paper - Grouped Query Attention
- Flash Attention - 高效 Attention 实现
实现变体:
- 实现 Cross-Attention(Q 和 KV 来自不同输入)
- 实现 Sliding Window Attention
- 实现 Sparse Attention
下一步:前往 04. FeedForward 学习前馈网络!