Skip to content

Commit

Permalink
RL
Browse files Browse the repository at this point in the history
  • Loading branch information
carson committed Feb 11, 2023
1 parent 3053995 commit b6e98d2
Show file tree
Hide file tree
Showing 7 changed files with 664 additions and 88 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,13 @@ and the changes you make to example_class_function will be available to you with
**Forward batching**: Since the models can be fairly big and we want to rollout large PPO batches this can lead to out-of-memory errors when doing the forward passes for text generation and sentiment analysis. We introduce the parameter `forward_batch_size` to split the forward passes into smaller batches. Although this hurts performance a little this is neglectible compared to the computations of the backward passes when optimizing the model. The same parameter is used in the `PPOTrainer` when doing forward passes. The `batch_size` should multiple of `forward_batch_size`.

# References and Credits
alot of the code is from https://github.com/lvwerra/trl

@misc{vonwerra2022trl,
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert},
title = {TRL: Transformer Reinforcement Learning},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/lvwerra/trl}}
}

7 changes: 3 additions & 4 deletions minichatgpt/lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .core import LengthSampler
from .languagemodels.modeling_value_head import AutoModelForCausalLMWithValueHead
from .trainer import PPOTrainer
from .processdata.collators import collator
from .processdata.build_dataset import build_dataset


Expand All @@ -18,6 +17,7 @@ def __init__(self,
self.config = config

def set_generation_config(self,
do_sample=True,
output_min_length = 4,
output_max_length = 16,
pad_token_id=50256,
Expand All @@ -28,7 +28,7 @@ def set_generation_config(self,
"min_length":-1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"do_sample": do_sample,
"pad_token_id": pad_token_id,
}

Expand Down Expand Up @@ -76,10 +76,9 @@ def init_ppo_trainer(self,
old_policy,
tokenizer,
dataset,
data_collator = collator,
):

self.ppo_trainer = PPOTrainer(config, new_policy, old_policy, tokenizer, dataset, data_collator)
self.ppo_trainer = PPOTrainer(config, new_policy, old_policy, tokenizer, dataset)

self.batches_per_epoch = len(self.ppo_trainer.dataloader)

Expand Down
2 changes: 1 addition & 1 deletion minichatgpt/processdata/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .collators import collator
from .collators import dataloader_data_collator
2 changes: 1 addition & 1 deletion minichatgpt/processdata/collators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@


def collator(data):
def dataloader_data_collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
12 changes: 9 additions & 3 deletions minichatgpt/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from ..languagemodels import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig
from ..processdata.collators import collator
from ..processdata.collators import dataloader_data_collator

MODEL_CARD_TEMPLATE = """---
license: apache-2.0
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
data_collator=collator,
dataloader_data_collator=dataloader_data_collator,
num_shared_layers: Optional[int] = None,
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
):
Expand Down Expand Up @@ -183,9 +183,15 @@ def __init__(
UserWarning,
)
self.dataset = dataset
self.dataloader_data_collator = dataloader_data_collator
self._signature_columns = None

if self.dataset is not None:
self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
self.dataloader = self.prepare_dataloader(
self.dataset,
self.dataloader_data_collator,
)

elif self.dataset is None and self.accelerator.num_processes > 1:
warnings.warn(
"No dataset is provided. In a multi-GPU setting, this will lead to an error. You should",
Expand Down
Loading

0 comments on commit b6e98d2

Please sign in to comment.