Implementation of course project for PKU 2023 NLPDL.
conda install --file requirements.txt
And install the trl package by
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .
We use the datasets provided by MineDojo.
To get the reddit dataset, you first need to follow instructions on PRAW to get your own reddit client_id
, client_secret
and usr_agent
and fill them in get_reddit_data
and preprocess_reddit_data
in utils.py
.
Then run utils.py
to get the wiki and reddit dataset.
To train the models for wiki generation, run
python wiki_train.py
after changing the wandb project name and model path in it.
To train the models for reddit reply, run
python reddit_train.py
If you want to try RLHF on the reddit dataset, first run
python reward_model.py
to train the reward model, then run
python PPO.py
to train the RLHF model.
Access to the Google Gemini model is required to run the evaluation. Please follow the instructions on Gemini.
Run python elo_rating.py
to see the result.