Skip to content

Commit

Permalink
Merge pull request google-research#8 from kashif/hub
Browse files Browse the repository at this point in the history
[Hub] Download weights automatically from hub
  • Loading branch information
siriuz42 authored May 10, 2024
2 parents e2a5274 + 2a83a0f commit 2fafb79
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ tfm = timesfm.TimesFm(
model_dims=1280,
backend=<backend>,
)
tfm.load_from_checkpoint(<checkpoint_path>)
tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")
```

Note that the four parameters are fixed to load the 200m model
Expand Down
14 changes: 11 additions & 3 deletions src/timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@

import logging
import multiprocessing
from os import path
import time
from typing import Any, Literal, Sequence
from typing import Any, Literal, Optional, Sequence

import einshape as es
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from huggingface_hub import snapshot_download
from paxml import checkpoints
from paxml import tasks_lib
from praxis import base_hyperparams
Expand Down Expand Up @@ -222,17 +224,23 @@ def _get_sample_inputs(self):

def load_from_checkpoint(
self,
checkpoint_path: str,
checkpoint_path: Optional[str] = None,
repo_id: str = "google/timesfm-1.0-200m",
checkpoint_type: checkpoints.CheckpointType = checkpoints.CheckpointType.FLAX,
step: int | None = None,
) -> None:
"""Loads a checkpoint and compiles the decoder.
Args:
checkpoint_path: path to the checkpoint directory.
checkpoint_path: Optional path to the checkpoint directory.
repo_id: Hugging Face Hub repo id.
checkpoint_type: type of PAX checkpoint
step: step of the checkpoint to load. If `None`, load latest checkpoint.
"""
# Download the checkpoint from Hugging Face Hub if not given
if checkpoint_path is None:
checkpoint_path = path.join(snapshot_download(repo_id), "checkpoints")

# Initialize the model weights.
self._logging("Constructing model weights.")
start_time = time.time()
Expand Down

0 comments on commit 2fafb79

Please sign in to comment.