Skip to content

Commit

Permalink
fix(vLLM): importlib relative import
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 24, 2024
1 parent 9d7c437 commit 4f72f4a
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions ChatTTS/model/velocity/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""Utilities for selecting and loading models."""

import contextlib
from typing import Type

import torch
import torch.nn as nn
from transformers import PretrainedConfig

from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import get_quant_config, initialize_dummy_weights
import importlib

from .llama import LlamaModel


@contextlib.contextmanager
Expand All @@ -22,16 +21,7 @@ def _set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype)


def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
model_cls = getattr(
importlib.import_module(".llama"), "LlamaModel", None
)
return model_cls


def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config)

# Get the (maybe quantized) linear method.
linear_method = None
if model_config.quantization is not None:
Expand Down Expand Up @@ -63,7 +53,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
# Create a model instance.
# The weights will be initialized as empty tensors.
with torch.device("cuda"):
model = model_class(model_config.hf_config, linear_method)
model = LlamaModel(model_config.hf_config, linear_method)
if model_config.load_format == "dummy":
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
Expand Down

0 comments on commit 4f72f4a

Please sign in to comment.