Skip to content

Commit

Permalink
add set_seed to __init__.py (huggingface#127)
Browse files Browse the repository at this point in the history
* add  to __init__.py

* update examples

* fix style
  • Loading branch information
lvwerra authored Feb 1, 2023
1 parent 078182a commit 3173ed2
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions examples/sentiment/scripts/gpt2-sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler, set_seed
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, set_seed
from trl.core import LengthSampler

########################################################################
# This is a fully working simple example to use trl with accelerate.
Expand Down
4 changes: 2 additions & 2 deletions examples/sentiment/scripts/t5-sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from trl.core import LengthSampler, set_seed
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead, set_seed
from trl.core import LengthSampler

########################################################################
# This is a fully working simple example to use trl with accelerate.
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def tearDownClass(cls):
pass

def setUp(self):

# model_id
model_id = "gpt2"

Expand Down
1 change: 1 addition & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__version__ = "0.2.2.dev0"

from .core import set_seed
from .models import (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
Expand Down

0 comments on commit 3173ed2

Please sign in to comment.