Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed Jun 22, 2023
1 parent 43c29d9 commit 442f86c
Show file tree
Hide file tree
Showing 5 changed files with 564 additions and 0 deletions.
Empty file added hf_integration/__init__.py
Empty file.
30 changes: 30 additions & 0 deletions hf_integration/configuration_olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
OLMo configuration
"""

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

from olmo.config import ModelConfig

logger = logging.get_logger(__name__)

OLMO_PRETRAINED_CONFIG_ARCHIVE_MAP = {}


class OLMoConfig(PretrainedConfig, ModelConfig): # trying to keep it as simple as possible.

model_type = "olmo"
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm

def __init__(self, **kwargs):
# TODO: confirm name mapping.
super().__init__(**kwargs)

def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"

def to_dict(self) -> str:
the_dict = PretrainedConfig.to_dict(self)
the_dict.update(ModelConfig.asdict(self))
return the_dict
283 changes: 283 additions & 0 deletions hf_integration/convert_olmo_weights_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
# import argparse
# import gc
# import json
# import math
# import os
# import shutil
# import warnings
#
# import torch
#
# # from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
# from .configuration_olmo import OLMoConfig
# from .modeling_olmo import OLMoForCausalLM
# from olmo.tokenizer import Tokenizer
#
# # try:
# # from transformers import LlamaTokenizerFast
# # except ImportError as e:
# # warnings.warn(e)
# # warnings.warn(
# # "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
# # )
# # LlamaTokenizerFast = None
#
# """
# Sample usage:
#
# ```
# python src/transformers/models/llama/convert_llama_weights_to_hf.py \
# --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
# ```
#
# Thereafter, models can be loaded via:
#
# ```py
# from transformers import LlamaForCausalLM, LlamaTokenizer
#
# model = LlamaForCausalLM.from_pretrained("/output/path")
# tokenizer = LlamaTokenizer.from_pretrained("/output/path")
# ```
#
# Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
# come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
# """
#
# INTERMEDIATE_SIZE_MAP = {
# "7B": 11008,
# "13B": 13824,
# "30B": 17920,
# "65B": 22016,
# }
# NUM_SHARDS = {
# "7B": 1,
# "13B": 2,
# "30B": 4,
# "65B": 8,
# }
#
#
# def compute_intermediate_size(n):
# return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
#
#
# def read_json(path):
# with open(path, "r") as f:
# return json.load(f)
#
#
# def write_json(text, path):
# with open(path, "w") as f:
# json.dump(text, f)
#
#
# def write_model(model_path, input_base_path, model_size):
# os.makedirs(model_path, exist_ok=True)
# tmp_model_path = os.path.join(model_path, "tmp")
# os.makedirs(tmp_model_path, exist_ok=True)
#
# params = read_json(os.path.join(input_base_path, "params.json"))
# num_shards = NUM_SHARDS[model_size]
# n_layers = params["n_layers"]
# n_heads = params["n_heads"]
# n_heads_per_shard = n_heads // num_shards
# dim = params["dim"]
# dims_per_head = dim // n_heads
# base = 10000.0
# inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
#
# # permute for sliced rotary
# def permute(w):
# return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
#
# print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# # Load weights
# if model_size == "7B":
# # Not sharded
# # (The sharded implementation would also work, but this is simpler.)
# loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
# else:
# # Sharded
# loaded = [
# torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
# for i in range(num_shards)
# ]
# param_count = 0
# index_dict = {"weight_map": {}}
# for layer_i in range(n_layers):
# filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
# if model_size == "7B":
# # Unsharded
# state_dict = {
# f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
# loaded[f"layers.{layer_i}.attention.wq.weight"]
# ),
# f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
# loaded[f"layers.{layer_i}.attention.wk.weight"]
# ),
# f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
# f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
# f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
# f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
# f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
# f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
# f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
# }
# else:
# # Sharded
# # Note that in the 13B checkpoint, not cloning the two following weights will result in the checkpoint
# # becoming 37GB instead of 26GB for some reason.
# state_dict = {
# f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
# f"layers.{layer_i}.attention_norm.weight"
# ].clone(),
# f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
# f"layers.{layer_i}.ffn_norm.weight"
# ].clone(),
# }
# state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
# torch.cat(
# [
# loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
# for i in range(num_shards)
# ],
# dim=0,
# ).reshape(dim, dim)
# )
# state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
# torch.cat(
# [
# loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
# for i in range(num_shards)
# ],
# dim=0,
# ).reshape(dim, dim)
# )
# state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
# [
# loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
# for i in range(num_shards)
# ],
# dim=0,
# ).reshape(dim, dim)
#
# state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
# [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
# )
# state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
# [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
# )
# state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
# [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
# )
# state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
# [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
# )
#
# state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
# for k, v in state_dict.items():
# index_dict["weight_map"][k] = filename
# param_count += v.numel()
# torch.save(state_dict, os.path.join(tmp_model_path, filename))
#
# filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
# if model_size == "7B":
# # Unsharded
# state_dict = {
# "model.embed_tokens.weight": loaded["tok_embeddings.weight"],
# "model.norm.weight": loaded["norm.weight"],
# "lm_head.weight": loaded["output.weight"],
# }
# else:
# state_dict = {
# "model.norm.weight": loaded[0]["norm.weight"],
# "model.embed_tokens.weight": torch.cat(
# [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
# ),
# "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
# }
#
# for k, v in state_dict.items():
# index_dict["weight_map"][k] = filename
# param_count += v.numel()
# torch.save(state_dict, os.path.join(tmp_model_path, filename))
#
# # Write configs
# index_dict["metadata"] = {"total_size": param_count * 2}
# write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
#
# config = LlamaConfig(
# hidden_size=dim,
# intermediate_size=compute_intermediate_size(dim),
# num_attention_heads=params["n_heads"],
# num_hidden_layers=params["n_layers"],
# rms_norm_eps=params["norm_eps"],
# )
# config.save_pretrained(tmp_model_path)
#
# # Make space so we can load the model properly now.
# del state_dict
# del loaded
# gc.collect()
#
# print("Loading the checkpoint in a Llama model.")
# model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
# # Avoid saving this as part of the config.
# del model.config._name_or_path
#
# print("Saving in the Transformers format.")
# model.save_pretrained(model_path)
# shutil.rmtree(tmp_model_path)
#
#
# def write_tokenizer(tokenizer_path, input_tokenizer_path):
# # Initialize the tokenizer based on the `spm` model
# tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
# print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
# tokenizer = tokenizer_class(input_tokenizer_path)
# tokenizer.save_pretrained(tokenizer_path)
#
#
import argparse
import os
from olmo import Olmo, ModelConfig
from hf_integration.configuration_olmo import OLMoConfig
from hf_integration.modeling_olmo import OLMoPretrainedModel


def write_model(model_path: str, checkpoint_dir: str):
os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)

# save config as HF config
model = Olmo.from_checkpoint(checkpoint_dir)
config = OLMoConfig(**model.config.asdict())
config.save_pretrained(model_path)

OLMoPretrainedModel(model)

# save tokenizer? not needed?


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of OLMo weights and model config.",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
checkpoint_dir=args.input_dir,
)
# spm_path = os.path.join(args.input_dir, "tokenizer.model")
# write_tokenizer(args.output_dir, spm_path)


if __name__ == "__main__":
main()
Loading

0 comments on commit 442f86c

Please sign in to comment.