大模型持续学习:解决灾难性遗忘的挑战

训练优化大模型持续学习openstarry.com

大模型持续学习:解决灾难性遗忘的挑战

大语言模型一旦训练完成,如果直接在新任务上微调,就会"忘记"之前学到的知识。这个现象被称为灾难性遗忘(Catastrophic Forgetting),是持续学习领域最核心的挑战。

什么是灾难性遗忘?

灾难性遗忘是指神经网络在学习新任务后,对旧任务的性能急剧下降的现象。

灾难性遗忘的典型场景:

阶段 1:模型在英文数据上预训练
  英文理解准确率:95%
  中文理解准确率:0%(未见过)

阶段 2:在中文数据上微调
  英文理解准确率:35%  ← 暴跌 60%!
  中文理解准确率:90%

问题:模型在学习新知识时,遗忘了旧知识

这就像一个人学了中文后突然忘了英文——但人类不会这样。大脑天然支持持续学习,而神经网络却做不到。


为什么会发生灾难性遗忘?

根本原因:神经网络的参数是全局共享的。

假设网络有 3 层,每层权重 W1, W2, W3

任务 A 训练后:
  W1* = 最优于任务 A 的权重
  W2* = 最优于任务 A 的权重
  W3* = 最优于任务 A 的权重

任务 B 训练时:
  梯度更新方向只考虑任务 B 的损失
  W1 = W1* + Δ1(Δ1 为任务 B 的梯度)
  
  问题:Δ1 可能破坏 W1* 中对任务 A 重要的结构

主流解决方案

1. 弹性权重巩固(EWC)

核心思想:重要参数应该少改,不重要参数可以多改

# EWC 的核心公式
def ewc_loss(model, old_data, importance_weights):
    """
    importance_weights: 每个参数对旧任务的重要性
    """
    total_loss = new_task_loss
    
    for param_name, param in model.named_parameters():
        # 重要性高的参数偏离旧值越多,惩罚越大
        penalty = importance_weights[param_name] * (param - old_params[param_name])**2
        total_loss += lambda_reg * penalty
    
    return total_loss

工作流程

  1. 训练任务 A 后,计算每个参数的重要性(用 Fisher 信息矩阵估计)
  2. 训练任务 B 时,损失函数增加正则项:参数偏离旧值的加权平方差
  3. 重要参数被"锁定",新任务主要更新不重要的参数

2. 渐进网络(Progressive Networks)

核心思想:新任务增加新模块,旧模块完全冻结

任务 A 训练完成后:
  [Column A: W_a1, W_a2, W_a3]  ← 冻结

任务 B 开始训练:
  [Column A: W_a1, W_a2, W_a3]  ← 冻结,不更新
  [Column B: W_b1, W_b2, W_b3]  ← 新建,可训练
  [横向连接: h_ab1, h_ab2]       ← 连接 A 和 B 的信息

优势:完全消除遗忘(旧参数不变)
劣势:模型随任务数线性增长

3. 经验回放(Experience Replay)

核心思想:保留旧任务的部分数据,新旧数据混合训练

经验回放的实现方式:

方法 1:存储原始数据
  保留 5% 的旧任务数据
  新任务训练时,batch 中混合 50% 旧数据 + 50% 新数据

方法 2:生成伪数据(GEM、A-GEM)
  不存储原始数据
  用生成模型或梯度约束模拟旧数据

方法 3:提示微调(Prompt Tuning)
  为每个任务学习一个小的提示向量
  推理时根据任务切换提示

大模型时代的持续学习

大语言模型(如 GPT、LLaMA)的持续学习面临特殊挑战:

挑战 具体表现 应对策略
模型规模数百亿参数,无法冻结所有旧层参数高效微调(LoRA、Adapter)
多任务混合对话、代码、推理等多种能力指令微调 + 人类反馈
实时更新需要持续吸收新知识RAG(检索增强生成)
安全对齐新能力不能破坏安全约束RLHF + 约束微调

LoRA:大模型持续学习的利器

Low-Rank Adaptation(LoRA)是当前大模型持续学习的主流方案:

# LoRA 的核心思想
# 原始权重 W 冻结,只训练低秩分解矩阵 A 和 B

# 冻结的原始权重
W_frozen = pretrained_weight  # 不更新

# 可训练的低秩矩阵
A = nn.Parameter(torch.randn(d, r))  # r << d
B = nn.Parameter(torch.randn(r, d))

# 前向传播
output = x @ W_frozen + x @ A @ B  # A @ B 是增量

# 优势:
# - 每个任务只需存储 A 和 B(通常只有原始参数的 0.1%)
# - 多个任务的 LoRA 可以热切换
# - 完全消除灾难性遗忘

实际应用案例

1. ChatGPT 的持续进化

ChatGPT 通过持续微调不断获得新能力:从 GPT-3.5 到 GPT-4,从纯文本到多模态,每次更新都基于前一版本继续训练。

2. 医疗 AI 的知识更新

医疗知识不断更新,模型需要持续学习新的诊疗指南,同时不忘记已掌握的医学知识。

3. 个性化推荐系统

用户兴趣随时间变化,模型需要在保持通用推荐能力的同时,持续适应用户的新偏好。


当前研究前沿


总结

灾难性遗忘是神经网络与人类学习能力之间的关键差距。从 EWC 到 LoRA,研究者正在逐步缩小这个差距。对于大语言模型,持续学习不仅是一个技术问题,更是实现真正通用 AI 的必经之路。理解这些方法的原理,能帮助你在实际项目中设计更智能的模型更新策略。

以 AI 之力,筑未来之境

现在注册,立即免费获赠 200 次大模型调用权益

免费注册 →