Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement BoN for training and eval #528

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f04446d
Implementing support for dense rewards
Dahoas Jun 5, 2023
13a01fc
added "num_return_sequences" param which corresponds to n in Best-of-…
SharathRaparthy Jun 16, 2023
5421a73
updates to "num_return_sequences" param
SharathRaparthy Jun 16, 2023
2f3ac28
BoN implementation
SharathRaparthy Jun 16, 2023
2f1dace
Changed back to default.
SharathRaparthy Jun 19, 2023
f58170d
TopK sampling instead of Top1
SharathRaparthy Jun 19, 2023
be8bc1a
summed along dim=1
SharathRaparthy Jun 26, 2023
608d812
Generating samples in chunks
SharathRaparthy Jun 26, 2023
d8557e7
added gen_chunk_size parameter
SharathRaparthy Jun 26, 2023
8ef9c36
chunking in forward prop
SharathRaparthy Jun 26, 2023
4c1d82d
chunking generations in train and eval
SharathRaparthy Jun 26, 2023
ecd5107
Implementing support for dense rewards
Dahoas Jun 5, 2023
4071604
Fix distributed ref_mean, ref_var bug for dense rewards
Dahoas Jun 15, 2023
5f41413
Make generation respect max seq length
Dahoas Jun 23, 2023
22ae83f
Make experience before first round of training
Dahoas Jun 23, 2023
7d0a4be
Refactoring .generate/.generate_eval
Dahoas Jun 27, 2023
b79dd19
Fix BoN metric support
Dahoas Jun 29, 2023
cb49dc5
Enforce chunk_size param for eval generation when present
Dahoas Jul 3, 2023
e290412
Fix: Don't shuffle prompt dataset
Dahoas Jul 4, 2023
391d04c
Move inputs to device
Dahoas Jul 18, 2023
8de84e4
Fix style
Dahoas Jul 18, 2023
3d7e0d5
Fix chunked generation
Dahoas Jul 21, 2023
1fda0ce
fix(accelerate_base_trainer): order of keyword arguments
maxreciprocate Jul 22, 2023
4ac1707
Merging main
Dahoas Aug 7, 2023
de3d854
Merge branch 'BoN' of https://github.com/CarperAI/trlx into BoN
Dahoas Aug 7, 2023
3ce3c2b
Removing old example
Dahoas Aug 7, 2023
2635de5
Fix: remove extraneous method args
Dahoas Aug 7, 2023
1be2c3c
Fix: Always set generate_experience_kwargs
Dahoas Aug 7, 2023
3cba0db
Fix: Remove mask from RunningMoments update call
Dahoas Aug 7, 2023
0cb91c4
Fix: style
Dahoas Aug 7, 2023
cc92911
Fix: rename 'gen_chunk_size' to 'chunk_size'
Dahoas Aug 7, 2023
4297f98
Fix: generated samples padding
Dahoas Aug 7, 2023
36f06af
Remove prints
Dahoas Aug 7, 2023
a2980dd
Rename 'num_train_sequences' to 'num_topk_samples'
Dahoas Aug 21, 2023
3d5a639
Address nits
Dahoas Aug 21, 2023
87837b6
Fix: style
Dahoas Aug 21, 2023
ed93be8
Set 'num_return_sequences' to 1 by default
Dahoas Aug 21, 2023
24925c8
Fix: typo
Dahoas Aug 21, 2023
a022d3f
Merge branch 'main' into BoN
maxreciprocate Sep 1, 2023
9680c9f
Merge branch 'main' into bon-x
maxreciprocate Sep 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
TopK sampling instead of Top1
  • Loading branch information
SharathRaparthy authored and Dahoas committed Jul 18, 2023
commit f58170dc3022f1c21f7bd53c5c88882984240751
1 change: 1 addition & 0 deletions trlx/data/default_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def default_ppo_config():
ref_std=None,
cliprange_reward=10,
num_return_sequences=1,
num_train_sequences=1,
gen_kwargs=dict(
max_new_tokens=40,
top_k=0,
Expand Down
1 change: 1 addition & 0 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class PPOConfig(MethodConfig):
cliprange_reward: float
gen_kwargs: dict
num_return_sequences: int
num_train_sequences: int
gen_experience_kwargs: Optional[dict] = None

def get_advantages_and_returns(
Expand Down
21 changes: 9 additions & 12 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
else:
scores = all_scores[0].clone().detach()
# Best-of-N Sampling.
max_score_indices = self.get_max_indices(scores, self.config.method.num_return_sequences, device)
scores = scores.index_select(0, max_score_indices)
samples = samples.index_select(0, max_score_indices)
prompt_tensors = prompt_tensors.index_select(0, max_score_indices)
train_indices = self.get_topk_indices(input_tensor=scores, window_size=self.config.method.num_return_sequences,k=self.config.method.num_train_sequences, device=device)
scores = scores.index_select(0, train_indices)
samples = samples.index_select(0, train_indices)
prompt_tensors = prompt_tensors.index_select(0, train_indices)

str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)

Expand Down Expand Up @@ -514,14 +514,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
self.push_to_store(ppo_rl_elements)

@staticmethod
def get_max_indices(input_tensor, window_size, device):
def get_topk_indices(input_tensor, window_size: int, k: int, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe docstring should be added specifying that this isn't the same as regular topk but rather a topk overw window_size

# Use unfold to create the sliding windows
unfolded = input_tensor.unfold(0, window_size, window_size)

# Find the max values and indices along the unfolded dimension
values, indices = unfolded.max(dim=2)

# Find the topk values and indices along the unfolded dimension
_, indices = torch.topk(unfolded, k, dim=2)
# Adjust indices to be relative to original tensor
indices += torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to(device).unsqueeze(1)

return indices.squeeze()
indices = indices.squeeze(1) + torch.arange(0, input_tensor.size(0) - window_size + 1, window_size).to(device).unsqueeze(1)
return indices.reshape(-1)