From 601b02ae1ad12e77c5e149031dac8cba84da7cc2 Mon Sep 17 00:00:00 2001 From: Hanlin Tang Date: Tue, 2 May 2023 13:24:16 -0700 Subject: [PATCH] Fix `ComposerHFCausalLM` initialization (#6) --- llmfoundry/models/hf/hf_causal_lm.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index c83819fa69..aa9c34074d 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -3,20 +3,22 @@ """Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`.""" -from typing import Optional +from typing import Union from composer.metrics.nlp import (InContextLearningLMAccuracy, InContextLearningMultipleChoiceAccuracy, LanguageCrossEntropy, LanguagePerplexity) from omegaconf import DictConfig -from omegaconf import OmegaConf as om -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer, + PreTrainedTokenizerFast) from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss from llmfoundry.models.utils import init_empty_weights __all__ = ['ComposerHFCausalLM'] +Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + class ComposerHFCausalLM(HuggingFaceModelWithZLoss): """Configures a :class:`.HuggingFaceModel` around a Causal LM. @@ -40,21 +42,11 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss): to validation metrics. Default: ``False``. """ - def __init__(self, - om_model_config: DictConfig, - om_tokenizer_config: Optional[DictConfig] = None): + def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer): config = AutoConfig.from_pretrained( om_model_config.pretrained_model_name_or_path, **om_model_config.get('config_overrides', {})) - resolved_om_tokenizer_config = om.to_container(om_tokenizer_config, - resolve=True) - tokenizer_kwargs = resolved_om_tokenizer_config.get( # type: ignore - 'kwargs', {}) - tokenizer_name = resolved_om_tokenizer_config['name'] # type: ignore - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, - **tokenizer_kwargs) - train_metrics = [ LanguageCrossEntropy(len(tokenizer)), LanguagePerplexity(len(tokenizer)),