Skip to content

Commit

Permalink
Add ValueError in evaluation script if load_path checkpoint is not sp…
Browse files Browse the repository at this point in the history
…ecified in config for mpt_causal_lm's (mosaicml#535)
  • Loading branch information
j316chuck authored Sep 5, 2023
1 parent 186dd19 commit 3be3d24
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 48 deletions.
2 changes: 1 addition & 1 deletion mcli/mcli-1b-eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ parameters:
attn_config:
attn_impl: triton

load_path: # Add your (optional) Composer checkpoint path here!
load_path: # Add your (non-optional) Composer checkpoint path here!

device_eval_batch_size: 4
precision: amp_fp16
Expand Down
6 changes: 6 additions & 0 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ def evaluate_model(model_cfg: DictConfig, dist_timeout: Union[float, int],
[t.name for t in eval_gauntlet_callback.categories])

load_path = model_cfg.get('load_path', None)
if model_cfg.model.name == 'mpt_causal_lm' and load_path is None:
raise ValueError(
'MPT causal LMs require a load_path to the checkpoint for model evaluation.'
+
' Please check your yaml and the model_cfg to ensure that load_path is set.'
)

assert composer_model is not None

Expand Down
4 changes: 2 additions & 2 deletions scripts/eval/yamls/mpt_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ models:
attn_config:
attn_impl: triton

load_path: # Add your (optional) Composer checkpoint path here!
load_path: # Add your non-optional Composer checkpoint path here! (must not be empty)

device_eval_batch_size: 16

Expand All @@ -38,7 +38,7 @@ fsdp_config:
icl_tasks:
-
label: jeopardy
dataset_uri: eval/local_data/jeopardy_all.jsonl # ADD YOUR OWN DATASET URI
dataset_uri: eval/local_data/world_knowledge/jeopardy_all.jsonl # ADD YOUR OWN DATASET URI
num_fewshot: [0]
icl_task_type: language_modeling
continuation_delimiter: "\nAnswer: " # this separates questions from answers
Expand Down
40 changes: 40 additions & 0 deletions scripts/eval/yamls/test_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
max_seq_len: 128
seed: 1
precision: fp32
models:
- model_name: tiny_mpt
model:
name: mpt_causal_lm
init_device: meta
d_model: 128
n_heads: 2
n_layers: 2
expansion_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50368
attn_config:
attn_impl: torch
loss_fn: torch_crossentropy
# Tokenizer
tokenizer:
name: EleutherAI/gpt-neox-20b
kwargs:
model_max_length: ${max_seq_len}

device_eval_batch_size: 4
icl_subset_num_batches: 1
icl_tasks:
- label: lambada_openai
dataset_uri: eval/local_data/language_understanding/lambada_openai.jsonl
num_fewshot: [0]
icl_task_type: language_modeling
eval_gauntlet:
weighting: EQUAL
subtract_random_baseline: true
rescale_accuracy: true
categories:
- name: language_understanding_lite
benchmarks:
- name: lambada_openai
num_fewshot: 0
random_baseline: 0.0
75 changes: 31 additions & 44 deletions tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

import omegaconf as om
import pytest
from composer import Trainer

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.utils import build_tokenizer

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
Expand All @@ -26,51 +30,34 @@ def set_correct_cwd():
os.chdir('..')


def test_icl_eval(capfd: Any):
test_cfg = om.OmegaConf.create("""
max_seq_len: 1024
seed: 1
precision: fp32
models:
-
model_name: tiny_mpt
model:
name: mpt_causal_lm
init_device: meta
d_model: 128
n_heads: 2
n_layers: 2
expansion_ratio: 4
max_seq_len: ${max_seq_len}
vocab_size: 50368
attn_config:
attn_impl: torch
loss_fn: torch_crossentropy
# Tokenizer
tokenizer:
name: EleutherAI/gpt-neox-20b
kwargs:
model_max_length: ${max_seq_len}
@pytest.fixture()
def mock_saved_model_path():
# load the eval and model config
with open('eval/yamls/test_eval.yaml', 'r', encoding='utf-8') as f:
eval_cfg = om.OmegaConf.load(f)
model_cfg = eval_cfg.models[0]
# set device to cpu
device = 'cpu'
model_cfg.model.init_device = device
# build tokenizer
tokenizer = build_tokenizer(model_cfg.tokenizer)
# build model
model = COMPOSER_MODEL_REGISTRY[model_cfg.model.name](model_cfg.model,
tokenizer)
# create mocked save checkpoint
trainer = Trainer(model=model, device=device)
saved_model_path = os.path.join(os.getcwd(), 'test_model.pt')
trainer.save_checkpoint(saved_model_path)
yield saved_model_path

# clean up the mocked save checkpoint
os.remove(saved_model_path)


device_eval_batch_size: 4
icl_subset_num_batches: 1
icl_tasks:
-
label: lambada_openai
dataset_uri: eval/local_data/language_understanding/lambada_openai.jsonl
num_fewshot: [0]
icl_task_type: language_modeling
eval_gauntlet:
weighting: EQUAL
subtract_random_baseline: true
rescale_accuracy: true
categories:
- name: language_understanding_lite
benchmarks:
- name: lambada_openai
num_fewshot: 0
random_baseline: 0.0
""")
def test_icl_eval(capfd: Any, mock_saved_model_path: Any):
with open('eval/yamls/test_eval.yaml', 'r', encoding='utf-8') as f:
test_cfg = om.OmegaConf.load(f)
test_cfg.models[0].load_path = mock_saved_model_path
assert isinstance(test_cfg, om.DictConfig)
main(test_cfg)
out, _ = capfd.readouterr()
Expand Down
32 changes: 31 additions & 1 deletion tests/test_eval_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from scripts.eval.eval import main # noqa: E402


class TestEvalYAMLInputs:
class TestHuggingFaceEvalYAMLInputs:
"""Validate and tests error handling for the input YAML file."""

@pytest.fixture
Expand Down Expand Up @@ -72,3 +72,33 @@ def test_optional_mispelled_params_raise_warning(self,
str(warning.message) for warning in warning_list)
# restore configs.
cfg = copy.deepcopy(old_cfg)


class TestMPTEvalYAMLInputs:

@pytest.fixture
def cfg(self) -> DictConfig:
"""Create YAML cfg fixture for testing purposes."""
conf_path: str = os.path.join(repo_dir,
'scripts/eval/yamls/mpt_eval.yaml')
with open(conf_path, 'r', encoding='utf-8') as config:
test_cfg = om.load(config)

test_cfg.icl_tasks[0].dataset_uri = os.path.join(
repo_dir, 'scripts', test_cfg.icl_tasks[0].dataset_uri)

# make tests use cpu initialized transformer models only
test_cfg.models[0].model.init_device = 'cpu'
test_cfg.models[0].model.attn_config.attn_impl = 'torch'
test_cfg.models[0].model.loss_fn = 'torch_crossentropy'
test_cfg.precision = 'fp32'
assert isinstance(test_cfg, DictConfig)
return test_cfg

def test_empty_load_path_raises_error(self, cfg: DictConfig) -> None:
"""Check that empty load paths for mpt models raise an error."""
error_string = 'MPT causal LMs require a load_path to the checkpoint for model evaluation.' \
+ ' Please check your yaml and the model_cfg to ensure that load_path is set.'
cfg.models[0].load_path = None
with pytest.raises(ValueError, match=error_string):
main(cfg)

0 comments on commit 3be3d24

Please sign in to comment.