Skip to content

The nanoGPT-style implementation of RWKV Language Model - a RNN with GPT-level LLM performance.

License

Notifications You must be signed in to change notification settings

Smith42/nanoRWKV

 
 

Repository files navigation

nanoRWKV

nanoGPT

The nanoGPT-style implementation of RWKV Language Model - an RNN with GPT-level LLM performance.

This is a rewrite of RWKV-v4neo and HuggingFace Implementation that aims to create a clean code base of RWKV for head-to-head comparison with GPT-series, while keeping in line with the simplicity and practicality of nanoGPT. This single repository can be utilized for training both GPT and RWKV models.

nanoGPT

RWKV is essentially an RNN with unrivaled advantage when doing inference. Here we benchmark the speed and space occupation of RWKV, along with its Transformer counterpart (code could be found here). We could easily find:

  • single token generation latency of RWKV is an constant.
  • overall latency of RWKV is linear with respect to context length.
  • overall memory occupation of RWKV is an constant.

benchmark

Animation

Time Mixing

Channel Mixing

Table of Contents

We organize this project as follows:

  • Installation

    how to set up environment to kick off this project

  • Prerequisites:

    some concepts you should be familiar with

  • Tutorial

    a step-by-step tutorial of building RWKV in a jupyter notebook

  • Reproduction

    reproduce RWKV and GPT under relatively equitable conditions.

  • Generation

    using trained RWKV to do generation

  • To-do-List

    things remaining to be done (Welcome PR)

  • Reference

    some useful references about RWKV and Large Language Model

Installation

We would recommend using Conda to manage the environment.

conda create -n nanoRWKV python=3.8 
conda activate nanoRWKV
pip install torch numpy transformers datasets tiktoken wandb tqdm ninja
## replace * with your driver version if loading kernel failed
# conda install cuda -c nvidia/label/cuda-11.*.0

Prerequisites

Before kicking off this project, make sure you are familiar with the following concepts:

  • RNN: RNN stands for Recurrent Neural Network. It is a type of artificial neural network designed to work with sequential data or time-series data. Check this tutorial about RNN.
  • Transformer: A Transformer is a type of deep learning model introduced in the paper Attention is All You Need. It is specifically designed for handling sequential data, like natural language processing tasks, by using a mechanism called self-attention. Check this post to know more about Transformer.
  • LLM: LLM, short for Large Language Model, has taken the world by storm. Check this Awesome-LLM repo and State of GPT.
  • nanoGPT: the simplest, fastest repository for training/finetuning medium-sized GPTs by great Andrej Karpathy. Here you could find the code and the teaching video.
  • RWKV Language Model: an RNN with GPT-level LLM performance, which can also be directly trained like a GPT transformer (parallelizable). The model is created by an independent researcher Bo Peng. Get more information here.

Tutorial

We would present a step-by-step tutorial of building RWKV in a jupyter notebook.

Reproduction

After all set up, let's build RWKV - first tokenize the dataset (OpenWebText):

python data/openwebtext/prepare.py

Then train RWKV(130M) with 8*V100 32GB on one node using PyTorch Distributed Data Parallel (DDP) :

torchrun --standalone --nproc_per_node=8 train.py config/train_rwkv.py

For comparision, we also train a GPT-2 model(124M) on the same device with:

torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py

We got the results as follows (check this wandb project):

model params train loss val loss
GPT-2 124M 2.82 2.86
RWKV 130M 2.85 2.88

baselines

Existing OpenAI GPT-2 checkpoints and RWKV checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows:

python train.py config/eval_rwkv4_{169m|430m|1b5|3b|7b|14b}.py
python train.py config/eval_gpt2{|_medium|_large|_xl}.py

and observe the following losses on val set:

model RWKV GPT-2
parameters 169M 430M 1.5B 3B 7B 14B 124M 350M 774M 1.5B
val loss 3.11 2.79 2.54 2.42 2.32 2.23 3.11 2.82 2.66 2.56

Notice that both models are not trained in the openwebtext (RWKV in The Pile and OpenAI GPT-2 in private WebText), so they could be further improved due to dataset domain gap.

Generation

After training is done, we could use the following to ask LLM to do generation:

python sample.py \
    --init_from=$model \
    --start="What is the answer to life, the universe, and everything?" \
    --num_samples=1 --max_new_tokens=100

The $model above could be either GPT-seris or RWKV-series:

#For GPT-series:
model=gpt2/gpt2-xl/gpt2-medium/gpt2-large

#For RWKV-series:
model_type=169m/430m/1b5/3b
model=RWKV/rwkv-4-{model_type}-pile

To-do-list

This is not a done project and there are a lot remaining:

  • Doule check the correctness of the current implementation (need help).
  • benchmark generation speed and memory usage.
  • A detailed and thorough jupyter notebook tutorial about RWKV.
  • More code comment in modeling_rwkv.py.
  • RNN mode for inference [HF Implementation]
  • rescale parameters for inference [reference]
  • loading RWKV checkpoint for evaluation(may not comparable to GPT-2 due to different tokenizer)
  • test bf16 training (Since V100 doesn't support bf16, your sponsorship of A100 for testing bf16 would be greatly appreciated :) Thanks @Smith42 for verifying this.
  • maybe scale up a little bit with DeepSpeed? Not sure, since nanoGPT didn't do this.
  • keep in line with the original implementaion of RWKV optimization. [reference]
  • More analysis about RWKV in scaling_laws.ipynb, transformer_sizeing.ipynb

Reference

Here are some useful references (offering my sincerest gratitude):

About

The nanoGPT-style implementation of RWKV Language Model - a RNN with GPT-level LLM performance.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 88.4%
  • Cuda 9.2%
  • C++ 2.4%