Skip to content

Commit

Permalink
Merge pull request togethercomputer#78 from csris/csris/pythia-support
Browse files Browse the repository at this point in the history
Add Pythia Support
  • Loading branch information
csris authored Mar 30, 2023
2 parents a71963d + 1ed1d6e commit 5180a70
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 69 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ dmypy.json
/data/OIG/files/
/data/wikipedia-3sentence-level-retrieval-index/files/
/pretrained/GPT-NeoX-20B/EleutherAI_gpt-neox-20b/
/pretrained/Pythia-6.9B-deduped/EleutherAI_pythia-6.9b-deduped/

# ignore training output
/model_ckpts/
Expand Down
144 changes: 77 additions & 67 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# OpenChatKit

OpenChatKit provides a powerful, open-source base to create both specialized and general purpose chatbots for various applications. The kit includes an instruction-tuned 20 billion parameter language model, a 6 billion parameter moderation model, and an extensible retrieval system for including up-to-date responses from custom repositories. It was trained on the OIG-43M training dataset, which was a collaboration between [Together](https://www.together.xyz/), [LAION](https://laion.ai), and [Ontocord.ai](https://ontocord.ai). Much more than a model release, this is the beginning of an open source project. We are releasing a set of tools and processes for ongoing improvement with community contributions.
OpenChatKit provides a powerful, open-source base to create both specialized and general purpose chatbots for various applications. The kit includes an instruction-tuned language models, a moderation model, and an extensible retrieval system for including up-to-date responses from custom repositories. OpenChatKit models were trained on the OIG-43M training dataset, which was a collaboration between [Together](https://www.together.xyz/), [LAION](https://laion.ai), and [Ontocord.ai](https://ontocord.ai).

In this repo, you'll find code for:
- Training an OpenChatKit model
Expand All @@ -9,16 +9,15 @@ In this repo, you'll find code for:

# Contents

- [Requirements](#requirements)
- [Pre-trained Weights](#pre-trained-weights)
- [Datasets](#datasets)
* [Data Contributions](#data-contributions)
- [Pretrained Base Model](#pretrained-base-model)
- [Training and Finetuning](#training-and-finetuning)
- [Getting Started](#getting-started)
* [Requirements](#requirements)
* [Chatting with Pythia-Chat-Base-7B](#chatting-with-pythia-chat-base-7b)
- [Reproducing Pythia-Chat-Base-7B](#reproducing-pythia-chat-base-7b)
* [Downloading training data and the base model](#downloading-training-data-and-the-base-model)
* [(Optional) 8bit Adam](#optional-8bit-adam)
* [Train GPT-NeoX-Chat-Base-20B](#train-gpt-neox-chat-base-20b)
- [Converting Weights to Huggingface Format](#converting-weights-to-huggingface-format)
- [Inference](#inference)
* [Training the model](#training-the-model)
* [Converting weights to Huggingface format](#converting-weights-to-huggingface-format)
* [Testing the new model](#testing-the-new-model)
- [Monitoring](#monitoring)
* [Loguru](#loguru)
* [Weights & Biases](#weights--biases)
Expand All @@ -27,7 +26,15 @@ In this repo, you'll find code for:
- [Citing OpenChatKit](#citing-openchatkit)
- [Acknowledgements](#acknowledgements)

# Requirements
# Getting Started

In this tutorial, you will download Pythia-Chat-Base-7B, an instruction-tuned language model, and run some some inference requests against it using a command-line tool.

Pythia-Chat-Base-7B is a 7B-parameter fine-tuned variant of Pythia-6.9B-deduped from Eleuther AI. Pre-trained weights for this model are available on Huggingface as [togethercomputer/Pythia-Chat-Base-7B](https://huggingface.co/togethercomputer/Pythia-Chat-Base-7B) under an Apache 2.0 license.

More details can be found on the model card for [Pythia-Chat-Base-7B](https://huggingface.co/togethercomputer/Pythia-Chat-Base-7B) on Huggingface.

## Requirements

Before you begin, you need to install PyTorch and other dependencies.

Expand All @@ -49,6 +56,9 @@ conda install mamba -n base -c conda-forge

5. Create an environment called OpenChatKit using the `environment.yml` file at the root of this repo.

> **Note**
> Use `mamba` to create the environment. It's **much** faster than using `conda`.
```shell
mamba env create -f environment.yml
```
Expand All @@ -59,46 +69,62 @@ mamba env create -f environment.yml
conda activate OpenChatKit
```

# Pre-trained Weights
## Chatting with Pythia-Chat-Base-7B

GPT-NeoXT-Chat-Base-20B is a 20B-parameter variant of GPT-NeoX, fine-tuned on conversational datasets. We are releasing pre-trained weights for this model as [togethercomputer/GPT-NeoXT-Chat-Base-20B](https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B) on Huggingface.
To help you try the model, [`inference/bot.py`](inference/bot.py) is a simple command-line test harness that provides a shell inferface enabling you to chat with the model. Simply enter text at the prompt and the model replies. The test harness also maintains conversation history to provide the model with context.

More details can be found on the model card for [GPT-NeoXT-Chat-Base-20B](https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B) on Huggingface.

# Datasets
Start the bot by calling `bot.py` from the root for the repo.

The chat model was trained on the [OIG](https://huggingface.co/datasets/laion/OIG) dataset built by [LAION](https://laion.ai/), [Together](https://www.together.xyz/), and [Ontocord.ai](https://www.ontocord.ai/). To download the dataset from Huggingface run the command below from the root of the repo.
```shell
python inference/bot.py --model togethercomputer/Pythia-Chat-Base-7B
```

Loading the model can take some time, but once it's loaded, you are greeted with a prompt. Say hello.

```shell
python data/OIG/prepare.py
$ python inference/bot.py
Loading /home/csris/src/github.com/togethercomputer/OpenChatKit/inference/../huggingface_models/GPT-NeoXT-Chat-Base-20B to cuda:1...
Welcome to OpenChatKit shell. Type /help or /? to list commands.

>>> Hello.
Hello human.

>>>
```

Once the command completes, the data will be in the `data/OIG/files` directory.
Enter additional queries at the prompt, and the model replies. Under the covers, the shell is forming a prompt with all previous queries and passes that to the model to generate more text.

## Data Contributions
The shell also supports additional commands to inspect hyperparamters, the full prompt, and more. Commands are prefixed with a `/`.

You can help make this chat model better by contributing data! See the [OpenDataHub](https://github.com/togethercomputer/OpenDataHub) repo for more details.
> **Note**
> The `/quit` command exits the shell.
# Pretrained Base Model
Please see [the inference README](inference/README.md) for more details about arguments, running on multiple/specific GPUs, and running on consumer hardware.

As mentioned above, the chat model is a fine-tuned variant of GPT-NeoX-20B from Eleuther AI. To download GPT-NeoX-20B and prepare it for fine tuning, run this command from the root of the repo.
# Reproducing Pythia-Chat-Base-7B

```shell
python pretrained/GPT-NeoX-20B/prepare.py
```
This tutorial walks through reproducing the Pythia-Chat-Base-7B model by fine-tuning Eleuther AI's Pythia-6.9B-deduped model using the OIG dataset.

The weights for this model will be in the `pretrained/GPT-NeoX-20B/EleutherAI_gpt-neox-20b`.
## Downloading training data and the base model

In case you want to fine-tune other gpt-neox models, e.g. [the Pythia model suite](https://huggingface.co/models?sort=downloads&search=pythia), you can specify the HF model name, for example:
The chat model was trained on the [OIG](https://huggingface.co/datasets/laion/OIG) dataset built by [LAION](https://laion.ai/), [Together](https://www.together.xyz/), and [Ontocord.ai](https://www.ontocord.ai/). To download the dataset from Huggingface run the command below from the root of the repo.

```shell
python pretrained/GPT-NeoX-20B/prepare.py --model-name EleutherAI/pythia-6.9b-deduped
python data/OIG/prepare.py
```
> **Note**
> You can help make this chat model better by contributing data! See the [OpenDataHub](https://github.com/togethercomputer/OpenDataHub) repo for more details.
Once the command completes, the data will be in the `data/OIG/files` directory.

And the weights for this model will be in the `pretrained/GPT-NeoX-20B/EleutherAI_pythia-6.9b-deduped`.
Pythia-Chat-Base-7B is a fine-tuned variant of Pythia-6.9B-deduped from Eleuther AI. To download the model and prepare it for fine tuning, run this command from the root of the repo.

```shell
python pretrained/Pythia-6.9B-deduped/prepare.py
```

# Training and Finetuning
The weights for this model will be in the `pretrained/Pythia-6.9B-deduped/EleutherAI_pythia-6.9b-deduped` directory.

## (Optional) 8bit Adam

Expand All @@ -108,66 +134,48 @@ To use 8bit-adam during training, install the `bitsandbytes` package.
pip install bitsandbytes # optional, to use 8bit-adam
```

## Train GPT-NeoX-Chat-Base-20B
## Training the model

The `training/finetune_GPT-NeoXT-Chat-Base-20B.sh` script configures and runs the training loop. After downloading the dataset and the base model, run:
The `training/finetune_Pythia-Chat-Base-7B.sh` script configures and runs the training loop. After downloading the dataset and the base model, run:

```shell
bash training/finetune_GPT-NeoXT-Chat-Base-20B.sh
bash training/finetune_Pythia-Chat-Base-7B.sh
```

The script launches 8 processes with a pipeline-parallel degree of 8 and a data-parallel degree of 1.

As the training loop runs, checkpoints are saved to the `model_ckpts` directory at the root of the repo.

Please see [the training README](training/README.md) for more details about customizing the training run.

The `training/finetune_Pythia-Chat-Base-7B.sh` script is another example to fine-tune a 7B pythia (gpt-neox) model. The script launches 8 processes with a pipeline-parallel degree of 4 and a data-parallel degree of 2.

# Converting Weights to Huggingface Format
## Converting weights to Huggingface format

Before you can use this model to perform inference, it must be converted to the Huggingface format. Run this command from the root of the repo to do so.

```shell
mkdir huggingface_models \
&& python tools/convert_to_hf_gptneox.py \
--ckpt-path model_ckpts/GPT-Neo-XT-Chat-Base-20B/checkpoint_100 \
--save-path huggingface_models/GPT-NeoXT-Chat-Base-20B \
--n-stages 8 \
--n-layer-per-stage 6 \
--config-name EleutherAI/pythia-6.9b-deduped \
--ckpt-path model_ckpts/Pythia-Chat-Base-7B/checkpoint_100 \
--save-path huggingface_models/Pythia-Chat-Base-7B \
--n-stages 4 \
--n-layer-per-stage 8 \
--fp16
```
where the `--fp16` flag will load and store models in fp16.

Make sure to replace `model_ckpts/GPT-Neo-XT-Chat-Base-20B/checkpoint_100` with the latest checkpoint in the `model_ckpts/GPT-Neo-XT-Chat-Base-20B` directory.
Make sure to replace `model_ckpts/Pythia-Chat-Base-7B/checkpoint_100` with the latest checkpoint in the `model_ckpts/Pythia-Chat-Base-7B` directory.

If you need to convert ckpts of other gpt-neox variants, make sure to specify the correct config name for your variant.
For example, if you want to convert a checkpoint fine-tuned from `EleutherAI/pythia-6.9b-deduped`, you should indicate this as a config name:
```shell
python tools/convert_to_hf_gptneox.py \
--config-name EleutherAI/pythia-6.9b-deduped \
--ckpt-path model_ckpts/Pythia-Chat-Base-7B/checkpoint_100 \
--save-path huggingface_models/Pythia-Chat-Base-7B \
--n-stages 4 \
--n-layer-per-stage 8 \
--fp16
```
## Testing the new model


# Inference

To help you test the model, we provide a simple test command line test harness to interact with the bot.
You can use the OpenChatKit Shell test harness to chat with the new model. From the root of the repo, run

```shell
python inference/bot.py
```

By default the script will load the model named GPT-NeoXT-Chat-Base-20B model under the `huggingface_models` directory, but you can override that behavior by specifying `--model`.

For example, if you want to load the base model from our Huggingface, repo, you can run the following command which downloads the weights from HuggingFace.
By default the script will load the model named Pythia-Chat-Base-7B under the `huggingface_models` directory, but you can override that behavior by specifying `--model`.

```shell
python inference/bot.py --model togethercomputer/GPT-NeoXT-Chat-Base-20B
python inference/bot.py --model ./huggingface_models/GPT-NeoXT-Chat-Base-20B
```

Once the model has loaded, enter text at the prompt and the model will reply.
Expand All @@ -178,13 +186,15 @@ Loading /home/csris/src/github.com/togethercomputer/OpenChatKit/inference/../hug
Welcome to OpenChatKit shell. Type /help or /? to list commands.

>>> Hello.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Hello human.

>>>
```

Commands are prefixed with a `/`, and the `/quit` command exits.
The shell also supports additional commands to inspect hyperparamters, the full prompt, and more. Commands are prefixed with a `/`.

> **Note**
> The `/quit` command exits the shell.
Please see [the inference README](inference/README.md) for more details about arguments, running on multiple/specific GPUs, and running on consumer hardware.

Expand All @@ -208,7 +218,8 @@ And set `--train-log-backend wandb` in the training script to enable logging to

# Experimental: Retrieval-Augmented Models

*Note: Retrieval is still experimental.*
> **Warning**
> Retrieval support is experimental.
The code in `/retrieval` implements a python package for querying a Faiss index of Wikipedia. The following steps explain how to use this index to augment queries in the test harness with context from the retriever.

Expand All @@ -234,7 +245,6 @@ Loading retrieval index...
Welcome to OpenChatKit shell. Type /help or /? to list commands.

>>> Where is Zurich?
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Where is Zurich?
Zurich is located in Switzerland.

Expand Down Expand Up @@ -281,6 +291,6 @@ For full terms, see the LICENSE file. If you have any questions, comments, or co

# Acknowledgements

Our model is a fine-tuned version of [gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b), a large language model trained by [Eleuther AI](https://www.eleuther.ai). We evaluated our model on [HELM](https://crfm.stanford.edu/helm/latest/) provided by the [Center for Research on Foundation Models](https://crfm.stanford.edu). And we collaborated with both [CRFM](https://crfm.stanford.edu) and [HazyResearch](http://hazyresearch.stanford.edu) at Stanford to build this model.
Our models are fine-tuned versions of large language models trained by [Eleuther AI](https://www.eleuther.ai). We evaluated our model on [HELM](https://crfm.stanford.edu/helm/latest/) provided by the [Center for Research on Foundation Models](https://crfm.stanford.edu). And we collaborated with both [CRFM](https://crfm.stanford.edu) and [HazyResearch](http://hazyresearch.stanford.edu) at Stanford to build this model.

We collaborated with [LAION](https://laion.ai/) and [Ontocord.ai](https://www.ontocord.ai/) to build the training data used to fine tune this model.
2 changes: 1 addition & 1 deletion inference/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def main():
)
parser.add_argument(
'--model',
default=f"{INFERENCE_DIR}/../huggingface_models/GPT-NeoXT-Chat-Base-20B",
default=f"{INFERENCE_DIR}/../huggingface_models/Pythia-Chat-Base-7B",
help='name/path of the model'
)
parser.add_argument(
Expand Down
56 changes: 56 additions & 0 deletions pretrained/Pythia-6.9B-deduped/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

DIR = os.path.dirname(os.path.abspath(__file__))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert HF checkpoints')
parser.add_argument('--model-name', type=str, default='EleutherAI/pythia-6.9b-deduped',
help='model-name')
parser.add_argument('--save-dir', type=str, default=DIR,
help='model-name')
parser.add_argument('--offload-dir', type=str, default=None,
help='directory to offload from memory')
args = parser.parse_args()

if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
save_path = os.path.join(args.save_dir, args.model_name.replace('/', '_'))
if not os.path.exists(save_path):
os.mkdir(save_path)

print('loading model from HF...')
config = AutoConfig.from_pretrained(args.model_name)
config.save_pretrained(save_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.save_pretrained(save_path)
# offload model from memory to disk if offload-dir is specified
if args.offload_dir is not None:
if not os.path.exists(args.offload_dir):
os.mkdir(args.offload_dir)
model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, device_map="auto", offload_folder=args.offload_dir)
else:
model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16)
print('loaded model from HF...')

print('converting the embedding layer...')
item = {}
item['embed_in.weight'] = model.gpt_neox.embed_in.weight
torch.save(item, os.path.join(save_path, 'pytorch_embs.pt'))
print('converted the embedding layer.')

for i in range(len(model.gpt_neox.layers)):
print(f'converting the {i}-th transformer layer...')
torch.save(model.gpt_neox.layers[i].state_dict(), os.path.join(save_path, f'pytorch_{i}.pt'))
print(f'converted the {i}-th transformer layer.')

print('converting the lm_head layer...')
item = {}
item['embed_out.weight'] = model.embed_out.weight
item['final_layer_norm.weight'] = model.gpt_neox.final_layer_norm.weight
item['final_layer_norm.bias'] = model.gpt_neox.final_layer_norm.bias
torch.save(item, os.path.join(save_path, 'pytorch_lm_head.pt'))
print('converted the lm_head layer.')
2 changes: 1 addition & 1 deletion training/finetune_Pythia-Chat-Base-7B.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export MODEL_NAME=Pythia-Chat-Base-7B

export SHOW_DATA=0

BASE_MODEL="${DIR}/../pretrained/GPT-NeoX-20B/EleutherAI_pythia-6.9b-deduped/"
BASE_MODEL="${DIR}/../pretrained/Pythia-6.9B-deduped/EleutherAI_pythia-6.9b-deduped/"

CHECKPOINT_STEPS=100

Expand Down

0 comments on commit 5180a70

Please sign in to comment.