Skip to content

Commit

Permalink
Set torch default dtype in a context manager (vllm-project#971)
Browse files Browse the repository at this point in the history
Signed-off-by: Antoni Baum <[email protected]>
  • Loading branch information
Yard1 authored Sep 7, 2023
1 parent 320a622 commit 005ba45
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for selecting and loading models."""
import contextlib
from typing import Type

import torch
Expand Down Expand Up @@ -30,6 +31,15 @@
}


@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)


def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
Expand All @@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:

def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config)
torch.set_default_dtype(model_config.dtype)

# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(model_config.hf_config)
if model_config.use_dummy_weights:
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir,
model_config.use_np_weights)
model = model.cuda()
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(model_config.hf_config)
if model_config.use_dummy_weights:
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
else:
# Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir,
model_config.use_np_weights)
model = model.cuda()
return model.eval()

0 comments on commit 005ba45

Please sign in to comment.