Original ODD 泛函的范

本篇将介绍ChatGPT的训练流程,及其背后的关键技术之一:RLHF(Reinforcement Learning from Human Feedback)。


(资料图片仅供参考)

传统 RL 通过 Agent 与环境的交互来实现某个任务上的策略最优化,往往需要从数百万次的试错中寻求最优策略,既耗时又耗费资源。为了解决这个问题,研究人员探索如何利用人类知识和反馈来加速RL过程。

RLHF 核心思想是结合强化学习和人类反馈,使智能体能够在更短的时间内学到最优策略。通过利用人类知识和经验,智能体可以避免掉进陷阱,缩小搜索空间,从而更快地找到最优策略。同时,人类反馈也可以帮助智能体更好地理解任务的目标和约束条件。

此外,随着 ChatGPT 的大火,在各行各业掀起了 AI 热潮,使人工智能技术应用进入了一个新阶段。各大网友将ChatGPT应用到自己的生活、工作中,如利用ChatGPT帮忙问诊、做数据分析、写周报、写代码,甚至还有写论文的。

从实际体验中可以感受到,ChatGPT 可以对各种问题给出比较流畅且质量较高的回答。其背后核心技术就是LLM+RLHF:利用RLHF技术,以RL方式结合人类反馈来优化LLM,使LLM的结果更契合人类认知。

作为ML从业者,除了体验ChatGPT给我们工作、生活带来的便利和乐趣外,更要了解其背后的技术原理,本文就带大家一起了解下ChatGPT背后的技术细节。

在具体介绍RLHF之前,先来看下LLM的总体训练流程:

可以将整个训练流程分为3个阶段:

在未标注的大量文本数据上,采用SSL(Self-Supervised Learning)方式(如预测下一个词)来训练LLM,得到预训练好的LLM。

关于SSL的更多内容可以参考之前的文章《A Cookbook of Self-Supervised Learning》,这里就不赘述了。

有了预训练好的LLM之后,我们在高质量的有标注数据上采用Supervised Learning方式对LLM进行微调,即Supervised Fine-Tuning(STF)。 SFT 的目标是优化预训练LLM,使其能生成用户真正想要的结果。

这一点是如何做到的呢?其实,训练过程就是使模型拟合训练数据过程。在SFT中,将训练数据整理成 (promot, response) 格式,向LLM展示如何适当地回应不同Prompts,使模型能拟合在不同Prompts(例如:问答、摘要、翻译)下的Response。STF流程如下图所示:

训练数据是如何标注的呢?在给定一个Prompt之后,由专业的标注人员(通常称为Labeler)来给出相应的Response。这些Labeler通常都是具有一定专业知识的人员,能给出高质量的Response。但反过来,这个标注成本非常昂贵,所以一般只有10K量级的标注量。

OpenAI训练InstructGPT时请了40位Labeler标注了13k的 (prompt, response) 样本对。如下图展示了 (prompt, response) 样本对示例:

有了 (prompt, response) 样本对之后,就可以利用它们来fine-tune预训练好得LLM模型。

从流程图可以看到,RLHF(即图中虚线框部分)包含三大部分:

预训练好的LLM 收集数据,训练奖励模型 (Reward Model,RM) 用RL来微调 LLM

关于预训练LLM的内容在前文已经介绍了,下面主要介绍后面两部分是如何实现的。

通过收集了大量人类对模型在不同Prompts的Response的偏好数据(如喜欢、打分等)之后,我们可以利用这些数据来训练一个RM来拟合人类偏好。从而实现在给定Prompt下,对模型所给出的Response进行打分,来预估人类对这个Response的满意程度。人类反馈数据收集流程如下所示:

使用输入文本数据(最好是生产环境的用户真实数据),由模型生成对应的Response,然后由人类对其进行评估,并给出Reward分(通常是0~5分,也可以是喜欢/不喜欢等)。

有了人类反馈数据后,我们就可以基于这些数据来训练RM。如下图所示:

根据人工对LLM的Response的打分,构造用于训练RM的数据 (sample,reward) 样本对,其中 sample=(prompt, response) 。RM输出的Reward在数值上就代表了人类对改Response的偏好。

引入RM的主要目标是借助RM来模仿人类的奖励标注,从而能在没有人类参与的情况下进行RLHF的离线训练。

需要说明的是:RM可以是一个经过微调的LLM,也可以是仅用人类偏好数据从头训练的LLM。

经过前文的操作,我们已经有了一个经过SFT之后的LLM,和能够评估模型输出好坏(迎合人类偏好)的RM。接下来就是如何使用RL来基于RM优化LLM。其思路是采用RM的输出值作为一个Reward,基于RL的思路进行优化,对问题建模如下:

Agent :一个以Prompt为输入,并输出文本 (或文本的概率分布) 的LLM; Action Space :词表对应的所有Token; Observation Space :输入词元序列(); Reward Function :RM和策略转变约束 (Policy Shift Constraint) 的结合。 Policy Shift Constraint 是一种用于强化学习中的优化技术,旨在减少在训练过程中策略发生显著变化的可能性。在强化学习中,智能体通过不断地与环境交互来学习最优策略。在训练过程中,策略的变化可能会导致学习不稳定,甚至失效。因此,策略转移约束的目的是限制策略的变化幅度,从而提高强化学习算法的稳定性。

如下图所示:

在开始前,复制一份模型并将其权重固定,得到Fronze LM。Fronze LM有助于防止在训练过程中将LM的权重改变过大,避免输出无意义的文本。

在训练过程中,除了RM给出的reward外,还引入了Frozen LM和Trained LM生成文本的差异惩罚项(即Policy shift constraint),来惩罚 RL 策略在每个训练批次中生成大幅偏离初始模型,以确保模型输出合理连贯的文本。如果去掉这一惩罚项可能导致模型在优化中为了得到更高奖励而生成乱码文本来愚弄RM。

在OpenAI、Anthropic和DeepMind的多篇论文中,都采用了Fronze LM和Trained LM出词分布序列之间的 KL散度来作为Policy shift constraint。因此,RL的整体reward为:

其中,和分别表示RM的输出和Fronze LM和Trained LM出词分布序列之间的 KL散度。

最后,使用PPO算法,在RM的辅助下对LM的部分参数迭代更新。如下图所示:

其中:

近端策略优化(Proximal Policy Optimization,PPO) 是一种常见的强化学习算法,用于优化策略函数以最大化累积奖励。在PPO中,策略函数通常表示为神经网络,它将状态作为输入并输出动作的概率分布。PPO的主要优点之一是其稳定性,在强化学习中很受欢迎。它使用两个策略函数来进行训练,一个当前策略和一个旧策略,并使用一个剪切参数来确保新策略不会太远离旧策略。这种剪切机制有助于防止策略跳跃,并使PPO在处理连续动作空间时表现出色。通常使用Adam优化器来更新神经网络参数。

至此,我们已经了解了ChatGPT训练的整体过程,以及RLHF在其所扮演的角色。下图是在训练过程中不同策略所带来的效果提升:

从图中可以看出:

采用SFT能有效提升预训练模型效果; 采用预训练+SFT+RLHF的InstructGPT效果要明显优于其他策略。

下面是几个比较常用的RLHF Repo,感兴趣的可以去动手尝试下:

本文介绍了ChatGPT的总体训练流程,并展开介绍了RLHF的训练过程。RLHF不只适用于LLM的训练,它还对如下两种场景有奇效:

无法构建好的loss function 。如,如何设计一个metrics来衡量模型的输出是否有趣味性? 想要用生产数据进行训练,但又很难为生产数据打标签 。如,如何从ChatGPT中获取带标签的生产数据?必须有人编写正确答案,告诉ChatGPT应该如何回答。

在面对上面两种场景时,我们可以尝试使用RLHF,没准会有意想不到的效果。

泛函的范 日常搬砖,周末更新。关键词:推荐系统、机器学习、经验总结、面试技巧69篇原创内容

公众号

[1] Nathan Lambert, Louis Castricato, Leandro von Werra, Alex Havrilla. Illustrating Reinforcement Learning from Human Feedback (RLHF) Hugging Face. 2022.

[2] Joao Lages. Reinforcement Learning from Human Feedback (RLHF) - a simplified explanation. NebulyAI. 2023.

[3] Chip Huyen. RLHF: Reinforcement Learning from Human Feedback. 2023.

[4] Ben Dickson. What is reinforcement learning from human feedback (RLHF)?. TechTalks. 2023.

[5] Ayush Thakur. Understanding Reinforcement Learning from Human Feedback. Weights & Biases. 2023.

[6] Trlx: /CarperAI/

[7] DeepspeedChat: /microsoft/

[8] ColossalAI-Chat: /hpcaitech/

[9] PaLM-rlhf-pytorch: /lucidrains/

推荐内容