Skip to content

Commit

Permalink
Fixed a bug that manifested only occasionally in convert_validation_b…
Browse files Browse the repository at this point in the history
…atch(), if all records in a betch happen to be shorter than the max block length
  • Loading branch information
gstrazds committed Apr 13, 2023
1 parent 2e208c9 commit 4bcbad1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
12 changes: 8 additions & 4 deletions conf/overrides/gengo2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ data:

model:
# block_size: 128 # aka seq_len or 'attention window size'
block_size: 992 # spatial extent of the model for its context
# block_size: 992 # spatial extent of the model for its context
block_size: 1024 # spatial extent of the model for its context
# n_layers: 8
n_layers: 17
# n_layers: 17
n_layers: 15
n_heads: 8
d_embd: 384
# d_embd: 384
d_embd: 512

trainer:
batch_size: 8
Expand All @@ -44,4 +47,5 @@ eval:
#checkpoint: ${cwd_path}/saved_models/2023-04-10-20-32-0/mingpt:gpt:pthru-epoch=5-step=6374-val_acc=0.844-val_loss=0.227.ckpt
#checkpoint: ${cwd_path}/saved_models/2023-04-11-X/mingpt:gpt:pthru-epoch=7-step=8182-val_acc=0.674-val_loss=0.220.ckpt
#checkpoint: ${cwd_path}/saved_models/2023-04-11-X/mingpt:gpt:pthru-epoch=4-step=5205-val_acc=0.660-val_loss=0.212.ckpt
checkpoint: ${cwd_path}/saved_models/2023-04-11-20-00-Y/mingpt:gpt:pthru-epoch=4-step=4569-val_acc=0.758-val_loss=0.230.ckpt
#checkpoint: ${cwd_path}/saved_models/2023-04-11-20-00-Y/mingpt:gpt:pthru-epoch=4-step=4569-val_acc=0.758-val_loss=0.230.ckpt
checkpoint: ${cwd_path}/saved_models/2023-04-13-11-44/mingpt:gpt:pthru-epoch=6-step=6801-val_acc=0.842-val_loss=0.242.ckpt
4 changes: 3 additions & 1 deletion mingpt/pl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,9 @@ def convert_validation_batch(batch_x, batch_y, batch_cmd_pos, batch_cmd_len, max

cmd_len = batch_cmd_len[i]
if truncate_to_max: # max_block_size:
prefix_len = block_size - cmd_len
if batch_cmd_pos[i] < block_size-1 or block_size > batch_x.shape[1]:
print(f"block_size:{block_size} cmd_len:{cmd_len} max_block_size:{max_block_size} cmd_pos:{batch_cmd_pos[i]} {batch_x.shape} {new_x.shape}")
prefix_len = batch_x.shape[1] - cmd_len
new_x[i, 0:prefix_len + 1] = batch_x[i, cmd_len - 1:]
new_y[i, 0:prefix_len] = batch_x[i, cmd_len:]
_cmd_pos = batch_cmd_pos[i] - (cmd_len - 1)
Expand Down

0 comments on commit 4bcbad1

Please sign in to comment.