Skip to content

Commit

Permalink
Using FormattedCheckpointFiles in configs... round 2 (pytorch#2167)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Dec 18, 2024
1 parent c0b2cbd commit 3518492
Show file tree
Hide file tree
Showing 15 changed files with 27 additions and 15 deletions.
2 changes: 1 addition & 1 deletion recipes/configs/llama2/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ checkpointer:
checkpoint_dir: /tmp/Llama-2-70b-hf
checkpoint_files:
filename_format: pytorch_model-{}-of-{}.bin
max_filename: 00015
max_filename: "00015"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA2
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/70B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ checkpointer:
checkpoint_dir: /tmp/Llama-2-70b-hf
checkpoint_files:
filename_format: pytorch_model-{}-of-{}.bin
max_filename: 00015
max_filename: "00015"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA2
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ checkpointer:
checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ checkpointer:
checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/405B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ checkpointer:
checkpoint_dir: /tmp/Meta-Llama-3.1-405B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00191
max_filename: "00191"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ checkpointer:
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_1/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ checkpointer:
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ checkpointer:
checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_3/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ checkpointer:
checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama3_3/70B_qlora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ checkpointer:
checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00030
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/qwen2_5/14B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2_5-14B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00008
max_filename: "00008"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/qwen2_5/32B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2_5-32B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00017
max_filename: "00017"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/qwen2_5/72B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2_5-72B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: 00037
max_filename: "00037"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,19 @@ def expected_filenames(self):
"model_0012_of_0012.pt",
]

def test_invalid_to_dict(self):
def test_invalid_from_dict_no_filename_format(self):
invalid_dict = {"bad_key": "model_{}_of_{}.pt", "max_filename": "0005"}
with pytest.raises(ValueError, match="Must pass 'filename_format'"):
_ = FormattedCheckpointFiles.from_dict(invalid_dict)

def test_invalid_from_dict_int_max_filename(self):
# the 0o0005 is an octal number. we use this insane value in this test
# as YAML treats numbers with a leading 0 as an octal number, so this
# may be a good example of `from_dict` being called with an invalid config
invalid_dict = {"filename_format": "model_{}_of_{}.pt", "max_filename": 0o00025}
with pytest.raises(ValueError, match="`max_filename` must be a string"):
_ = FormattedCheckpointFiles.from_dict(invalid_dict)

def test_invalid_filename_format(self):
formatted_string = "invalid_format_{}.pt"
formatted_file_dict = {
Expand Down
6 changes: 5 additions & 1 deletion torchtune/training/checkpointing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def from_dict(cls, d: dict) -> "FormattedCheckpointFiles":
raise ValueError(
"Must pass 'filename_format' and 'max_filename' keys to generate checkpoint filenames"
)
if not isinstance(d["max_filename"], str):
raise ValueError(
f"`max_filename` must be a string, but found {type(d['max_filename'])} instead."
)
return cls(
filename_format=d["filename_format"],
max_filename=d["max_filename"],
Expand Down Expand Up @@ -527,7 +531,7 @@ def validate_checkpoint_files(
# e.g.
# checkpoint_files:
# filename_format: model-{}-of-{}.safetensors
# max_filename: 00191
# max_filename: "00191"
# becomes checkpoint_files = [model-00001-of-00191.safetensors, model-00002-of-00191,..]
if not isinstance(checkpoint_files, List):
# TODO: this can be a function instead of a class
Expand Down

0 comments on commit 3518492

Please sign in to comment.