Skip to content

Commit

Permalink
Upload model (NovaSky-AI#29)
Browse files Browse the repository at this point in the history
* fix math eval

* add upload hub script
  • Loading branch information
DachengLi1 authored Jan 20, 2025
1 parent 783c785 commit 037f7d8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
45 changes: 45 additions & 0 deletions skythought/tools/upload_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
From https://github.com/lm-sys/FastChat/
Upload weights to huggingface.
Usage:
python upload_hub.py --model-path ~/model_weights/Sky-T1 --hub-repo-id NovaSky-AI/Sky-T1 --private
"""
import argparse
import tempfile

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def upload_hub(model_path, hub_repo_id, component, private):
if component == "all":
components = ["model", "tokenizer"]
else:
components = [component]

kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private}

if "model" in components:
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
with tempfile.TemporaryDirectory() as tmp_path:
model.save_pretrained(tmp_path, **kwargs)

if "tokenizer" in components:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
with tempfile.TemporaryDirectory() as tmp_path:
tokenizer.save_pretrained(tmp_path, **kwargs)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--hub-repo-id", type=str, required=True)
parser.add_argument(
"--component", type=str, choices=["all", "model", "tokenizer"], default="all"
)
parser.add_argument("--private", action="store_true")
args = parser.parse_args()

upload_hub(args.model_path, args.hub_repo_id, args.component, args.private)
5 changes: 4 additions & 1 deletion skythought/tools/util/math/testing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ def replace_match(match):
# If the answer is a list of integers (without parenthesis), sort them
if re.fullmatch(r"(\s*-?\d+\s*,)*\s*-?\d+\s*", string):
# Split the string into a list of integers
integer_list = list(map(int, string.split(',')))
try:
integer_list = list(map(int, string.split(',')))
except:
integer_list = list(map(int, "-1,-1".split(',')))

# Sort the list in ascending order
sorted_list = sorted(integer_list)
Expand Down

0 comments on commit 037f7d8

Please sign in to comment.