Skip to content

Commit

Permalink
update reference model with policy within the spin and dpo trainers t…
Browse files Browse the repository at this point in the history
…hemselves, to ready for arbitrary ordering of fine tuning steps
  • Loading branch information
lucidrains committed Jan 31, 2024
1 parent a83249f commit 9673d9b
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 9 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@ sft_dataset = create_mock_dataset(100, lambda: (torch.randint(0, 256, (256,)), t
spin_trainer = SPINTrainer(
transformer,
max_seq_len = 16,
sft_dataset = sft_dataset,
spin_λ = 0.1,
checkpoint_every = 100
train_sft_dataset = sft_dataset,
checkpoint_every = 100,
spin_kwargs = dict(
λ = 0.1,
),
)

spin_trainer()
Expand Down
5 changes: 4 additions & 1 deletion self_rewarding_lm_pytorch/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,10 @@ def forward(
self,
train_self_reward_dataset: Optional[Dataset] = None
):
self.early_stopper.clear_early_checkpoint_folder()
self.model.update_reference_model_with_policy()

if exists(self.early_stopper):
self.early_stopper.clear_early_checkpoint_folder()

train_dataloader = self.train_dataloader

Expand Down
4 changes: 0 additions & 4 deletions self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,6 @@ def forward(

spin_trainer()

self.spin.update_reference_model_with_policy()

self.save(f'spin.{spin_cycle}.ckpt.pt', overwrite = overwrite_checkpoints)


Expand All @@ -816,8 +814,6 @@ def forward(

dpo_trainer(dpo_dataset_from_self_reward)

self.dpo.update_reference_model_with_policy()

self.save(f'self-reward.{iterate_num}.ckpt.pt', overwrite = overwrite_checkpoints)

self.print(f'self-reward training done')
2 changes: 2 additions & 0 deletions self_rewarding_lm_pytorch/spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ def forward(self, overwrite_checkpoints: bool = True):
Algorithm 1 - https://arxiv.org/abs/2401.01335v1
"""

self.model.update_reference_model_with_policy()

self.steps = 0
self.model.train()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'self-rewarding-lm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.38',
version = '0.0.39',
license='MIT',
description = 'Self Rewarding LM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 9673d9b

Please sign in to comment.