Skip to content

Implementation for course project of PKU 2023 NLPDL

Notifications You must be signed in to change notification settings

muzhancun/MineCraftGPT

Repository files navigation

MineCraftGPT

Implementation of course project for PKU 2023 NLPDL.

Requirements

conda install --file requirements.txt

And install the trl package by

git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .

Datasets

We use the datasets provided by MineDojo.

To get the reddit dataset, you first need to follow instructions on PRAW to get your own reddit client_id, client_secret and usr_agent and fill them in get_reddit_data and preprocess_reddit_data in utils.py.

Then run utils.py to get the wiki and reddit dataset.

Training

To train the models for wiki generation, run

python wiki_train.py

after changing the wandb project name and model path in it.

To train the models for reddit reply, run

python reddit_train.py

If you want to try RLHF on the reddit dataset, first run

python reward_model.py

to train the reward model, then run

python PPO.py

to train the RLHF model.

Evaluation

Access to the Google Gemini model is required to run the evaluation. Please follow the instructions on Gemini.

Run python elo_rating.py to see the result.

About

Implementation for course project of PKU 2023 NLPDL

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages