Skip to content

Commit

Permalink
add new algo problem
Browse files Browse the repository at this point in the history
  • Loading branch information
callummcdougall committed Nov 2, 2023
1 parent 9a20842 commit 3b8c912
Show file tree
Hide file tree
Showing 18 changed files with 3,507 additions and 12 deletions.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch as t
from torch.utils.data import Dataset



class CumsumDataset(Dataset):

def __init__(self, size: int, max_value: int, seq_len: int, seed: int = 42, p: float = 0.7):
'''
Dataset for the cumulative sum problem. Contains tokens and labels.
Tokens are integers between [-max_value, +max_value] inclusive. Labels are
2 if the cumsum is strictly positive, 1 if it's zero, 0 if strictly negative.
The argument `p` is the probability that the next value will be of opposite
sign to the previous value. This is to make zeros a bit more common.
'''

self.vocab = [f"{i:+}" for i in range(-max_value, max_value+1)]
self.vocab_out = ["neg", "zero", "pos"]
self.size = size
self.max_value = max_value
t.manual_seed(seed) # for reproducible results

# Generate our sequences, and labels
self.toks = t.randint(-max_value, max_value+1, (size, seq_len))

for seq_pos in range(1, seq_len):
prev_cumsum = self.toks[:, :seq_pos].sum(dim=1)
# Get random 1s & 0s, where 1s have probability p
p_rand_seed = t.rand(size) < p
# If prev_cumsum is positive, we want next value to be negative with prob p
next_value_is_negative = t.where(prev_cumsum > 0, p_rand_seed, ~p_rand_seed)
# Get new random values, and flip them randomly
new_values = t.randint(0, max_value+1, (size,))
new_values = t.where(next_value_is_negative, -new_values, new_values)
# Add the new values to self.toks
self.toks[:, seq_pos] = new_values

# Create labels: 0 if neg, 1 if zero, 2 if pos
toks_pos = (self.toks.cumsum(dim=1) > 0)
toks_neg = (self.toks.cumsum(dim=1) < 0)
self.labels = t.where(toks_pos, 2, t.where(toks_neg, 0, 1))

self.str_toks = [
[f"{tok:+}" for tok in toks]
for toks in self.toks
]
str_labels_dict = {0: "neg", 1: "zero", 2: "pos"}
self.str_labels = [
[str_labels_dict[label.item()] for label in labels]
for labels in self.labels
]

def __getitem__(self, index):
return self.toks[index], self.labels[index]

def __len__(self):
return self.size

def to(self, device: str):
self.toks = self.toks.to(device)
self.labels = self.labels.to(device)
return self


# %%

Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch as t
import numpy as np
from typing import Optional, List
from transformer_lens import HookedTransformer, HookedTransformerConfig

def create_model(
max_value: int,
seq_len: int,
seed: int,
d_model: int,
d_head: int,
n_layers: int,
n_heads: int,
d_mlp: Optional[int],
normalization_type: Optional[str],
device: str = "cuda",
**kwargs # ignore other kwargs
) -> HookedTransformer:

t.manual_seed(seed)
np.random.seed(seed)

attn_only = (d_mlp is None)

cfg = HookedTransformerConfig(

n_layers=n_layers,
n_ctx=seq_len,
d_model=d_model,
d_head=d_head,
n_heads=n_heads,
d_mlp=d_mlp,
attn_only=attn_only,
act_fn="relu",

# We have from [-max_value, +max_value] inclusive, so 2*max_value+1
d_vocab=2*max_value+1,
# We only have 3 classifications
d_vocab_out=3,

# it's a small transformer so may as well use these hooks
use_attn_result=True,
use_split_qkv_input=True,
use_hook_tokens=True,

# Layernorm makes things way more accurate, even though it makes
# mech interp a little more annoying!
normalization_type=normalization_type,

device=device,
)

model = HookedTransformer(cfg)
return model
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from dataclasses import dataclass
from typing import Tuple, Optional, List
from tqdm import tqdm
import torch as t
from torch import Tensor
import torch.nn.functional as F
from copy import deepcopy
from torch.utils.data import DataLoader
import einops
import wandb

from monthly_algorithmic_problems.november23_cumsum.dataset import CumsumDataset
from monthly_algorithmic_problems.november23_cumsum.model import create_model


@dataclass
class TrainArgs:
max_value: int
seq_len: int
trainset_size: int
valset_size: int
epochs: int
batch_size: int
lr_start: float
lr_end: float
weight_decay: float
seed: int
d_model: int
d_head: int
n_layers: int
n_heads: int
d_mlp: int
normalization_type: Optional[str]
use_wandb: bool
device: str


class Trainer:
def __init__(self, args: TrainArgs):
self.args = args
self.model = create_model(**args.__dict__)
if args.use_wandb:
wandb.init(project="sum-model")
wandb.watch(self.model)

def training_step(self, batch: Tuple[Tensor, Tensor]) -> t.Tensor:
logits, labels = self._shared_train_validation_step(batch)
logprobs = logits.log_softmax(-1)
loss = F.cross_entropy(
einops.rearrange(logprobs, "batch seq vocab_out -> (batch seq) vocab_out"),
einops.rearrange(labels, "batch seq -> (batch seq)"),
)
return loss

def validation_step(self, batch: Tuple[Tensor, Tensor]) -> t.Tensor:
logits, labels = self._shared_train_validation_step(batch)
accuracy = (logits.argmax(-1) == labels).float().sum().item()
return accuracy

def _shared_train_validation_step(self, batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
toks, labels = batch
toks = toks.to(self.args.device)
labels = labels.to(self.args.device)
logits = self.model(toks)
return logits, labels

def train_dataloader(self, seed: int):
trainset = CumsumDataset(size=self.args.trainset_size, max_value=self.args.max_value, seq_len=self.args.seq_len, seed=seed)
return DataLoader(trainset, batch_size=self.args.batch_size, shuffle=True)

def val_dataloader(self, seed: int):
valset = CumsumDataset(size=self.args.valset_size, max_value=self.args.max_value, seq_len=self.args.seq_len, seed=seed)
return DataLoader(valset, batch_size=self.args.batch_size, shuffle=False)

def configure_optimizers(self):
optimizer = t.optim.Adam(self.model.parameters(), lr=self.args.lr_start, weight_decay=self.args.weight_decay)
return optimizer


def train(args: TrainArgs):

trainer = Trainer(args)
optimizer = trainer.configure_optimizers()

train_dataloader = trainer.train_dataloader(seed=args.seed)
val_dataloader = trainer.val_dataloader(seed=args.seed+1)

# Save the best model (based on validation accuracy)
best_model = deepcopy(trainer.model)
best_epoch = None
best_accuracy = None
best_loss = None

for epoch in range(args.epochs):

# Update learning rate (linear decay)
lr = args.lr_start + (args.lr_end - args.lr_start) * ((epoch + 1) / args.epochs)
optimizer.param_groups[0]["lr"] = lr

progress_bar = tqdm(total=args.trainset_size//args.batch_size)

# Training
for batch in train_dataloader:
# Optimization step on training set
optimizer.zero_grad()
loss = trainer.training_step(batch)
loss.backward()
optimizer.step()
# Log variables, update progress bar
if args.use_wandb: wandb.log({"training_loss": loss})
progress_bar.update()
progress_bar.set_description(f"Epoch {epoch:02}, Train loss = {loss:.4f}");

# Validation
with t.inference_mode():
# Calculate accuracy on validation set
accuracy = sum(trainer.validation_step(batch) for batch in val_dataloader) / (args.valset_size * args.seq_len)
# Log variables, update progress bar
if args.use_wandb: wandb.log({"test_accuracy": accuracy})
progress_bar.set_description(f"Epoch {epoch:02}, Train loss = {loss:.4f}, Accuracy: {accuracy:.4f}, LR = {lr:.2e}");

# If validation accuracy is the best it's been so far, save this model
if best_accuracy is None or (accuracy > best_accuracy) or (accuracy >= best_accuracy and loss <= best_loss):
best_epoch = epoch
best_accuracy = accuracy
best_loss = loss
best_model = deepcopy(trainer.model)

if args.use_wandb:
wandb.finish()

print(f"Returning best model from epoch {best_epoch}/{args.epochs}, with accuracy {best_accuracy:.3f}")
return best_model
Loading

0 comments on commit 3b8c912

Please sign in to comment.