Skip to content

Commit

Permalink
Fix ComposerHFCausalLM initialization (mosaicml#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanlint authored May 2, 2023
1 parent d25e1ab commit 601b02a
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@

"""Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""

from typing import Optional
from typing import Union

from composer.metrics.nlp import (InContextLearningLMAccuracy,
InContextLearningMultipleChoiceAccuracy,
LanguageCrossEntropy, LanguagePerplexity)
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer,
PreTrainedTokenizerFast)

from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.utils import init_empty_weights

__all__ = ['ComposerHFCausalLM']

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
"""Configures a :class:`.HuggingFaceModel` around a Causal LM.
Expand All @@ -40,21 +42,11 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
to validation metrics. Default: ``False``.
"""

def __init__(self,
om_model_config: DictConfig,
om_tokenizer_config: Optional[DictConfig] = None):
def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer):
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
**om_model_config.get('config_overrides', {}))

resolved_om_tokenizer_config = om.to_container(om_tokenizer_config,
resolve=True)
tokenizer_kwargs = resolved_om_tokenizer_config.get( # type: ignore
'kwargs', {})
tokenizer_name = resolved_om_tokenizer_config['name'] # type: ignore
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name,
**tokenizer_kwargs)

train_metrics = [
LanguageCrossEntropy(len(tokenizer)),
LanguagePerplexity(len(tokenizer)),
Expand Down

0 comments on commit 601b02a

Please sign in to comment.