Skip to content

Commit

Permalink
prm
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 25, 2023
1 parent 8fb95a8 commit cd108c3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ device = 0 if torch.cuda.is_available() else "cpu"

# Model initialization
prm_model = PRMModel(
model_name="lvwerra/gpt2-imdb-pos-v2",
ref_model_name="lvwerra/gpt2-imdb",
reward_model_name="lvwerra/distilbert-imdb",
device=device,
model_name="lvwerra/gpt2-imdb-pos-v2",
ref_model_name="lvwerra/gpt2-imdb",
reward_model_name="lvwerra/distilbert-imdb",
device=device,
)

# Generation arguments
gen_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": prm_model.tokenizer.eos_token_id,
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": prm_model.tokenizer.eos_token_id,
}
sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}

Expand Down
4 changes: 3 additions & 1 deletion process_supervision/prm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class PRMModel:
def __init__(self, model_name, ref_model_name, reward_model_name, device):
def __init__(self, model_name: str = "lvwerra/gpt2-imdb-pos-v2", ref_model_name: str = "lvwerra/gpt2-imdb", reward_model_name: str = "lvwerra/distilbert-imdb", device):
"""
Initialize the PRM model with specified models and tokenizer.
Expand All @@ -17,9 +17,11 @@ def __init__(self, model_name, ref_model_name, reward_model_name, device):
self.model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_name
).to(device)

self.ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
ref_model_name
).to(device)

self.reward_pipe = pipeline(
"sentiment-analysis", model=reward_model_name, device=device
)
Expand Down

0 comments on commit cd108c3

Please sign in to comment.