Skip to content
/ MyGPT Public

My implementation of a GPT language model in PyTorch

License

Notifications You must be signed in to change notification settings

dx-dtran/MyGPT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

70 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MyGPT

My implementation of a small GPT language model in PyTorch

Given a sequence of tokens, MyGPT predicts the next token

vocab = ["cat", "hat", "the", "in"]

mygpt = MyGPT(vocab)
prediction = mygpt(["the", "cat", "in", "the"])

# prediction = "hat"

Longer text can be generated by appending the predicted token to the prompt, then feeding the longer prompt back into MyGPT

num_tokens_to_predict = 5
prompt = ["the", "cat", "in", "the"]

for _ in range(num_tokens_to_predict):
    prediction = mygpt(prompt)
    prompt.append(prediction)
    print(prompt)

# ["the", "cat", "in", "the", "hat"]
# ["the", "cat", "in", "the", "hat", "is"]
# ["the", "cat", "in", "the", "hat", "is", "a"]
# ["the", "cat", "in", "the", "hat", "is", "a", "great"]
# ["the", "cat", "in", "the", "hat", "is", "a", "great", "book"]

See the mygpt notebook for a deeper explanation of how MyGPT works

See the output folder for sample text a ~825k parameter MyGPT produced after ~5 minutes of training on an M1 Pro CPU. Predictions were made at the character level

GPT?

  • transformer.py is an implementation of the Transformer decoder architecture described in the Attention Is All You Need paper
    • It can be thought of as a mathematical function that transforms an input sequence of tokens into a prediction of the next token
  • pretrain.py performs a training loop that improves MyGPT's ability to predict tokens
    • Output text is periodically sampled during training to help visualize predictive ability improvements over time
    • The weights of the pre-trained model are saved to the weights folder and can be loaded for inference
  • generate.py takes a prompt, feeds it into a pre-trained MyGPT model, and generates text

Installation

PyTorch is the only dependency

macOS or Linux

$ python3 -m venv venv
$ source venv/bin/activate
$ pip install .

Windows

$ python -m venv venv
$ venv\Scripts\activate.bat
$ pip install .

For GPU acceleration

$ pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116

About

My implementation of a GPT language model in PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published