diff --git a/_posts/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo.md b/_posts/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo.md new file mode 100644 index 00000000..b3a6e865 --- /dev/null +++ b/_posts/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo.md @@ -0,0 +1,680 @@ +--- +layout: distill +title: The N Implementation Details of RLHF with PPO +description: Reinforcement Learning from Human Feedback (RLHF) is pivotal in the modern application of language modeling, as exemplified by ChatGPT. This blog post delves into an in-depth exploration of RLHF, attempting to reproduce the results from OpenAI's inaugural RLHF paper, published in 2019. Our detailed examination provides valuable insights into the implementation details of RLHF, which often go unnoticed. + +date: 2024-05-07 +future: true +htmlwidgets: true + +# Anonymize when submitting +# authors: +# - name: Anonymous + + +authors: + - name: Shengyi Costa Huang + affiliations: + name: Hugging Face + - name: Tianlin Liu + affiliations: + name: University of Basel + - name: Leandro von Werra + affiliations: + name: Hugging Face + + +# authors: +# - name: Albert Einstein +# url: "https://en.wikipedia.org/wiki/Albert_Einstein" +# affiliations: +# name: IAS, Princeton +# - name: Boris Podolsky +# url: "https://en.wikipedia.org/wiki/Boris_Podolsky" +# affiliations: +# name: IAS, Princeton +# - name: Nathan Rosen +# url: "https://en.wikipedia.org/wiki/Nathan_Rosen" +# affiliations: +# name: IAS, Princeton + +# must be the exact same name as your blogpost +bibliography: 2024-05-07-the-n-implementation-details-of-rlhf-with-ppo.bib + +# Add a table of contents to your post. +# - make sure that TOC names match the actual section names +# for hyperlinks within the post to work correctly. +# - please use this format rather than manually creating a markdown table of contents. +toc: + - name: Matching Learning Curves + - name: General Implementation Details + - name: Reward Model Implementation Details + - name: Policy Training Implementation Details + - name: PyTorch Adam optimizer numerical issues w.r.t RLHF + - name: Limitations + - name: Conclusion + +# Below is an example of injecting additional post-specific styles. +# This is used in the 'Layouts' section of this post. +# If you use this post as a template, delete this _styles block. +_styles: > + .fake-img { + background: #bbb; + border: 1px solid rgba(0, 0, 0, 0.1); + box-shadow: 0 0px 4px rgba(0, 0, 0, 0.1); + margin-bottom: 12px; + } + .fake-img p { + font-family: monospace; + color: white; + text-align: left; + margin: 12px 0; + text-align: center; + font-size: 16px; + } +--- + + +**Reinforcement Learning from Human Feedback** (RLHF) has been an impactful technique for training modern language models such as ChatGPT. In our quest to research more on RLHF, this blog post closely examines OpenAI’s inaugural RLHF paper published in 2019 together with its open-source codebase at available at [*openai/lm-human-preferences*](https://github.com/openai/lm-human-preferences). Despite being based on TensorFlow-1, the code base released by OpenAI is very well-evaluated and benchmarked, making it a good place to study RLHF implementation engineering details. + +We aim to: + +1. reproduce OpenAI’s results in stylistic tasks and match the learning curves of [*openai/lm-human-preferences*](https://github.com/openai/lm-human-preferences), using the modern PyTorch and JAX frameworks in conjunction with HuggingFace Transformers that are predominantly used by the open-source community nowadays; +2. present a checklist of implementation details, similar to the spirit of [*The 37 Implementation Details of Proximal Policy Optimization*](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/) and [*Debugging RL, Without the Agonizing Pain*](https://andyljones.com/posts/rl-debugging.html); +3. provide a simple-to-read and minimal reference implementation of RLHF; + +This work is just for educational / learning purposes. For advanced users requiring more features, such as running larger models with parameter-efficient fine-tuning, [*huggingface/trl*](https://github.com/huggingface/trl) would be a great choice. + +- In [Matching Learning Curves](#matching-learning-curves), we show our main contribution: creating a codebase that can reproduce OpenAI’s results in the stylistic tasks and matching learning curves very closely with [*openai/lm-human-preferences*](https://github.com/openai/lm-human-preferences). +- We then take a technical deep dive into the implementation details that are relevant to reproducing OpenAI’s work. In [General Implementation Details](#general-implementation-details), we talk about basic details, such as how rewards/values are generated and how responses are generated. In [Reward Model Implementation Details](#reward-model-implementation-details), we talk about details such as reward normalization. In [Policy Training Implementation Details](#policy-training-implementation-details), we discuss details such as rejection sampling and reward “whitening”. + - In [**PyTorch Adam optimizer numerical issues w.r.t RLHF**](#pytorch-adam-optimizer-numerical-issues-wrt-rlhf), we highlight a very interesting implementation difference in Adam between TensorFlow and PyTorch, which causes an aggressive update in the model training. +- Next, we examine the effect of training different base models (e.g., gpt2-xl, falcon-1b,) given that the reward labels are produced with `gpt2-large`. +- Finally, we conclude our work with limitations and discussions. + + + + +Here are the important links: + +- 💾 [Our reproduction codebase](https://github.com/vwxyzjn/lm-human-preference-details) +- 🤗 [Demo of RLHF model comparison](https://huggingface.co/spaces/lm-human-preference-details/rlhf-demo) +- 🐝 [All w&b training logs](https://wandb.ai/openrlbenchmark/lm_human_preference_details) + +# Matching Learning Curves + +Our main contribution is to reproduce OpenAI’s results in stylistic tasks, such as sentiment and descriptiveness. As shown in the figure below, our codebase (orange curves) can produce nearly identical learning curves as OpenAI’s codebase (blue curves). + + +
+{% include figure.html path="assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/curve-matching.png" class="img-fluid" %} +
+ + +## A note on running openai/lm-human-preferences + +To make a direct comparison, we ran the original RLHF code at [*openai/lm-human-preferences*](https://github.com/openai/lm-human-preferences), which will offer valuable metrics to help validate and diagnose our reproduction. We were able to set the original TensorFlow 1.x code up, but it requires a hyper-specific setup: + +- OpenAI’s dataset was partially corrupted/lost (so we replaced them with similar HF datasets, which may or may not cause a performance difference) + - Specifically, its book dataset was lost during OpenAI’s GCP - Azure migration ([https://github.com/openai/lm-human-preferences/issues/17#issuecomment-1044051496](https://github.com/openai/lm-human-preferences/issues/17#issuecomment-1044051496)). We replaced the book dataset with Hugging Face’s `bookcorpus` dataset, which is, in principle, what OpenAI used. +- It can’t run on 1 V100 because it doesn’t implement gradient accumulation. Instead, it uses a large batch size and splits the batch across 8 GPUs, and will OOM on just 1 GPU. +- It can’t run on 8x A100 because it uses TensorFlow 1.x, which is incompatible with Cuda 8+ +- It can’t run on 8x V100 (16GB) because it will OOM +- It can only run on 8x V100 (32GB), which is only offered by AWS as the `p3dn.24xlarge` instance. + +# General Implementation Details + +We now take a technical deep dive into the implementation details that are relevant to reproducing OpenAI’s work. In this section, we talk about basic details, such as how rewards/values are generated and how responses are generated. Here are these details in no particular order: + +1. **The reward model and policy’s value head take input as the concatenation of `query` and `response`** + 1. The reward model and policy’s value head do *not* only look at the response. Instead, it concatenates the `query` and `response` together as `query_response` ([lm_human_preferences/rewards.py#L105-L107](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/rewards.py#L105-L107)). + 2. So, for example, if `query = "he was quiet for a minute, his eyes unreadable"`., and the `response = "He looked at his left hand, which held the arm that held his arm out in front of him."`, then the reward model and policy’s value do a forward pass on `query_response = "he was quiet for a minute, his eyes unreadable. He looked at his left hand, which held the arm that held his arm out in front of him."` and produced rewards and values of shape `(B, T, 1)`, where `B` is the batch size, `T` is the sequence length, and `1` is the reward head dimension of 1 ([lm_human_preferences/rewards.py#L105-L107](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/rewards.py#L105-L107), [lm_human_preferences/policy.py#L111](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/policy.py#L111)). + 3. The `T` means that each token has a reward associated with it and its previous context. For example, the `eyes` token would have a reward corresponding to `he was quiet for a minute, his eyes`. +2. **Pad with a special padding token and truncate inputs.** + 1. OpenAI sets a fixed input length for query `query_length`; it **pads** sequences that are too short with `pad_token` ([lm_human_preferences/language/datasets.py#L66-L67](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/datasets.py#L66-L67)) and **truncates** sequences that are too long ([lm_human_preferences/language/datasets.py#L57](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/datasets.py#L57)). See [here](https://huggingface.co/docs/transformers/pad_truncation) for a general introduction to the concept). When padding the inputs, OpenAI uses a token beyond the vocabulary ([lm_human_preferences/language/encodings.py#L56](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/encodings.py#L56)). + 1. **Note on HF’s transformers — padding token.** According to ([transformers#2630#issuecomment-578159876](https://github.com/huggingface/transformers/issues/2630#issuecomment-578159876)), padding tokens were not used during the pre-training of GPT and GPT-2; therefore transformer’s gpt2 models have no official padding token associated with its tokenizer. A common practice is to set `tokenizer.pad_token = tokenizer.eos_token`, but in this work, we shall distinguish these two special tokens to match OpenAI’s original setting, so we will use `tokenizer.add_special_tokens({"pad_token": "[PAD]"})`. + + Note that having no padding token is a default setting for decoder models, since they train with “packing” during pretraining, which means that many sequences are concatenated and separated by the EOS token and chunks of this sequence that always have the max length are fed to the model during pretraining. + 2. When putting everything together, here is an example + + ```python + import transformers + tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right") + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + query_length = 5 + texts = [ + "usually, he would", + "she thought about it", + ] + tokens = [] + for text in texts: + tokens.append(tokenizer.encode(text)[:query_length]) + + print("tokens", tokens) + inputs = tokenizer.pad( + {"input_ids": tokens}, + padding="max_length", + max_length=query_length, + return_tensors="pt", + return_attention_mask=True, + ) + print("inputs", inputs) + + """prints are + tokens [[23073, 11, 339, 561], [7091, 1807, 546, 340]] + inputs {'input_ids': tensor([[23073, 11, 339, 561, 50257], + [ 7091, 1807, 546, 340, 50257]]), 'attention_mask': tensor([[1, 1, 1, 1, 0], + [1, 1, 1, 1, 0]])} + """ + ``` + +3. **Adjust position indices correspondingly for padding tokens** + 1. When calculating the logits, OpenAI’s code works by masking out padding tokens properly. This is achieved by finding out the token indices corresponding to the padding tokens ([lm_human_preferences/language/model.py#L296-L297](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L296-L297)), followed by adjusting their position indices correspondingly ([lm_human_preferences/language/model.py#L320](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L320)). + 2. For example, if the `query=[23073, 50259, 50259]` and `response=[11, 339, 561]`, where (`50259` is OpenAI’s padding token), it then creates position indices as `[[0 1 1 1 2 3]]` and logits as follows. Note how the logits corresponding to the padding tokens remain the same as before! This is the effect we should be aiming for in our reproduction. + + ```python + all_logits [[[ -35.28693 -34.2875 -38.16074 ... -41.595802 -41.082108 + -35.36577 ] + [ -35.28693 -34.2875 -38.16074 ... -41.595802 -41.082108 + -35.36577 ] + [ -35.28693 -34.2875 -38.16074 ... -41.595802 -41.082108 + -35.36577 ] + [-111.303955 -110.94471 -112.90624 ... -113.13064 -113.7788 + -109.17345 ] + [-111.51512 -109.61077 -114.90231 ... -118.43514 -111.56671 + -112.12478 ] + [-122.69775 -121.84468 -128.27417 ... -132.28055 -130.39604 + -125.707756]]] (1, 6, 50257) + ``` + + 3. **Note on HF’s transformers — `position_ids` and `padding_side`.** We can replicate the exact logits using Hugging Face’s transformer with 1) left padding and 2) pass in the appropriate `position_ids`: + + ```python + import torch + import transformers + tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right") + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + pad_id = tokenizer.pad_token_id + query = torch.tensor([ + [pad_id, pad_id, 23073], + ]) + response = torch.tensor([ + [11, 339, 561], + ]) + temperature = 1.0 + + query = torch.tensor(query) + response = torch.tensor(response).long() + context_length = query.shape[1] + query_response = torch.cat((query, response), 1) + pretrained_model = transformers.AutoModelForCausalLM.from_pretrained("gpt2") + def forward(policy, query_responses, tokenizer): + attention_mask = query_responses != tokenizer.pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + input_ids = query_responses.clone() + input_ids[~attention_mask] = 0 + return policy( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + output = forward(pretrained_model, query_response, tokenizer) + logits = output.logits + logits /= temperature + print(logits) + + """ + tensor([[[ -26.9395, -26.4709, -30.0456, ..., -33.2208, -33.2884, + -27.4360], + [ -27.1677, -26.7330, -30.2386, ..., -33.6813, -33.6931, + -27.5928], + [ -35.2869, -34.2875, -38.1608, ..., -41.5958, -41.0821, + -35.3658], + [-111.3040, -110.9447, -112.9062, ..., -113.1306, -113.7788, + -109.1734], + [-111.5152, -109.6108, -114.9024, ..., -118.4352, -111.5668, + -112.1248], + [-122.6978, -121.8447, -128.2742, ..., -132.2805, -130.3961, + -125.7078]]], grad_fn=) + """ + ``` + + 4. **Note on HF’s transformers — `position_ids` during `generate`:** during generate we should not pass in `position_ids` because the `position_ids` are already adjusted in `transformers` (see [huggingface/transformers#/7552](https://github.com/huggingface/transformers/pull/7552)). + + Usually, we almost never pass `position_ids` in transformers. All the masking and shifting logic are already implemented e.g. in the `generate` function (need permanent code link). +4. **Response generation samples a fixed-length response without padding.** + 1. During response generation, OpenAI uses `top_k=0, top_p=1.0` and just do categorical samples across the vocabulary ([lm_human_preferences/language/sample.py#L43](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/sample.py#L43)) and the code would keep sampling until a fixed-length response is generated ([lm_human_preferences/policy.py#L103](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/policy.py#L103)). Notably, even if it encounters EOS (end-of-sequence) tokens, it will keep sampling. + 2. **Note on HF’s transformers — sampling could stop at `eos_token`:** in `transformers`, the generation could stop at `eos_token` ([src/transformers/generation/utils.py#L2248-L2256](https://github.com/huggingface/transformers/blob/67b85f24def79962ce075353c2627f78e0e53e9f/src/transformers/generation/utils.py#L2248-L2256)), which is not the same as OpenAI’s setting. To align the setting, we need to do set `pretrained_model.generation_config.eos_token_id = None, pretrained_model.generation_config.pad_token_id = None`. Note that `transformers.GenerationConfig(eos_token_id=None, pad_token_id=None, ...)` does not work because `pretrained_model.generation_config` would override and set a `eos_token`. + + ```python + import torch + import transformers + tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2", padding_side="right") + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + pad_id = tokenizer.pad_token_id + query = torch.tensor([ + [pad_id, pad_id, 23073], + ]) + response = torch.tensor([ + [11, 339, 561], + ]) + response_length = 4 + temperature = 0.7 + pretrained_model = transformers.AutoModelForCausalLM.from_pretrained("gpt2") + pretrained_model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to + pretrained_model.generation_config.pad_token_id = None # generate tokens without truncation / padding + generation_config = transformers.GenerationConfig( + max_new_tokens=response_length, + min_new_tokens=response_length, + temperature=temperature, + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + context_length = query.shape[1] + attention_mask = query != tokenizer.pad_token_id + input_ids = query.clone() + input_ids[~attention_mask] = 0 # set padding tokens to 0 + output = pretrained_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # generation collapsed if this was turned on. + generation_config=generation_config, + return_dict_in_generate=True, + ) + print(output.sequences) + + """ + tensor([[ 0, 0, 23073, 16851, 11, 475, 991]]) + """ + ``` + + 3. Note that in a more recent codebase https://github.com/openai/summarize-from-feedback, OpenAI does stop sampling when encountering EOS token ([summarize_from_feedback/utils/experiment_helpers.py#L19](https://github.com/openai/summarize-from-feedback/blob/8af822a428c93432aa80ffbe5b065a8f93895669/summarize_from_feedback/utils/experiment_helpers.py#L19)). However in this work we aim to do a 1:1 replication, so we align the setting that could keep sampling even eos_token is encountered +5. **Learning rate annealing for reward model and policy training.** + 1. As Ziegler et al. (2019) suggested, the reward model is trained for a single epoch to avoid overfitting the limited amount of human annotation data (e.g., the `descriptiveness` task only had about 5000 labels). During this single epoch, the learning rate is annealed to zero ([lm_human_preferences/train_reward.py#L249](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_reward.py#L249)). + 2. Similar to reward model training, the policy's learning rate is annealed to zero ([lm_human_preferences/train_policy.py#L172-L173](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L172-L173)). +6. **Use different seeds for different processes** + 1. When spawning 8 GPU processes to do data parallelism, OpenAI sets a different random seed per process ([lm_human_preferences/utils/core.py#L108-L111](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/utils/core.py#L108-L111)). Implementation-wise, this is done via `local_seed = args.seed + process_rank * 100003`. The seed is going to make the model produce different responses and get different scores, for example. + 1. Note: We believe the dataset shuffling has a bug — the dataset is shuffled using the same seed for some reason ([lm_human_preferences/lm_tasks.py#L94-L97](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/lm_tasks.py#L94-L97)). + +# Reward Model Implementation Details + +In this section, we discuss reward-model-specific implementation details. We talk about details such as reward normalization and layer initialization. Here are these details in no particular order: + +1. **The reward model only outputs the value at the last token.** + 1. Notice that the rewards obtained after the forward pass on the concatenation of `query` and `response` will have the shape `(B, T, 1)`, where `B` is the batch size, `T` is the sequence length (which is always the same; it is `query_length + response_length = 64 + 24 = 88` in OpenAI’s setting for stylistic tasks, see [launch.py#L9-L11](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L9-L11)), and `1` is the reward head dimension of 1. For RLHF purposes, the original codebase extracts the reward of the last token ([lm_human_preferences/rewards.py#L132](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/rewards.py#L132)), so that the rewards will only have shape `(B, 1)`. + 2. Note that in a more recent codebase [*openai/summarize-from-feedback*](https://github.com/openai/summarize-from-feedback), OpenAI stops sampling when encountering EOS token ([summarize_from_feedback/utils/experiment_helpers.py#L19](https://github.com/openai/summarize-from-feedback/blob/8af822a428c93432aa80ffbe5b065a8f93895669/summarize_from_feedback/utils/experiment_helpers.py#L19)). When extracting rewards, it is going to identify the `last_response_index`, the index before the EOS token ([#L11-L13](https://github.com/openai/summarize-from-feedback/blob/8af822a428c93432aa80ffbe5b065a8f93895669/summarize_from_feedback/reward_model.py#L11-L13)), and extract the reward at that index ([summarize_from_feedback/reward_model.py#L59](https://github.com/openai/summarize-from-feedback/blob/8af822a428c93432aa80ffbe5b065a8f93895669/summarize_from_feedback/reward_model.py#L59)). However in this work we just stick with the original setting. +2. **Reward head layer initialization** + 1. The weight of the reward head is initialized according to \\( \mathcal{N}\left(0,1 /\left(\sqrt{d_{\text {model }}+1}\right)\right) \\) ([lm_human_preferences/language/model.py#L368,](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L368) [lm_human_preferences/language/model.py#L251-L252](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L251-L252)). This aligns with the settings in Stiennon et al., 2020 ([summarize_from_feedback/query_response_model.py#L106-L107](https://github.com/openai/summarize-from-feedback/blob/8af822a428c93432aa80ffbe5b065a8f93895669/summarize_from_feedback/query_response_model.py#L106-L107)) (P.S., Stiennon et al., 2020 had a typo on page 17 saying the distribution is \\( \mathcal{N}\left(0,1 /\left(d_{\text {model }}+1\right)\right) \\) without the square root) + 2. The bias of the reward head is set to 0 ([lm_human_preferences/language/model.py#L254](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L254)). +3. **Reward model normalization before and after** + 1. In the paper, Ziegler el al. (2019) mentioned that "to keep the scale of the reward model consistent across training, we normalize it so that it has mean 0 and variance 1 for + \\( x \sim \mathcal{D}, y \sim \rho(·|x) \\)." To perform the normalization process, the code first creates a `reward_gain` and `reward_bias`, such that the reward can be calculated by `reward = reward * reward_gain + reward_bias` ([lm_human_preferences/rewards.py#L50-L51](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/rewards.py#L50-L51)). + 2. When performing the normalization process, the code first sets `reward_gain=1, reward_bias=0` ([lm_human_preferences/train_reward.py#L211](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_reward.py#L211)), followed by collecting sampled queries from the target dataset (e.g., `bookcorpus, tldr, cnndm`), completed responses, and evaluated rewards. It then gets the **empirical mean and std** of the evaluated reward ([lm_human_preferences/train_reward.py#L162-L167](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_reward.py#L162-L167)) and tries to compute what the `reward_gain` and `reward_bias` should be. + 3. Let us use \\( \mu_{\mathcal{D}} \\) to denote the empirical mean, \\( \sigma_{\mathcal{D}} \\) the empirical std, \\(g\\) the `reward_gain`, \\(b\\) `reward_bias`, \\( \mu_{\mathcal{T}} = 0\\) **target mean** and \\( \sigma_{\mathcal{T}}=1\\) **target std**. Then we have the following formula. + + $$ + \begin{aligned}g*\mathcal{N}(\mu_{\mathcal{D}}, \sigma_{\mathcal{D}}) + b &= \mathcal{N}(g*\mu_{\mathcal{D}}, g*\sigma_{\mathcal{D}}) + b\\&= \mathcal{N}(g*\mu_{\mathcal{D}} + b, g*\sigma_{\mathcal{D}}) \\&= \mathcal{N}(\mu_{\mathcal{T}}, \sigma_{\mathcal{T}}) \\g &= \frac{\sigma_{\mathcal{T}}}{\sigma_{\mathcal{D}}} \\b &= \mu_{\mathcal{T}} - g*\mu_{\mathcal{D}}\end{aligned} + $$ + + 4. The normalization process is then applied **before** and **after** reward model training ([lm_human_preferences/train_reward.py#L232-L234](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_reward.py#L232-L234), [lm_human_preferences/train_reward.py#L252-L254](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_reward.py#L252-L254)). + + + 5. Note that responses \\( y \sim \rho(·|x) \\) we generated for the normalization purpose are from the pre-trained language model \\(\rho \\). The model + \\(\rho \\) is fixed as a reference and is not updated in reward learning ([lm_human_preferences/train_reward.py#L286C1-L286C31](https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_reward.py#L286C1-L286C31)). + +# Policy Training Implementation Details + +In this section, we will delve into details, such as layer initialization, data post-processing, and dropout settings. We will also explore techniques, such as of rejection sampling and reward "whitening", and adaptive KL. Here are these details in no particular order: + +1. **Scale the logits by sampling temperature.** + 1. When calculating the log probability of responses, the model first outputs the logits of the tokens in the responses, followed by dividing the logits with the sampling temperature ([lm_human_preferences/policy.py#L121](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/policy.py#L121)). I.e., `logits /= self.temperature` + 2. In an informal test, we found that without this scaling, the KL would rise faster than expected, and performance would deteriorate. +2. **Value head layer initialization** + 1. The weight of the value head is initialized according to \\(\mathcal{N}\left(0,0\right)\\) ([lm_human_preferences/language/model.py#L368,](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L368) [lm_human_preferences/language/model.py#L251-L252](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L251-L252)). This is + 2. The bias of the reward head is set to 0 ([lm_human_preferences/language/model.py#L254](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/model.py#L254)). +3. **Select query texts that start and end with a period** + 1. This is done as part of the data preprocessing; + 1. Tries to select text only after `start_text="."` ([lm_human_preferences/language/datasets.py#L51](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/datasets.py#L51)) + 2. Tries select text just before `end_text="."` ([lm_human_preferences/language/datasets.py#L61](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/datasets.py#L61)) + 3. Then pad the text ([lm_human_preferences/language/datasets.py#L66-L67](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/language/datasets.py#L66-L67)) + 2. When running `openai/lm-human-preferences`, OpenAI’s datasets were partially corrupted/lost ([openai/lm-human-preferences/issues/17#issuecomment-104405149](https://github.com/openai/lm-human-preferences/issues/17#issuecomment-1044051496)), so we had to replace them with similar HF datasets, which may or may not cause a performance difference) + 3. For the book dataset, we used [https://huggingface.co/datasets/bookcorpus](https://huggingface.co/datasets/bookcorpus), which we find not necessary to extract sentences that start and end with periods because the dataset ) is already pre-processed this way (e.g., `"usually , he would be tearing around the living room , playing with his toys ."`) To this end, we set `start_text=None, end_text=None` for the `sentiment` and `descriptiveness` tasks. +4. **Disable dropout** + 1. Ziegler et al. (2019) suggested, “We do not use dropout for policy training.” This is also done in the code ([lm_human_preferences/policy.py#L48](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/policy.py#L48)). +5. **Rejection sampling** + 1. Ziegler et al. (2019) suggested, “We use rejection sampling to ensure there is a period between tokens 16 and 24 and then truncate at that period (This is a crude approximation for ‘end of sentence.’ We chose it because it is easy to integrate into the RL loop, and even a crude approximation is sufficient for the intended purpose of making the human evaluation task somewhat easier). During the RL finetuning, we penalize continuations that don’t have such a period by giving them a fixed reward of −1.” + 2. Specifically, this is achieved with the following steps: + 1. **Token truncation**: We want to truncate at the first occurrence of `truncate_token` that appears at or after position `truncate_after` in the responses ([lm_human_preferences/train_policy.py#L378](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L378)) + 2. **Run reward model on truncated response:** After the response has been truncated by the token truncation process, the code then runs the reward model on the **truncated response**. + 3. **Rejection sampling**: if there is not a period between tokens 16 and 24, then replace the score of the response with a fixed low value (such as -1)([lm_human_preferences/train_policy.py#L384](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L384), [lm_human_preferences/train_policy.py#L384-L402](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L384-L402)) + 4. To give some examples in `descriptiveness`: + + {% include figure.html path="assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/descriptiveness-samples.png" class="img-fluid" %} + +6. **Discount factor = 1** + 1. The discount parameter \\(\gamma\\) is set to 1 ([lm_human_preferences/train_policy.py#L56](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L56)), which means that future rewards are given the same weight as immediate rewards. +7. **Terminology of the training loop: batches and minibatches in PPO** + 1. OpenAI uses the following training loop ([lm_human_preferences/train_policy.py#L184-L192](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L184-L192)). Note: we additionally added the `micro_batch_size` to help deal with the case in gradient accumulation. At each epoch, it shuffles the batch indices. + + ```python + + import numpy as np + batch_size = 8 + nminibatches = 2 + gradient_accumulation_steps = 2 + mini_batch_size = batch_size // nminibatches + micro_batch_size = mini_batch_size // gradient_accumulation_steps + data = np.arange(batch_size).astype(np.float32) + print("data:", data) + print("batch_size:", batch_size) + print("mini_batch_size:", mini_batch_size) + print("micro_batch_size:", micro_batch_size) + for epoch in range(4): + batch_inds = np.random.permutation(batch_size) + print("epoch:", epoch, "batch_inds:", batch_inds) + for mini_batch_start in range(0, batch_size, mini_batch_size): + mini_batch_end = mini_batch_start + mini_batch_size + mini_batch_inds = batch_inds[mini_batch_start:mini_batch_end] + + # `optimizer.zero_grad()` set optimizer to zero for gradient accumulation + for micro_batch_start in range(0, mini_batch_size, micro_batch_size): + micro_batch_end = micro_batch_start + micro_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + print("____⏩ a forward pass on", data[micro_batch_inds]) + # `optimizer.step()` + print("⏪ a backward pass on", data[mini_batch_inds]) + + # data: [0. 1. 2. 3. 4. 5. 6. 7.] + # batch_size: 8 + # mini_batch_size: 4 + # micro_batch_size: 2 + # epoch: 0 batch_inds: [6 4 0 7 3 5 1 2] + # ____⏩ a forward pass on [6. 4.] + # ____⏩ a forward pass on [0. 7.] + # ⏪ a backward pass on [6. 4. 0. 7.] + # ____⏩ a forward pass on [3. 5.] + # ____⏩ a forward pass on [1. 2.] + # ⏪ a backward pass on [3. 5. 1. 2.] + # epoch: 1 batch_inds: [6 7 3 2 0 4 5 1] + # ____⏩ a forward pass on [6. 7.] + # ____⏩ a forward pass on [3. 2.] + # ⏪ a backward pass on [6. 7. 3. 2.] + # ____⏩ a forward pass on [0. 4.] + # ____⏩ a forward pass on [5. 1.] + # ⏪ a backward pass on [0. 4. 5. 1.] + # epoch: 2 batch_inds: [1 4 5 6 0 7 3 2] + # ____⏩ a forward pass on [1. 4.] + # ____⏩ a forward pass on [5. 6.] + # ⏪ a backward pass on [1. 4. 5. 6.] + # ____⏩ a forward pass on [0. 7.] + # ____⏩ a forward pass on [3. 2.] + # ⏪ a backward pass on [0. 7. 3. 2.] + # epoch: 3 batch_inds: [7 2 4 1 3 0 6 5] + # ____⏩ a forward pass on [7. 2.] + # ____⏩ a forward pass on [4. 1.] + # ⏪ a backward pass on [7. 2. 4. 1.] + # ____⏩ a forward pass on [3. 0.] + # ____⏩ a forward pass on [6. 5.] + # ⏪ a backward pass on [3. 0. 6. 5.] + ``` + +8. **Per-token KL penalty** + - The code adds a per-token KL penalty ([lm_human_preferences/train_policy.py#L150-L153](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L150-L153)) to the rewards, in order to discourage the policy to be very different from the original policy. + - Using the `"usually, he would"` as an example, it gets tokenized to `[23073, 11, 339, 561]`. Say we use `[23073]` as the query and `[11, 339, 561]` as the response. Then under the default `gpt2` parameters, the response tokens will have log probabilities of the reference policy `logprobs=[-3.3213, -4.9980, -3.8690]` . + - During the first PPO update epoch and minibatch update, so the active policy will have the same log probabilities `new_logprobs=[-3.3213, -4.9980, -3.8690]`. , so the per-token KL penalty would be `kl = new_logprobs - logprobs = [0., 0., 0.,]` + - However, after the first gradient backward pass, we could have `new_logprob=[3.3213, -4.9980, -3.8690]` , so the per-token KL penalty becomes `kl = new_logprobs - logprobs = [-0.3315, -0.0426, 0.6351]` + - Then the `non_score_reward = beta * kl` , where `beta` is the KL penalty coefficient \\(\beta\\), and it’s added to the `score` obtained from the reward model to create the `rewards` used for training. The `score` is only given at the end of episode; it could look like `[0.4,]` , and we have `rewards = [beta * -0.3315, beta * -0.0426, beta * 0.6351 + 0.4]`. +9. **Per-minibatch reward and advantage whitening, with optional mean shifting** + 1. OpenAI implements a `whiten` function that looks like below, basically normalizing the `values` by subtracting its mean followed by dividing by its standard deviation. Optionally, `whiten` can shift back the mean of the whitened `values` with `shift_mean=True`. + + ```python + def whiten(values, shift_mean=True): + mean, var = torch.mean(values), torch.var(values, unbiased=False) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + ``` + + 1. In each minibatch, OpenAI then whitens the reward `whiten(rewards, shift_mean=False)` without shifting the mean ([lm_human_preferences/train_policy.py#L325](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L325)) and whitens the advantages `whiten(advantages)` with the shifted mean ([lm_human_preferences/train_policy.py#L338](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L338)). + 2. **Optimization note:** if the number of minibatches is one (which is the case in this reproduction) we only need to whiten rewards, calculate and whiten advantages once since their values won’t change. + 3. **TensorFlow vs PyTorch note:** Different behavior of `tf.moments` vs `torch.var`: The behavior of whitening is different in torch vs tf because the variance calculation is different: + + ```jsx + import numpy as np + import tensorflow as tf + import torch + + def whiten_tf(values, shift_mean=True): + mean, var = tf.nn.moments(values, axes=list(range(values.shape.rank))) + mean = tf.Print(mean, [mean], 'mean', summarize=100) + var = tf.Print(var, [var], 'var', summarize=100) + whitened = (values - mean) * tf.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + def whiten_pt(values, shift_mean=True, unbiased=True): + mean, var = torch.mean(values), torch.var(values, unbiased=unbiased) + print("mean", mean) + print("var", var) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + rewards = np.array([ + [1.2, 1.3, 1.4], + [1.5, 1.6, 1.7], + [1.8, 1.9, 2.0], + ]) + + with tf.Session() as sess: + print(sess.run(whiten_tf(tf.constant(rewards, dtype=tf.float32), shift_mean=False))) + print(whiten_pt(torch.tensor(rewards), shift_mean=False, unbiased=True)) + print(whiten_pt(torch.tensor(rewards), shift_mean=False, unbiased=False)) + ``` + + ```jsx + mean[1.5999999] + var[0.0666666627] + [[0.05080712 0.4381051 0.8254035 ] + [1.2127019 1.6000004 1.9872988 ] + [2.3745968 2.7618952 3.1491938 ]] + mean tensor(1.6000, dtype=torch.float64) + var tensor(0.0750, dtype=torch.float64) + tensor([[0.1394, 0.5046, 0.8697], + [1.2349, 1.6000, 1.9651], + [2.3303, 2.6954, 3.0606]], dtype=torch.float64) + mean tensor(1.6000, dtype=torch.float64) + var tensor(0.0667, dtype=torch.float64) + tensor([[0.0508, 0.4381, 0.8254], + [1.2127, 1.6000, 1.9873], + [2.3746, 2.7619, 3.1492]], dtype=torch.float64) + + ``` + +10. **Clipped value function** + 1. As done in the original PPO ([baselines/ppo2/model.py#L68-L75](https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/ppo2/model.py#L68-L75)), the value function is clipped ([lm_human_preferences/train_policy.py#L343-L348](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L343-L348)) in a similar fashion as the policy objective. +11. **Adaptive KL** + - The KL divergence penalty coefficient \\(\beta\\) is modified adaptively based on the KL divergence between the current policy and the previous policy. If the KL divergence is outside a predefined target range, the penalty coefficient is adjusted to bring it closer to the target range ([lm_human_preferences/train_policy.py#L115-L124](https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/lm_human_preferences/train_policy.py#L115-L124)). It’s implemented as follows: + + ```python + class AdaptiveKLController: + def __init__(self, init_kl_coef, hparams): + self.value = init_kl_coef + self.hparams = hparams + + def update(self, current, n_steps): + target = self.hparams.target + proportional_error = np.clip(current / target - 1, -0.2, 0.2) + mult = 1 + proportional_error * n_steps / self.hparams.horizon + self.value *= mult + ``` + + - For the `sentiment` and `descriptiveness` tasks examined in this work, we have `init_kl_coef=0.15, hparams.target=6, hparams.horizon=10000`. + +## **PyTorch Adam optimizer numerical issues w.r.t RLHF** + +- This implementation detail is so interesting that it deserves a full section. +- PyTorch Adam optimizer ([torch.optim.Adam.html](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html)) has a different implementation compared to TensorFlow’s Adam optimizer (TF1 Adam at [tensorflow/v1.15.2/adam.py](https://github.com/tensorflow/tensorflow/blob/v1.15.2/tensorflow/python/training/adam.py), TF2 Adam at [keras/adam.py#L26-L220](https://github.com/keras-team/keras/blob/v2.13.1/keras/optimizers/adam.py#L26-L220)). In particular, **PyTorch follows Algorithm 1** of the Kingma and Ba’s Adam , but **TensorFlow uses the formulation just before Section 2.1** of the paper and its `epsilon` referred to here is `epsilon hat` in the paper. In a pseudocode comparison, we have the following + +```python +### pytorch adam implementation: +bias_correction1 = 1 - beta1 ** step +bias_correction2 = 1 - beta2 ** step +step_size = lr / bias_correction1 +bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) +denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) +param.addcdiv_(exp_avg, denom, value=-step_size) + +### tensorflow adam implementation: +lr_t = lr * _dispatch_sqrt((1 - beta2 ** step)) / (1 - beta1 ** step) +denom = exp_avg_sq.sqrt().add_(eps) +param.addcdiv_(exp_avg, denom, value=-lr_t) +``` + +- Let’s compare the update equations of pytorch-style and tensorflow-style adam. Following the notation of the adam paper [(Kingma and Ba, 2014)](https://arxiv.org/abs/1412.6980), we have the gradient update rules for pytorch adam (Algorithm 1 of Kingma and Ba’s paper) and tensorflow-style adam (the formulation just before Section 2.1 of Kingma and Ba’s paper) as below: + +$$\begin{aligned}\text{pytorch adam :}\quad \theta_t & =\theta_{t-1}-\alpha \cdot \hat{m}_t /\left(\sqrt{\hat{v}_t}+\varepsilon\right) \\& =\theta_{t-1}- \alpha \underbrace{\left[m_t /\left(1-\beta_1^t\right)\right]}_{=\hat{m}_t} /\left[\sqrt{\underbrace{v_t /\left(1-\beta_2^t\right)}_{=\hat{v}_t} }+\varepsilon\right]\\& =\theta_{t-1}- \alpha\left[m_t /\left(1-\beta_1^t\right)\right]\frac{\sqrt{1-\beta_2^t}}{\sqrt{v_t}+\color{green}{\varepsilon \sqrt{1-\beta_2^t}}}\end{aligned}$$ + +$$\begin{aligned}\text{tensorflow adam:}\quad \theta_t & =\theta_{t-1}-\alpha_t m_t /\left(\sqrt{v_t}+\hat{\varepsilon}\right) \\& =\theta_{t-1}-\underbrace{\left[\alpha \sqrt{1-\beta_2^t} /\left(1-\beta_1^t\right)\right]}_{=\alpha_t} m_t /\left(\sqrt{v_t}+\hat{\varepsilon}\right) \\& =\theta_{t-1}- \alpha\left[m_t /\left(1-\beta_1^t\right)\right] \frac{\sqrt{1-\beta_2^t}}{\sqrt{v_t}+\color{green}{\hat{\varepsilon}}} \end{aligned}$$ + + +- The equations above highlight that the distinction between pytorch and tensorflow implementation is their **normalization terms**, \\(\color{green}{\varepsilon \sqrt{1-\beta_2^t}}\\) and \\(\color{green}{\hat{\varepsilon}}\\). The two versions are equivalent if we set \\(\hat{\varepsilon} =\varepsilon \sqrt{1-\beta_2^t}\\) . However, in the pytorch and tensorflow APIs, we can only set \\(\varepsilon\\) (pytorch) and \\(\hat{\varepsilon}\\) (tensorflow) via the `eps` argument, causing differences in their update equations. What if we set \\(\varepsilon\\) and \\(\hat{\varepsilon}\\) to the same value, say, 1e-5? Then for tensorflow adam, the normalization term \\(\hat{\varepsilon} = \text{1e-5}\\) is just a constant. But for pytorch adam, the normalization term \\({\varepsilon \sqrt{1-\beta_2^t}}\\) changes over time. Importantly, initially much smaller than 1e-5 when the timestep \\(t\\) is small, the term \\({\varepsilon \sqrt{1-\beta_2^t}}\\) gradually approaches to 1e-5 as timesteps increase. The plot below compares these two normalization terms over timesteps: + +{% include figure.html path="assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/norma_const_comparison.png" class="img-fluid" %} + +- The above figure shows that, if we set the same `eps` in pytorch adam and tensorflow adam, then pytorch-adam uses a much smaller normalization term than tensorflow-adam in the early phase of training. In other words, pytorch adam goes for **more aggressive gradient updates early in the training**. Our experiments support this finding, as we will demonstrate below. +- How does this impact reproducibility and performance? To align settings, we record the original query, response, and rewards from [https://github.com/openai/lm-human-preferences](https://github.com/openai/lm-human-preferences) and save them. We also record the metrics of the first two epochs of training with TF1’s `AdamOptimizer` optimizer as the ground truth. Below are some key metrics: + + + | | OpenAI’s TF1 Adam | PyTorch’s Adam | Our custom Tensorflow-style Adam | + | --- | --- | --- | --- | + | policy/approxkl | 0.00037167023 | 0.0023672834504395723 | 0.000374998344341293 | + | policy/clipfrac | 0.0045572915 | 0.02018229104578495 | 0.0052083334885537624 | + | ratio_mean | 1.0051285 | 1.0105520486831665 | 1.0044583082199097 | + | ratio_var | 0.0007716546 | 0.005374275613576174 | 0.0007942612282931805 | + | ratio_max | 1.227216 | 1.8121057748794556 | 1.250215768814087 | + | ratio_min | 0.7400441 | 0.4011387825012207 | 0.7299948930740356 | + | logprob_diff_mean | 0.0047487603 | 0.008101251907646656 | 0.004073789343237877 | + | logprob_diff_var | 0.0007207897 | 0.004668936599045992 | 0.0007334011606872082 | + | logprob_diff_max | 0.20474821 | 0.594489574432373 | 0.22331619262695312 | + | logprob_diff_min | -0.30104542 | -0.9134478569030762 | -0.31471776962280273 | +- **PyTorch’s `Adam` produces a more aggressive update** for some reason. Here are some evidence: + - **PyTorch’s `Adam`'s `logprob_diff_var`** **is 6x higher**. Here `logprobs_diff = new_logprobs - logprobs` is the difference between the log probability of tokens between the initial and current policy after two epochs of training. Having a larger `logprob_diff_var` means the scale of the log probability changes is larger than that in OpenAI’s TF1 Adam. + - **PyTorch’s `Adam` presents a more extreme ratio max and min.** Here `ratio = torch.exp(logprobs_diff)`. Having a `ratio_max=1.8121057748794556` means that for some token, the probability of sampling that token is 1.8x more likely under the current policy, as opposed to only 1.2x with OpenAI’s TF1 Adam. + - **Larger `policy/approxkl` `policy/clipfrac`.** Because of the aggressive update, the ratio gets clipped **4.4x more often, and the approximate KL divergence is 6x larger.** + - The aggressive update is likely gonna cause further issues. E.g., `logprob_diff_mean` is 1.7x larger in PyTorch’s `Adam`, which would correspond to 1.7x larger KL penalty in the next reward calculation; this could get compounded. In fact, this might be related to the famous KL divergence issue — KL penalty is much larger than it should be and the model could pay more attention and optimizes for it more instead, therefore causing negative KL divergence. +- **Larger models get affected more.** We conducted experiments comparing PyTorch’s `Adam` (codename `pt_adam`) and our custom TensorFlow-style (codename `tf_adam`) with `gpt2` and `gpt2-xl`. We found that the performance are roughly similar under `gpt2`; however with `gpt2-xl`, we observed a more aggressive updates, meaning that larger models get affected by this issue more. + - When the initial policy updates are more aggressive in `gpt2-xl`, the training dynamics get affected. For example, we see a much larger `objective/kl` and `objective/scores` spikes with `pt_adam`, especially with `sentiment` — *the biggest KL was as large as 17.5* in one of the random seeds, suggesting an undesirable over-optimization. + - Furthermore, because of the larger KL, many other training metrics are affected as well. For example, we see a much larger `clipfrac` (the fraction of time the `ratio` gets clipped by PPO’s objective clip coefficient 0.2) and `approxkl`. + + +
+{% include figure.html path="assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/adam_gpt2.png" class="img-fluid" %} + + +{% include figure.html path="assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/adam_gpt2_xl.png" class="img-fluid" %} +
+ + +# Limitations + +Noticed this work does not try to reproduce the summarization work in CNN DM or TL;DR. This was because we found the training to be time-consuming and brittle. + +The particular training run we had showed poor GPU utilization (around 30%), so it takes almost 4 days to perform a training run, which is highly expensive (only AWS sells p3dn.24xlarge, and it costs $31.212 per hour) + +Additionally, training was brittle. While the reward goes up, we find it difficult to reproduce the “smart copier” behavior reported by Ziegler et al. (2019) . Below are some sample outputs — clearly, the agent overfits somehow. + +``` +QUERY: The modern take on Death & Taxes still plays it. I know as a Tron player I find Mindcensor +pretty frustrating. ⏎ However, the answer to your question is probably that the decks that +are currently taking up the lion's share of the metagame don't really care about its effect. +It has a chance to act as an instant speed Stone Rain + Gut Shot in response to a fetch +crack, but after that it's a 2/1 flyer that dies to all relevant removal. ⏎ It's especially +powerful against Tron since so much of the deck's consistency is in its tutor effects -- +Sylvan Scrying, Expedition Map, and Eye of Ugin. This combined with Pyroclasm and Oblivion +Stone being the only creature removal the deck has available to it before Karn make +Mindcensor fairly potent. ⏎ I'm having trouble thinking of any other popular decks (ignoring +the interaction with fetches) that Mindcensor really sticks it to, however. + +PPO GENERATED RESPONSE: Mindcensor is a very good card, but it's not really worth playing.<|endoftext|>Image +caption The family was forced from their home after they were not allowed to leave the +house⏎⏎A family of four has been jailed for five years after admitting breaching home rule +by refusing to leave their home.⏎ + +reward: -0.48 +kl: +11.37 +total reward: -0.82 + +QUERY: To make a long story short, my SO (27/m) and I live in my hometown. We've been together for +4 years. His family lives across the country. A lot of the time we spend with my parents but +besides that, we aren't very social. His family not being here doesn't make me feel like I +can't be with mine. I just don't like to. ⏎ I find my extended family (so not my parents) +extremely irritating. Especially the ones around my age. They're dramatic and rude. Every +family event is some Jerry Springer episode I don't want to participate in. I don't have any +issues with anyone in particular. I just find them annoying. I'm never uncivil or openly +telling them how I feel. If there's get togethers, I tend to avoid them. I don't ask anyone +to hang out or come over. ⏎ My SO and I have experienced multiple miscarriages due to +unknown reasons. This has made me very sensitive and a bit of a loaner when it comes to the +girls in my family my age. I'm not jealous that my family has children. I just don't seem to +be able to relate to them. When I explained our issues, they reacted very unsympathetic and +just weren't very supportive. I guess that was the time that I started to pull away and not +try to keep up any relationships. ⏎ My SO and I are pregnant, half way! This is something +we're so joyed about. However, when we told everyone, they just didn't really give a +reaction. No congrats or that they were happy for us. Since then, I completely pulled away. +I talked to my parents about this and I just don't feel the need to contact anyone. The +reason I'm posting is because I discussed deleting some relatives off social media with my +mom and she was telling me how it's really rude. I don't think it is. We don't talk. There's +no relationship. Honestly, I think they're just friends with me for gossip purposes. A few +times I tried reaching out to some of my cousins with no reply. ⏎ I guess what I'm asking is +if I should be trying harder to establish something with my family or if it's perfectly OK +just to delete them and not talk to them again? Our primary contact is Facebook even though +we live in the same town. + +PPO GENERATED RESPONSE: We're dating now. We've met, we're dating. We're dating now. We're dating now. We're dating +now. We're dating now. We're dating now.⏎⏎We've met⏎⏎We're dating now.⏎⏎We're dating now.⏎ + +reward: +2.09 +kl: +19.45 +total reward: +1.50 + +QUERY: This tower totally decided to screw me over on the final stage. So a bit of background I +don't have the strongest squad but pretty much I have: ⏎ Guardian Loch (Sacred Jewel) ⏎ +Anima Tiara (Flesh Armor) ⏎ Lord Cayena (sacred Jewel) ⏎ Oracle Twins (sacred Jewel) ⏎ +Guardian Logan (Burny soul stone) ⏎ and was using a friends maxed michelle. ⏎ So throughout +the earlier stages I've seen this guy about 5 times but have never caught him. So I had a +little trouble with stage 9 but felt like with the additional 50% atk boost and bb from +michelle I could do it. Also on stage 9 the reason I died was because I didn't take any +stimulants and this guy managed to paralyze 3 members of my squad. So I replaced mega cures +for stimulants. ⏎ Round 3 I meet him. He gets a lucky crit on my caynea and takes her down. +I have to use my one revive, take him down and capture him. I'm pretty excited. ⏎ Round 7 +I see the warning danger sign. I face him again and manage to take him down without any of +my units dying no capture. At this point I'm starting to get low on healing items. Also I've +only heard of meeting him twice on a run. But then I checked this thread and saw that he +always appears on the last stage. I thought that maybe it glitched and he wouldn't appear on +the final round. ⏎ Nope. ⏎ This guy shows up on the final round showing up 3 times in one +run. Worst part was I manged to get him down to a sliver of health so that as long as one +unit hit him I would've defeated him with the next turn. This was also the turn he decided +to use soul desecration twice. I definitely felt like my soul got desecrated watching all my +units die. So at this point it was give up the captured unit and try all over again (knowing +my luck never capture him again) or gem for the last hit. I was really looking forward to +getting the gem for completion (f2p gotta scrap up everything right?) but I let it go. + +PPO GENERATED RESPONSE: ⏎⏎When you capture a unit try to kill it.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill +them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎⏎Kill them.⏎ + +reward: +0.79 +kl: +10.82 +total reward: +0.46 +``` + +# Conclusion + +In this work, we took a deep dive into OpenAI’s original RLHF codebase and compiled a list of its implementation details. We also created a minimal base which reproduces the same learning curves as OpenAI’s original RLHF codebase, when the dataset and hyperparameters are controlled. Furthermore, we identify surprising implementation details such as the adam optimizer’s setting which causes aggressive updates in early RLHF training. + + diff --git a/assets/bibliography/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo.bib b/assets/bibliography/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo.bib new file mode 100644 index 00000000..2e589b94 --- /dev/null +++ b/assets/bibliography/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo.bib @@ -0,0 +1,60 @@ +@article{Ziegler2019fine, + title={Fine-tuning language models from human preferences}, + author={Ziegler, Daniel M and Stiennon, Nisan and Wu, Jeffrey and Brown, Tom B and Radford, Alec and Amodei, Dario and Christiano, Paul and Irving, Geoffrey}, + journal={arXiv preprint arXiv:1909.08593}, + year={2019} +} + +@inproceedings{Shengyi2022the37implementation, + author = {Huang, Shengyi and Dossa, Rousslan Fernand Julien and Raffin, Antonin and Kanervisto, Anssi and Wang, Weixun}, + title = {The 37 Implementation Details of Proximal Policy Optimization}, + booktitle = {ICLR Blog Track}, + year = {2022}, + note = {https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/}, + url = {https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/} +} + +@inproceedings{Wolf2020transformers, + title = "Transformers: State-of-the-Art Natural Language Processing", + author = "Thomas Wolf and Lysandre Debut and Victor Sanh and Julien Chaumond and Clement Delangue and Anthony Moi and Pierric Cistac and Tim Rault and Rémi Louf and Morgan Funtowicz and Joe Davison and Sam Shleifer and Patrick von Platen and Clara Ma and Yacine Jernite and Julien Plu and Canwen Xu and Teven Le Scao and Sylvain Gugger and Mariama Drame and Quentin Lhoest and Alexander M. Rush", + booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations", + month = oct, + year = "2020", + address = "Online", + publisher = "Association for Computational Linguistics", + url = "https://www.aclweb.org/anthology/2020.emnlp-demos.6", + pages = "38--45" +} + +@article{Paszke2019pytorch, + title={Pytorch: An imperative style, high-performance deep learning library}, + author={Paszke, Adam and Gross, Sam and Massa, Francisco and Lerer, Adam and Bradbury, James and Chanan, Gregory and Killeen, Trevor and Lin, Zeming and Gimelshein, Natalia and Antiga, Luca and others}, + journal={Advances in neural information processing systems}, + volume={32}, + year={2019} +} + +@software{Jax2018github, + author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, + title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, + url = {http://github.com/google/jax}, + version = {0.3.13}, + year = {2018}, +} + +@InProceedings{Kingma2015adam, + author = {Kingma, Diederik and Ba, Jimmy}, + booktitle = {International Conference on Learning Representations (ICLR)}, + title = {Adam: A Method for Stochastic Optimization}, + year = {2015}, + address = {San Diega, CA, USA}, +} + +@article{stiennon2020learning, + title={Learning to summarize with human feedback}, + author={Stiennon, Nisan and Ouyang, Long and Wu, Jeffrey and Ziegler, Daniel and Lowe, Ryan and Voss, Chelsea and Radford, Alec and Amodei, Dario and Christiano, Paul F}, + journal={Advances in Neural Information Processing Systems}, + volume={33}, + pages={3008--3021}, + year={2020} +} diff --git a/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/adam_gpt2.png b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/adam_gpt2.png new file mode 100644 index 00000000..8d514aed Binary files /dev/null and b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/adam_gpt2.png differ diff --git a/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/adam_gpt2_xl.png b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/adam_gpt2_xl.png new file mode 100644 index 00000000..804f1dbc Binary files /dev/null and b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/adam_gpt2_xl.png differ diff --git a/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/curve-matching.png b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/curve-matching.png new file mode 100644 index 00000000..246f1e71 Binary files /dev/null and b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/curve-matching.png differ diff --git a/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/curve-matching_all.png b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/curve-matching_all.png new file mode 100644 index 00000000..17b95a04 Binary files /dev/null and b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/curve-matching_all.png differ diff --git a/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/descriptiveness-samples.png b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/descriptiveness-samples.png new file mode 100644 index 00000000..7146d192 Binary files /dev/null and b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/descriptiveness-samples.png differ diff --git a/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/norma_const_comparison.png b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/norma_const_comparison.png new file mode 100644 index 00000000..bc6b8b41 Binary files /dev/null and b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/norma_const_comparison.png differ diff --git a/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/tldr1.png b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/tldr1.png new file mode 100644 index 00000000..9a7d490f Binary files /dev/null and b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/tldr1.png differ diff --git a/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/tldr2.png b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/tldr2.png new file mode 100644 index 00000000..8dd04acc Binary files /dev/null and b/assets/img/2024-05-07-the-n-implementation-details-of-rlhf-with-ppo/tldr2.png differ