Bootstrap

PPO算法学习

1.1 强化学习

图1:强化学习的流程
在这里插入图片描述

  • 强化学习的两个实体:智能体(Agent)环境(Environment)
  • 强化学习中智能体环境之间的交互:
    • 状态空间(State): 指的是环境中所有的状态集合
    • 动作空间(Action): 指的是智能体中所有可能的动作集合
    • 奖励(Reward): 指的是智能体在某一 State 下获得的奖励

如图一所示,智能体与环境的交互过程如下:

  1. 在 t 时刻,环境的状态为 St,在这一状态下得到的奖励为 Rt
  2. 智能体根据 St 和 Rt,采取相应动作 At,在文本续写方面就是输出的第 t+1 个 Token
  3. 当智能体采取 At 后,环境状态变为 St+1,得到相应奖励 Rt+1

目的: 智能体在与环境交互的过程中不断学习,最终找到一个策略,能够根据当前的 State 环境状态和 Reward 奖励反馈,来选择最佳的 Action。

1.2 价值函数

在1.1中,我们谈到了奖励值 Rt ,它表示环境进入状态 St 下的即时奖励。
但如果只考虑即时奖励,目光似乎太短浅了:当下的状态和动作会影响到未来的状态和动作,进而影响到未来的整体收益。
所以,一种更好的设计方式是:**t 时刻状态 s 的总收益 = 身处状态 s 能带来的即时收益 + 从状态s出发后能带来的未来收益。**写成表达式就是:
价值函数图片
解释以上公式:

  • V_t:t 时刻的总收益,注意这个收益蕴涵了“即时”和“未来”的概念
  • R_t : t 时刻的即时收益
  • V_{t+1} :t+1 时刻的总收益,注意这个收益蕴涵了“即时”和“未来”的概念。而 V_{t+1} 对 V_t 来说就是“未来”。
  • γ:折扣因子。它决定了我们在多大程度上考虑将“未来收益”纳入“当下收益”。

2.NLP中的强化学习

我们如何将第一部分强化学习的流程有用到 NLP 任务中呢?换而言之,NLP任务中的智能体和环境(实体)、状态、动作、奖励等等,都是指什么呢?
NLP任务
回想对 NLP 任务做强化学习(RLHF)的目的:
我们希望给模型一个 Prompt,生成符合人类偏好的 Response。再回想模型推理过程:每个时刻 t 只产生一个 Token,先有上一个 Token 才有下一个 Token。

解读上述图像:

  • 1.首先喂给模型一个 Prompt,期望生成符合人类偏好的 Response。
  • 2.在 t 时刻,模型根据上下文(St)生成一个 token,这个 token 对应着强化学习中的动作,即At ,因此不难理解,在NLP语境下,强化学习任务的动作空间就对应着词表。
  • 3.在 t 时刻,模型生成的 At,即生成的一个 token,对应着即时收益 Rt,总收益为 Vt 。这个收益即可以理解为**“对人类喜好的衡量”**。此刻,模型的状态从 “St”——>“St+1”,也就是从“上下文”变为“上下文+新生成的token”

3. RLHF中的四个重要角色

我们从第二部分中已经知道:生成 token At 和对应的即时收益 Rt总收益 Vt 并不是一个模型。那么在RLHF中到底有几个模型?他们是怎么配合做训练的?而我们最终要的是哪个模型?
RLHF中的四个角色
如图所示,在 RLHF-PPO 阶段,一共有四个主要的模型 ,分别是:

  1. Actor Model: 想要训练的目标模型。
  2. Critic Model: 它的作用是评估总收益 Vt。
  3. Reward Model: 作用是计算即时收益 Rt。
  4. Reference Model: 作用是给目标模型增加一些“约束”,防止目标模型训歪(朝不受控制的方向更新,效果可能越来越差)。

其中:

  • Actor/Critic ModelRLHF 阶段是需要训练的(图中给这两个模型加了粗边,就是表示这个含义);而Reward/Reference Model是参数冻结的。
  • Critic/Reward/Reference Model 共同组成了一个“奖励-loss”计算体系(我自己命名的,为了方便理解),我们综合它们的结果计算loss,用于更新Actor和Critic Model
3.1 Actor Model

正如前文所说,Actor就是我们想要训练的目标语言模型。我们一般用SFT阶段产出的SFT模型来对它做初始化。
在这里插入图片描述
我们的最终目的是让 Acto r模型能产生符合人类喜好的 response 。所以我们的策略是,先喂给 Actor 一条 prompt (这里假设 batch_size = 1,所以是1条 prompt ),让它生成对应的response。然后,我们再将 “prompt + response" 送入我们的 “奖励-loss” 计算体系中去算得最后的loss,用于更新 actor 。

3.2 Reference Model

Reference Model(以下简称Ref模型)一般也用SFT阶段得到的SFT模型做初始化,在训练过程中,Ref 模型的参数是冻结的。主要作用是防止 Actor “训歪”,那么它具体是怎么做到这一点的呢?
在这里插入图片描述
“Ref” 模型的目的是 “防止 Actor 训歪”,简而言之,我们希望两个模型的输出分布尽量相似 ,即用KL散度 去衡量分布的相似度。

  • 对于Actor模型: 我们喂给它一个prompt,它正常输出对应的 response。那么 response中每一个 token 肯定有它对应的 log_prob 结果呀,我们把这样的结果记为 log_probs。
  • 对于Ref模型: 我们把Actor生成的"prompt + response"喂给它,那么它同样能给出每个token的log_prob结果,我们记其为ref_log_probs。
  • 那么这两个模型的输出分布相似度就可以用 ref_log_probs - log_probs 来衡量,我们可以从两个方面来理解这个公式:
    • 从直觉上理解: ref_log_probs 越高,说明 Ref 模型对 Actor 模型输出的肯定性越大。即对于某个 St(上下文),输出某个 At (输出的下一个 Token)的概率也很高。这时可以认为Actor模型较Ref模型没有训歪。
      在这里插入图片描述
3.3 Critic Model

Critic Model 用于预测期望总收益 Vt,在训练过程中,Critic Model 和 Actor 一样需要对参数进行更新。 实践中,Critic Model的设计和初始化方式也有很多种,例如和Actor共享部分参数、从RW阶段的Reward Model初始化而来等等。

问题一:训练Actor模型我能理解,但我还是不明白,为什么要单独训练一个Critic模型用于预测收益呢?

因为总收益 Vt 为即时收益和未来收益,在目标模型的训练时,是没有上帝视角来得到未来收益的。所以在 t 时刻,我们不给出客观存在的总收益 Vt,只能训练一个模型去预测它
所以总结来说,在RLHF中,我们不仅要训练模型生成符合人类喜好的内容的能力(Actor),也要提升模型对人类喜好量化判断的能力(Critic)。这就是Critic模型存在的意义。
Critic 模型架构:
在这里插入图片描述

deepspeed-chat 采用了 Reward 模型作为它的初始化(即 Critic Model 的起点就是 Reward Model),所以这里我们也按Reward模型的架构来简单画画它。你可以简单理解成,Reward/Critic模型和Actor模型的架构是很相似的(毕竟输入都一样),同时,在最后一层加了一个 Value Head 层,这是一个简单的线性层,目的是将原始输出映射为单一的 Vt 值(代表Critic模型对 t 时刻以及未来的 response 的收益估计)。

3.4 Reward Model

目的: 用于计算 token At 的即时收益,它就是 RW 阶段所训练的奖励模型。但是在 RLHF 中,参数是冻结的。(Rt 即时收益,其实就是一种事实数据,因为 Token At 以及得到了)

问题一:为什么Critic模型要参与训练,而同样是和收益相关的Reward模型的参数就可以冻结呢?

这是因为,Critic 模型是站在上帝视角的。这个上帝视角有两层含义:

  • 第一点,Reward 模型是经过和“估算收益”相关的训练的,因此在 RLHF 阶段它可以直接被当作一个能产生客观值的模型。
  • 第二点,Reward 模型代表的含义就是“即时收益”,你的 token At 已经产生,因此即时收益自然可以立刻算出

问题二:既然已经用 Critic Model 求得 Vt (总收益)了,为什么还需要用 Reward Model 求得 Rt(即时收益)呢?

因为价值函数为 Vt = Rt + γVt+1,我们不能一味地依靠上帝视角:

  • Critic 预测的 Vt
  • Reward 预测的 Rt 和 Critic 预测的 Vt+1
    结果1全靠预测,而结果2中的 Rt 才是事实依据。

我们知道 Critic 模型也是参与参数更新的,在训练Critic 模型时,我们可以使用均方误差(Mean Squared Error, MSE)作为损失函数。MSE计算的是模型预测的收益与真实收益之间的差异(上帝视角的客观收益-Critic模型预测的收益)。这里的“真实收益”指的是从上帝视角看到的客观收益,即如果模型能够完美地了解环境和所有可能的未来状态,它能够计算出的收益。但是上帝视角的客观收益我们是不知道的,只能用已知事实数据去逼近它,所以我们就用 Rt + γ*Vt+1 来做近似,这就是 Rt 和 Vt 同时存在的意义。

4. RLHF 中 loss 的计算

到目前为止,我们已经基本了解了RLHF的训练框架,以及其中的四个重要角色(训练一个RLHF,有4个模型在硬件上跑,可想而知对存储的压力)。在本节中,我们一起来解读RLHF的loss计算方式。在解读中,我们会再一次理一遍RLHF的整体训练过程,填补相关细节。在这之后,我们就可以来看代码解析了。

在第三部分的讲解中,我们知道Actor和Critic模型都会做参数更新,所以我们的loss也分成2个

  • Actor Loss: 用于评估Actor是否产生了符合人类喜好的结果,将作用于Actor的 BWD上。
  • Critic Loss: 用于评估Critic是否正确预测了人类的喜好,将作用于Critic的BWD上。
4.1 Actor Loss 的设计
  • Actor 的输入是当前上下文 St,输出是 Token At,即第 t+1 个 Token【P(At|St)】。
  • Critic 模型根据 St,At,输出总收益的预测 Vt。

因此我们可以将 Loss 设计为
我们希望最小化 actor_loss
在这里插入图片描述
这个设计的直观解释:

  • 当 Vt > 0 时,意味着 Critic 对 Actor 当前采取的动作给了正向反馈,因此我们就需要在训练迭代中提高 P(At|St),起到减小 loss 的作用。
  • 当 Vt < 0 时,意味着 Critic 对 Actor 当前采取的动作给了负向反馈,因此我们需要在训练中降低 P(At|St)。

一句话总结:这个 loss 设计的含义是,对上下文 St 而言,如果 token At 产生的收益较高,那就增大它出现的概率,否则降低它出现的概率。

4.2 引入优势(Advantage)

在开始讲解之前,我们举个小例子:
假设在王者中,中路想支援发育路,这时中路有两种选择:1. 走自家野区。2. 走大龙路。
中路选择走大龙路,当她做出这个决定后,Critic告诉她可以收1个人头。结果,此刻对面打野正在自家采灵芝,对面也没有什么苟草英雄,中路一路直上,最终收割2个人头。
因为实际收割的人头比预期要多1个,中路尝到了甜头,所以她增大了“支援发育路走大龙路”的概率。
这个多出来的“甜头”,就叫做“优势”(Advantage)。

对于 NLP 任务来说,在 St -> At 情况下,如果 Critic 对 At 的收益预测为 Vt,但实际总收益为 Rt + γ * Vt+1,因此我们定义优势为:
在这里插入图片描述
然后我们用 Advt 去替换 Vt,则此刻 actor_loss 为:
在这里插入图片描述

4.3 Rt 即时收益的设计

总结一下,到目前为止,我们的 actor_loss 形式为:
在这里插入图片描述
那么,Rt 表示为每个 Actor 产生出 token At 带来的即时收益
在这里插入图片描述
但在deepspeed-chat的RLHF实践中,对 Rt 做了另一种设计:
在这里插入图片描述

  • kl_ctl: 常量,可以理解成是一个控制比例的缩放因子,在deepspeed-chat中默认设为0.1。
  • log(P(At|St)) - log(Pref(At|St)): 这一项你是不是非常眼熟,这就是我们在3.2部分介绍的 Actor 和 Ref 模型间的 KL 散度呀,写成更容易理解的形式,就是 ref_log_probs - log_probs。在3.2中我们说过,为了防止模型训歪,我们需要把这个 KL 散度加入 loss 计算中,所以这里我们就在做这件事。

基于上述描述,我们可以把 Rt 设计的原因归于:

  • 当 t ≠ T,我们更加关系 Actor 是否有在 Ref 的约束下输出 Token At,这就是即时收益。

也就是 t != T 的时候,我们更关心的是Actor模型是否在参考模型(Ref)的约束下生成了token At,即训练的方向是正确的。这里的 T 通常指的是序列的最后一个时间步。

  • 当t = T,我们不仅关心 Actor 是否遵从了 Ref 的约束输出了 Token At,同时 care 真正的即时收益 Rt

当 t=T 时,我们不仅关心 Actor 模型是否遵循了 Ref 模型的约束,也关心真正的即时收益 Rt。这是因为在最后一个时间步,我们有了完整的 prompt(提示)和response(响应),可以评估整个交互的奖励。
在这里插入图片描述

def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
                        action_mask):
        """
        reward_function:计算最终的reward分数
        复习一下几个相关参数的默认值:
        self.kl_ctl = 0.1
        self.clip_reward_value = 5
        
        对于batch中的某个prompt来说,它最终的reward分数为:
        (1) 先计算actor和ref_model的logit相似度: -self.kl_ctl * (log_probs - ref_log_probs)
            其实写成self.kl_ctl * (ref_log_probs - log_probs)更好理解些
            这个值越大,说明ref_model对actor生成的结果的认可度越高(即表明rlhf没有训歪),
            没有训歪的情况下我们也应该给模型一些奖励,这个奖励就是self.kl_ctl * (ref_log_probs - log_probs)
            
        (2)由于我们只取最后一个token对应位置的分数作为reward_score,因此我们只需要:
            self.kl_ctl * (ref_log_probs - log_probs)的最后一位 + reward_score
         
         (3) 同时我们对reward_score也做了大小限制,最大不超过self.clip_reward_value(超过统一给成self.clip_reward_value),
             最小不低于-self.clip_reward_value(低于统一给成-self.clip_reward_value)
        
         (4) 最后返回的rewards大小为:(batch_size, 各条数据的长度),对batch中的每条数据来说:
             - response的最后一位:self.kl_ctl * (ref_log_probs - log_probs)的最后一位 + reward_score
             - response的其余位置:self.kl_ctl * (ref_log_probs - log_probs)
        
        """

        kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
        rewards = kl_divergence_estimate
        # ---------------------------------------------------------------------------------------------------
        # response开始的位置:
        # prompts 是输入的提示(prompt),经过 padding 处理后,每个样本的长度一致。
        # prompts.shape[1] 是每个样本的长度。
        # start = prompts.shape[1] - 1 表示 response 的起始位置是最后一个 token 的位置(假设最后一个 token 是 response 的第一个 token)
        # ---------------------------------------------------------------------------------------------------
        start = prompts.shape[1] - 1
        # ---------------------------------------------------------------------------------------------------
        # response结束的位置
        # action_mask[:, start:] 提取从 response 开始位置到结束的所有掩码值。即遍历所有的行,从第start列开始取
        # .sum(1) 对每个样本的掩码值求和,得到 response 的长度。
        # (所以这里end是加s的,ends的尺寸是(batch_size,)
        # ---------------------------------------------------------------------------------------------------
        ends = start + action_mask[:, start:].sum(1) + 1
        # ---------------------------------------------------------------------------------------------------
        # 对rewards_score做限制
        # ---------------------------------------------------------------------------------------------------
        reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
                                  self.clip_reward_value)
        # ---------------------------------------------------------------------------------------------------
        # 对每个样本的响应部分(response)的最后一个 token 的奖励值进行更新。
        # for j in range(batch_size):遍历批次中的每个样本,j 是样本的索引。
        # rewards[j, start:ends[j]]:start 是响应部分的起始位置(固定值)。ends[j] 是第 j 个样本的响应部分的结束位置(每个样本可能不同)。start:ends[j] 表示提取第 j 个样本的响应部分的奖励值。
[-1]:表示响应部分的最后一个 token 的位置。
        batch_size = log_probs.shape[0]
        for j in range(batch_size):
            rewards[j, start:ends[j]][-1] += reward_clip[j] # 

        return rewards

强化学习中奖励分段计算过程:

  1. 计算 KL 散度估计值
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
  1. 初始化奖励数组
rewards = kl_divergence_estimate
  1. 确定响应部分的起始和结束位置
start = prompts.shape[1] - 1
ends = start + action_mask[:, start:].sum(1) + 1
  1. 对外部奖励分数进行限制
reward_clip = torch.clamp(reward_score, -self.clip_reward_value, self.clip_reward_value)
  1. 更新响应部分的最后一个 token 的奖励值
batch_size = log_probs.shape[0]
for j in range(batch_size):
    rewards[j, start:ends[j]][-1] += reward_clip[j]
4.4 优势函数的设计

好,再总结一下,目前为止我们的actor_loss为:

actor_loss = -Advt * logP(At|St)

其中,Advt = Rt + γ * V(t+1) - Vt
同时,我们对 Rt 进行来改造,使其能够衡量 Actor 模型是否遵从了 Ref 模型的约束。

现在我们把改造焦点放在 Advt 上,回想一下,既然对于收益而言,分为即时和未来,那么对于优势而言,是不是也能引入对未来优势的考量呢?这样,我们就可以把 Advt 改写成如下形式:

Advt = (Rt + γ * V(t+1) - Vt) + γ * λ * Adv(t+1)

(熟悉强化学习的朋友应该能一眼看出这是GAE*,这里我们不打算做复杂的介绍,一切都站在直觉的角度理解)
其中,新引入的
也是一个常量,可将其理解为权衡因子,直觉上看它控制了在计算当前优势时对未来优势的考量。(从强化学习的角度上,它控制了优势估计的方差和偏差)
其中,新引入的 λ 也是一个常量,可将其理解为权衡因子,直觉上看它控制了在计算当前优势时对未来优势的考量。(从强化学习的角度上,它控制了优势估计的方差和偏差)

看到这里,你可能想问:这个代表未来优势的 Advt+1,我要怎么算呢?
注意到,对于最后一个时刻 t,它的未来收益 V(t+1) 和未来优势 Adv(t+1) 都是0,也就是 Advt = Rt - Vt,这是可以直接算出来的。而有了 Advt,我们不就能从后往前,通过动态规划的方法,把所有时刻的优势都依次算出来了吗?

 def get_advantages_and_returns(self, values, rewards, start):
        """
        Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
        
        没有引入GAE前的t时刻的优势值:
        detal_t = r_t + gamma * V_t+1 - V_t
        其中:
            - r_t表示t时刻的即时收益
            - V_t+1表示未来时刻的预期收益
            - r_t + gamma * V_t+1可理解成t时刻的实际预期收益
            - V_t可理解成t时刻的预估预期收益(是模型,例如critic model自己估算出来的)
        
        引入GAE后的t时刻的优势值:
        A_t = delta_t + gamma * lambda * A_t+1
        粗暴理解为在t时刻时,不仅考虑当下优势,还考虑了未来的优势
        为了知道A_t, 我们得知道A_t+1,所以在本算法中采取了从后往前做动态规划求解的方法,也即:
        假设T是最后一个时刻,则有A_T+1 = 0, 所以有: A_T = delta_T
        知道了A_T, 就可以依次往前倒推,把A_t-1, A_t-2之类都算出来了
        
        引入GAE后t时刻的实际预期收益
        returns_t = A_t + V_t
                  = delta_t + gamma * lambda * A_t+1 + V_t
                  = r_t + gamma * V_t+1 - V_t + gamma * lambda * A_t+1 + V_t
                  = r_t + gamma * (V_t+1 + lambda * A_t+1)
        
        注意,这里不管是advantages还是returns,都只算response的部分
        """
        
        # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
        lastgaelam = 0
        advantages_reversed = []
        length = rewards.size()[-1]
        # 注意这里用了reversed,是采取从后往前倒推计算的方式
        for t in reversed(range(start, length)):
            nextvalues = values[:, t + 1] if t < length - 1 else 0.0
            delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
            lastgaelam = delta + self.gamma * self.lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1) # 优势
        returns = advantages + values[:, start:] # 实际收益
        # values: 预期收益
        return advantages.detach(), returns
4.4 PPO-epoch: 引入新约束
;