Make sure you have CUDA installed.
pip install -r prerequirements.txt
pip install -r requirements.txt
- place higher loss on the last token (the final answer) to maybe speedup training
- smaller mamba (less layers) to train faster
- train transformer (at least try to)
- test whether the tasks are actually fully sequential or there are heuristics for solving them
- check that in AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') my task is still 1+2*num_steps tokens
Mamba trained: 130m stripped to 6 layers, so about 30m.
- for interventions into mamba architecture, this will be easiest mamba-minimal
- it's not efficient though, so later modify code in original repo, and then sanity check that it behaves the same as modified mamba-minimal
cd /workspace && git clone && cd sneaky-mamba && pip install -r requirements.txt
nohup sneaky_mamba/ EXP_NAME > LOG_NAME.log 2>&1 &
/bin/sh -c "cd /workspace && git clone && cd sneaky-mamba && pip install -r requirements.txt && tail -f /dev/null"
old loss:
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels =
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
return lm_loss
fancy loss, with separate loss for final answer:
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.pop("input_ids")
# batched generation
lm_logits = model(input_ids).logits
labels =
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
# cut out the task part (the part before "answer")
reasoning_shift_logits = []
reasoning_labels = []
final_answer_shift_logits = []
final_answer_labels = []
for ex_shift_logits, ex_labels in zip(shift_logits, labels):
# find the indexes of the "answer" token
answer_index = torch.where(ex_labels == answer_token)[0]
answer_index = int(answer_index)
# cut out the task part
# loss for the final answer will be calculated separately
# calculate loss only for the tokens after "answer"
loss_fct = torch.nn.CrossEntropyLoss()
reasoning_lm_loss = loss_fct(,,
loss_fct = torch.nn.CrossEntropyLoss()
final_answer_lm_loss = loss_fct(,,
return reasoning_lm_loss * (1 - final_answer_loss_contribution) + final_answer_lm_loss * final_answer_loss_contribution