Skip to content

Commit

Permalink
Fix issue in loading mixed precision vocab pruned models during torch…
Browse files Browse the repository at this point in the history
…tune generation for evaluation (pytorch#2043)
  • Loading branch information
ifed-ucsd authored Nov 22, 2024
1 parent a9aadf5 commit 009adaa
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
6 changes: 6 additions & 0 deletions tests/torchtune/training/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,9 @@ def test_validate_expected_param_dtype(self):
m = torch.nn.Linear(10, 10)
with pytest.raises(ValueError, match=f"has dtype {next(m.parameters()).dtype}"):
validate_expected_param_dtype(m.named_parameters(), dtype=torch.float16)

validate_expected_param_dtype(
m.named_parameters(),
dtype=torch.float16,
exclude_param_names=[name for name, _ in m.named_parameters()],
)
10 changes: 8 additions & 2 deletions torchtune/training/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import contextlib
from typing import Dict, Generator, Iterable, Optional, Tuple
from typing import Dict, Generator, Iterable, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -147,19 +147,25 @@ def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:


def validate_expected_param_dtype(
named_params: Iterable[Tuple[str, torch.nn.Parameter]], dtype: torch.dtype
named_params: Iterable[Tuple[str, torch.nn.Parameter]],
dtype: torch.dtype,
exclude_param_names: Optional[List[str]] = None,
) -> None:
"""
Validates that all input parameters have the expected dtype.
Args:
named_params (Iterable[Tuple[str, torch.nn.Parameter]]): Iterable of named parameters.
dtype (torch.dtype): Expected dtype.
exclude_param_names (Optional[List[str]]): Optional list of parameter names to exclude from dtype checking
Raises:
ValueError: If any parameter has a different dtype than `dtype`.
"""
for name, param in named_params:
if exclude_param_names is not None:
if any(n in name for n in exclude_param_names):
continue
if param.dtype != dtype:
raise ValueError(
f"Parameter {name} has dtype {param.dtype}, but expected {dtype}"
Expand Down

0 comments on commit 009adaa

Please sign in to comment.