Skip to content

Commit

Permalink
make auto compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Jun 22, 2023
1 parent b5da6b4 commit 6935e5f
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 287 deletions.
31 changes: 31 additions & 0 deletions hf_integration/add_hf_config_to_olmo_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import argparse

from hf_integration.configuration_olmo import OLMoConfig
from olmo import Olmo


def write_config(checkpoint_dir: str):
# save config as HF config
# TODO: add logging
model = Olmo.from_checkpoint(checkpoint_dir)
config = OLMoConfig(**model.config.asdict())
config.save_pretrained(checkpoint_dir)


def main():
parser = argparse.ArgumentParser(
description="Adds a config.json to the checkpoint directory, making it easier to load weights as HF models"
)
parser.add_argument(
"--checkpoint-dir",
help="Location of OLMo checkpoint.",
)

args = parser.parse_args()
write_config(
checkpoint_dir=args.checkpoint_dir,
)


if __name__ == "__main__":
main()
6 changes: 5 additions & 1 deletion hf_integration/configuration_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
OLMo configuration
"""

from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging

from olmo.config import ModelConfig
Expand All @@ -28,3 +28,7 @@ def __repr__(self):
# the_dict = PretrainedConfig.to_dict(self)
# the_dict.update(ModelConfig.asdict(self))
# return the_dict


# Register
AutoConfig.register("olmo", OLMoConfig)
284 changes: 0 additions & 284 deletions hf_integration/convert_olmo_weights_to_hf.py

This file was deleted.

21 changes: 19 additions & 2 deletions tests/hf_integration/test_hf_integration.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
import pytest
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from hf_integration.configuration_olmo import OLMoConfig
from hf_integration.modeling_olmo import OLMoForCausalLM
from hf_integration import OLMoConfig, OLMoForCausalLM, OLMoTokenizerFast
from hf_integration.add_hf_config_to_olmo_checkpoint import write_config
from olmo import BlockType, Tokenizer, TrainConfig
from olmo.data import DataCollator
from olmo.model import Olmo


def test_auto_hf_classes(model_path: str):
# model_path is an OLMo checkpoint.

# Creates HF-compatible config.json
write_config(model_path)

config = AutoConfig.from_pretrained(model_path)
assert isinstance(config, OLMoConfig)

model = AutoModelForCausalLM.from_pretrained(model_path)
assert isinstance(model, OLMoForCausalLM)

tokenizer = AutoTokenizer.from_pretrained(model_path)
assert isinstance(tokenizer, OLMoTokenizerFast)


@pytest.mark.parametrize(
"alibi, rope, flash_attn, block_type, multi_query_attention, cuda, dtype",
[
Expand Down

0 comments on commit 6935e5f

Please sign in to comment.