My implementation of a small GPT language model in PyTorch
Given a sequence of tokens, MyGPT predicts the next token
vocab = ["cat", "hat", "the", "in"]
mygpt = MyGPT(vocab)
prediction = mygpt(["the", "cat", "in", "the"])
# prediction = "hat"
Longer text can be generated by appending the predicted token to the prompt, then feeding the longer prompt back into MyGPT
num_tokens_to_predict = 5
prompt = ["the", "cat", "in", "the"]
for _ in range(num_tokens_to_predict):
prediction = mygpt(prompt)
prompt.append(prediction)
print(prompt)
# ["the", "cat", "in", "the", "hat"]
# ["the", "cat", "in", "the", "hat", "is"]
# ["the", "cat", "in", "the", "hat", "is", "a"]
# ["the", "cat", "in", "the", "hat", "is", "a", "great"]
# ["the", "cat", "in", "the", "hat", "is", "a", "great", "book"]
See the mygpt notebook for a deeper explanation of how MyGPT works
See the output folder for sample text a ~825k parameter MyGPT produced after ~5 minutes of training on an M1 Pro CPU. Predictions were made at the character level
- transformer.py is an implementation of the Transformer decoder architecture described in the Attention Is All You Need paper
- It can be thought of as a mathematical function that transforms an input sequence of tokens into a prediction of the next token
- pretrain.py performs a training loop that improves MyGPT's ability to predict tokens
- Output text is periodically sampled during training to help visualize predictive ability improvements over time
- The weights of the pre-trained model are saved to the weights folder and can be loaded for inference
- generate.py takes a prompt, feeds it into a pre-trained MyGPT model, and generates text
PyTorch is the only dependency
$ python3 -m venv venv
$ source venv/bin/activate
$ pip install .
$ python -m venv venv
$ venv\Scripts\activate.bat
$ pip install .
$ pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116