forked from callummcdougall/ARENA_2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9a20842
commit 3b8c912
Showing
18 changed files
with
3,507 additions
and
12 deletions.
There are no files selected for viewing
Binary file added
BIN
+16.1 KB
...er1_transformers/exercises/monthly_algorithmic_problems/november23_cumsum/cumsum_model.pt
Binary file not shown.
68 changes: 68 additions & 0 deletions
68
chapter1_transformers/exercises/monthly_algorithmic_problems/november23_cumsum/dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch as t | ||
from torch.utils.data import Dataset | ||
|
||
|
||
|
||
class CumsumDataset(Dataset): | ||
|
||
def __init__(self, size: int, max_value: int, seq_len: int, seed: int = 42, p: float = 0.7): | ||
''' | ||
Dataset for the cumulative sum problem. Contains tokens and labels. | ||
Tokens are integers between [-max_value, +max_value] inclusive. Labels are | ||
2 if the cumsum is strictly positive, 1 if it's zero, 0 if strictly negative. | ||
The argument `p` is the probability that the next value will be of opposite | ||
sign to the previous value. This is to make zeros a bit more common. | ||
''' | ||
|
||
self.vocab = [f"{i:+}" for i in range(-max_value, max_value+1)] | ||
self.vocab_out = ["neg", "zero", "pos"] | ||
self.size = size | ||
self.max_value = max_value | ||
t.manual_seed(seed) # for reproducible results | ||
|
||
# Generate our sequences, and labels | ||
self.toks = t.randint(-max_value, max_value+1, (size, seq_len)) | ||
|
||
for seq_pos in range(1, seq_len): | ||
prev_cumsum = self.toks[:, :seq_pos].sum(dim=1) | ||
# Get random 1s & 0s, where 1s have probability p | ||
p_rand_seed = t.rand(size) < p | ||
# If prev_cumsum is positive, we want next value to be negative with prob p | ||
next_value_is_negative = t.where(prev_cumsum > 0, p_rand_seed, ~p_rand_seed) | ||
# Get new random values, and flip them randomly | ||
new_values = t.randint(0, max_value+1, (size,)) | ||
new_values = t.where(next_value_is_negative, -new_values, new_values) | ||
# Add the new values to self.toks | ||
self.toks[:, seq_pos] = new_values | ||
|
||
# Create labels: 0 if neg, 1 if zero, 2 if pos | ||
toks_pos = (self.toks.cumsum(dim=1) > 0) | ||
toks_neg = (self.toks.cumsum(dim=1) < 0) | ||
self.labels = t.where(toks_pos, 2, t.where(toks_neg, 0, 1)) | ||
|
||
self.str_toks = [ | ||
[f"{tok:+}" for tok in toks] | ||
for toks in self.toks | ||
] | ||
str_labels_dict = {0: "neg", 1: "zero", 2: "pos"} | ||
self.str_labels = [ | ||
[str_labels_dict[label.item()] for label in labels] | ||
for labels in self.labels | ||
] | ||
|
||
def __getitem__(self, index): | ||
return self.toks[index], self.labels[index] | ||
|
||
def __len__(self): | ||
return self.size | ||
|
||
def to(self, device: str): | ||
self.toks = self.toks.to(device) | ||
self.labels = self.labels.to(device) | ||
return self | ||
|
||
|
||
# %% | ||
|
54 changes: 54 additions & 0 deletions
54
chapter1_transformers/exercises/monthly_algorithmic_problems/november23_cumsum/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch as t | ||
import numpy as np | ||
from typing import Optional, List | ||
from transformer_lens import HookedTransformer, HookedTransformerConfig | ||
|
||
def create_model( | ||
max_value: int, | ||
seq_len: int, | ||
seed: int, | ||
d_model: int, | ||
d_head: int, | ||
n_layers: int, | ||
n_heads: int, | ||
d_mlp: Optional[int], | ||
normalization_type: Optional[str], | ||
device: str = "cuda", | ||
**kwargs # ignore other kwargs | ||
) -> HookedTransformer: | ||
|
||
t.manual_seed(seed) | ||
np.random.seed(seed) | ||
|
||
attn_only = (d_mlp is None) | ||
|
||
cfg = HookedTransformerConfig( | ||
|
||
n_layers=n_layers, | ||
n_ctx=seq_len, | ||
d_model=d_model, | ||
d_head=d_head, | ||
n_heads=n_heads, | ||
d_mlp=d_mlp, | ||
attn_only=attn_only, | ||
act_fn="relu", | ||
|
||
# We have from [-max_value, +max_value] inclusive, so 2*max_value+1 | ||
d_vocab=2*max_value+1, | ||
# We only have 3 classifications | ||
d_vocab_out=3, | ||
|
||
# it's a small transformer so may as well use these hooks | ||
use_attn_result=True, | ||
use_split_qkv_input=True, | ||
use_hook_tokens=True, | ||
|
||
# Layernorm makes things way more accurate, even though it makes | ||
# mech interp a little more annoying! | ||
normalization_type=normalization_type, | ||
|
||
device=device, | ||
) | ||
|
||
model = HookedTransformer(cfg) | ||
return model |
133 changes: 133 additions & 0 deletions
133
chapter1_transformers/exercises/monthly_algorithmic_problems/november23_cumsum/training.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from dataclasses import dataclass | ||
from typing import Tuple, Optional, List | ||
from tqdm import tqdm | ||
import torch as t | ||
from torch import Tensor | ||
import torch.nn.functional as F | ||
from copy import deepcopy | ||
from torch.utils.data import DataLoader | ||
import einops | ||
import wandb | ||
|
||
from monthly_algorithmic_problems.november23_cumsum.dataset import CumsumDataset | ||
from monthly_algorithmic_problems.november23_cumsum.model import create_model | ||
|
||
|
||
@dataclass | ||
class TrainArgs: | ||
max_value: int | ||
seq_len: int | ||
trainset_size: int | ||
valset_size: int | ||
epochs: int | ||
batch_size: int | ||
lr_start: float | ||
lr_end: float | ||
weight_decay: float | ||
seed: int | ||
d_model: int | ||
d_head: int | ||
n_layers: int | ||
n_heads: int | ||
d_mlp: int | ||
normalization_type: Optional[str] | ||
use_wandb: bool | ||
device: str | ||
|
||
|
||
class Trainer: | ||
def __init__(self, args: TrainArgs): | ||
self.args = args | ||
self.model = create_model(**args.__dict__) | ||
if args.use_wandb: | ||
wandb.init(project="sum-model") | ||
wandb.watch(self.model) | ||
|
||
def training_step(self, batch: Tuple[Tensor, Tensor]) -> t.Tensor: | ||
logits, labels = self._shared_train_validation_step(batch) | ||
logprobs = logits.log_softmax(-1) | ||
loss = F.cross_entropy( | ||
einops.rearrange(logprobs, "batch seq vocab_out -> (batch seq) vocab_out"), | ||
einops.rearrange(labels, "batch seq -> (batch seq)"), | ||
) | ||
return loss | ||
|
||
def validation_step(self, batch: Tuple[Tensor, Tensor]) -> t.Tensor: | ||
logits, labels = self._shared_train_validation_step(batch) | ||
accuracy = (logits.argmax(-1) == labels).float().sum().item() | ||
return accuracy | ||
|
||
def _shared_train_validation_step(self, batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: | ||
toks, labels = batch | ||
toks = toks.to(self.args.device) | ||
labels = labels.to(self.args.device) | ||
logits = self.model(toks) | ||
return logits, labels | ||
|
||
def train_dataloader(self, seed: int): | ||
trainset = CumsumDataset(size=self.args.trainset_size, max_value=self.args.max_value, seq_len=self.args.seq_len, seed=seed) | ||
return DataLoader(trainset, batch_size=self.args.batch_size, shuffle=True) | ||
|
||
def val_dataloader(self, seed: int): | ||
valset = CumsumDataset(size=self.args.valset_size, max_value=self.args.max_value, seq_len=self.args.seq_len, seed=seed) | ||
return DataLoader(valset, batch_size=self.args.batch_size, shuffle=False) | ||
|
||
def configure_optimizers(self): | ||
optimizer = t.optim.Adam(self.model.parameters(), lr=self.args.lr_start, weight_decay=self.args.weight_decay) | ||
return optimizer | ||
|
||
|
||
def train(args: TrainArgs): | ||
|
||
trainer = Trainer(args) | ||
optimizer = trainer.configure_optimizers() | ||
|
||
train_dataloader = trainer.train_dataloader(seed=args.seed) | ||
val_dataloader = trainer.val_dataloader(seed=args.seed+1) | ||
|
||
# Save the best model (based on validation accuracy) | ||
best_model = deepcopy(trainer.model) | ||
best_epoch = None | ||
best_accuracy = None | ||
best_loss = None | ||
|
||
for epoch in range(args.epochs): | ||
|
||
# Update learning rate (linear decay) | ||
lr = args.lr_start + (args.lr_end - args.lr_start) * ((epoch + 1) / args.epochs) | ||
optimizer.param_groups[0]["lr"] = lr | ||
|
||
progress_bar = tqdm(total=args.trainset_size//args.batch_size) | ||
|
||
# Training | ||
for batch in train_dataloader: | ||
# Optimization step on training set | ||
optimizer.zero_grad() | ||
loss = trainer.training_step(batch) | ||
loss.backward() | ||
optimizer.step() | ||
# Log variables, update progress bar | ||
if args.use_wandb: wandb.log({"training_loss": loss}) | ||
progress_bar.update() | ||
progress_bar.set_description(f"Epoch {epoch:02}, Train loss = {loss:.4f}"); | ||
|
||
# Validation | ||
with t.inference_mode(): | ||
# Calculate accuracy on validation set | ||
accuracy = sum(trainer.validation_step(batch) for batch in val_dataloader) / (args.valset_size * args.seq_len) | ||
# Log variables, update progress bar | ||
if args.use_wandb: wandb.log({"test_accuracy": accuracy}) | ||
progress_bar.set_description(f"Epoch {epoch:02}, Train loss = {loss:.4f}, Accuracy: {accuracy:.4f}, LR = {lr:.2e}"); | ||
|
||
# If validation accuracy is the best it's been so far, save this model | ||
if best_accuracy is None or (accuracy > best_accuracy) or (accuracy >= best_accuracy and loss <= best_loss): | ||
best_epoch = epoch | ||
best_accuracy = accuracy | ||
best_loss = loss | ||
best_model = deepcopy(trainer.model) | ||
|
||
if args.use_wandb: | ||
wandb.finish() | ||
|
||
print(f"Returning best model from epoch {best_epoch}/{args.epochs}, with accuracy {best_accuracy:.3f}") | ||
return best_model |
Oops, something went wrong.