Skip to content

Commit

Permalink
[FEATS][PRModel -> PRM][Swarms]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 26, 2023
1 parent f99e919 commit 8532805
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 8 deletions.
12 changes: 11 additions & 1 deletion prm_example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import torch
from process_supervision.prm import PRM
from swarms.models import OpenAIChat
import os
from dotenv import load_dotenv

load_dotenv()

api_key = os.getenv("OPENAI_API_KEY")

# LLM initialization
llm = OpenAIChat(api_key=api_key)

device = 0 if torch.cuda.is_available() else "cpu"

# Model initialization
prm_model = PRMModel(
prm_model = PRM(
model_name="lvwerra/gpt2-imdb-pos-v2",
ref_model_name="lvwerra/gpt2-imdb",
reward_model_name="lvwerra/distilbert-imdb",
Expand Down
29 changes: 23 additions & 6 deletions process_supervision/prm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from typing import Any, Dict, List

import torch
from transformers import AutoTokenizer, pipeline
from trl import AutoModelForCausalLMWithValueHead
from typing import List, Any

class PRMModel:
def __init__(self, model_name: str = "lvwerra/gpt2-imdb-pos-v2", ref_model_name: str = "lvwerra/gpt2-imdb", reward_model_name: str = "lvwerra/distilbert-imdb", device: torch.device):

class PRM:
def __init__(
self,
model_name: str = "lvwerra/gpt2-imdb-pos-v2",
ref_model_name: str = "lvwerra/gpt2-imdb",
reward_model_name: str = "lvwerra/distilbert-imdb",
device=None,
):
"""
Initialize the PRM model with specified models and tokenizer.
Expand All @@ -14,6 +22,11 @@ def __init__(self, model_name: str = "lvwerra/gpt2-imdb-pos-v2", ref_model_name:
reward_model_name (str): Name of the reward model.
device (int or str): Device to run the model on ('cpu' or 'cuda').
"""
self.model_name = model_name
self.ref_model_name = ref_model_name
self.reward_model_name = reward_model_name
self.device = device

self.model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_name
).to(device)
Expand All @@ -28,7 +41,9 @@ def __init__(self, model_name: str = "lvwerra/gpt2-imdb-pos-v2", ref_model_name:
self.tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token

def generate_responses(self, queries: List[str], gen_len: int, gen_kwargs: Dict[str, Any]) -> List[str]:
def generate_responses(
self, queries: List[str], gen_len: int, gen_kwargs: Dict[str, Any]
) -> List[str]:
"""
Generate responses for a batch of queries.
Expand All @@ -43,7 +58,7 @@ def generate_responses(self, queries: List[str], gen_len: int, gen_kwargs: Dict[
responses = []
for query in queries:
input_ids = self.tokenizer.encode(query, return_tensors="pt").to(
device
self.device
)
output_ids = self.model.generate(
input_ids, max_new_tokens=gen_len, **gen_kwargs
Expand All @@ -54,7 +69,9 @@ def generate_responses(self, queries: List[str], gen_len: int, gen_kwargs: Dict[
responses.append(response)
return responses

def score_responses(self, responses: List[str], sent_kwargs: Dict[str, Any]) -> List[float]:
def score_responses(
self, responses: List[str], sent_kwargs: Dict[str, Any]
) -> List[float]:
"""
Score a batch of responses using the reward pipeline.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ transformers = "*"
black = "*"
autopep8 = "*"
ruff = "*"
swarms = "*"

[tool.poetry.group.lint.dependencies]
ruff = "^0.0.249"
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ trl
transformers
black
autopep8
ruff
ruff
swarms

0 comments on commit 8532805

Please sign in to comment.