Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feats: Add text enviroment examples #560

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update code
  • Loading branch information
Duy Phung committed Sep 11, 2023
commit 72344e472127f1b335a9357b6f2d3e1ceeff8cfc
57 changes: 36 additions & 21 deletions examples/ppo_tool_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,19 @@
def ppo_init_config():
return TRLConfig(
train=TrainConfig(
seq_length=512,
seq_length=768,
epochs=100,
total_steps=10000,
batch_size=32,
minibatch_size=2,
minibatch_size=1,
checkpoint_interval=10000,
eval_interval=100,
pipeline="PromptPipeline",
trainer="AcceleratePPOTrainer",
save_best=False,
save_best=True,
checkpoint_dir="/fsx/home-duyphung/trlx_checkpoints",
),
model=ModelConfig(model_path="codellama/CodeLlama-7b-Instruct-hf", num_layers_unfrozen=8),
model=ModelConfig(model_path="codellama/CodeLlama-7b-Instruct-hf", num_layers_unfrozen=12),
tokenizer=TokenizerConfig(tokenizer_path="codellama/CodeLlama-7b-Instruct-hf", truncation_side="right"),
optimizer=OptimizerConfig(
name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
Expand All @@ -58,7 +59,7 @@ def ppo_init_config():
horizon=10000,
gamma=1,
lam=0.95,
num_value_layers_unfrozen=4,
num_value_layers_unfrozen=8,
cliprange=0.2,
cliprange_value=0.2,
vf_coef=1,
Expand All @@ -75,8 +76,6 @@ def ppo_init_config():
),
)

import torch


def exact_match_reward(responses, answers=None):
"""Reward if generated response contains correct answer."""
Expand All @@ -90,21 +89,21 @@ def exact_match_reward(responses, answers=None):
answer = answer.strip()
reward += float(response == answer)
rewards.append(reward)
print("rewards: ", rewards)
print("responses: ", responses)
print("answers: ", answers)
# print("rewards: ", rewards)
# print("responses: ", responses)
# print("answers: ", answers)
return rewards


def create_reward_function(prompt):
tool_env = ToolEnvironment([load_tool("lvwerra/python-interpreter")], prompt, exact_match_reward)

def reward_fn(samples, prompts, original_output, **kwargs):
# for sample in samples:
# print("sample: ", sample)
# print("======================================")
# for sample in samples:
# print("sample: ", sample)
# print("======================================")
rewards = tool_env.get_reward(samples, **{"answers": original_output})
print("rewards: ", rewards)
# print("rewards: ", rewards)
return rewards

return reward_fn
Expand All @@ -124,17 +123,17 @@ def main(hparams={}):
ds = ds.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
ds = ds.select(range(1, len(ds))) # skip the first sample which is used in prompt

ds_test = load_dataset("gsm8k", "main", split="test").select(range(1, 100))
ds_test = load_dataset("gsm8k", "main", split="test").select(range(1, 500))
ds_test = ds_test.rename_columns({"question": "query"})
ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]})

df = ds.to_pandas()
df_test = ds_test.to_pandas()

few_shot_prompt = """\
Example of using a Python API to solve math questions. Make sure that you print the result of your solution.
Instruction: Using a Python API to solve math questions. Write function solution to solve the following questions, then "print(solution())" to output the result.

Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
Question: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?

<request><PythonInterpreter>
def solution():
Expand All @@ -150,19 +149,35 @@ def solution():

Result = 72 <submit>

Question: Michael loves to paint and sells his creations. He charges $100 for a large painting and $80 for a small painting. At his last art show, he sold 5 large paintings and 8 small paintings. How much did he earn in all?

<request><PythonInterpreter>
def solution():
price_large = 100
price_small = 80
paintings_large = 5
paintings_small = 8
total_large_price = price_large * paintings_large
total_small_price = price_small * paintings_small
total = total_large_price + total_small_price
result = total
return result
print(solution())
<call>1140<response>

Result = 1140 <submit>

"""

reward_fn = create_reward_function(few_shot_prompt)

generate_prompt = """\
{few_shot_prompt}
Q: {query}
Question: {query}

<request><PythonInterpreter>"""

df["query"] = df["query"].apply(
lambda x: generate_prompt.format(few_shot_prompt=few_shot_prompt, query=x)
)
df["query"] = df["query"].apply(lambda x: generate_prompt.format(few_shot_prompt=few_shot_prompt, query=x))
df_test["query"] = df_test["query"].apply(
lambda x: generate_prompt.format(few_shot_prompt=few_shot_prompt, query=x)
)
Expand Down