Skip to content

Commit

Permalink
Remove usage of LRU cache from peft_utils (pytorch#907)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Apr 30, 2024
1 parent 74b2883 commit 3b3a820
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 34 deletions.
14 changes: 0 additions & 14 deletions tests/torchtune/modules/peft/test_peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torchtune.models.llama2 import llama2, lora_llama2
from torchtune.modules.peft import LoRALinear
from torchtune.modules.peft.peft_utils import (
_get_base_model_params,
AdapterModule,
disable_adapter,
get_adapter_params,
Expand Down Expand Up @@ -167,19 +166,6 @@ def test_get_adapter_params(self, request, model_name, expected_keys):
expected = request.getfixturevalue(expected_keys)
assert set(expected) == set(adapter_params.keys())

@pytest.mark.parametrize(
"model_name, expected_keys",
[
("dummy_adapter_parent_model", "dummy_model_expected_base_model_keys"),
("lora_llama2_model", "lora_llama2_expected_base_model_keys"),
],
)
def test_get_base_model_params(self, request, model_name, expected_keys):
model = request.getfixturevalue(model_name)
base_model_params = _get_base_model_params(model)
expected = request.getfixturevalue(expected_keys)
assert set(expected) == set(base_model_params.keys())

@pytest.mark.parametrize(
"model_name, expected_trainable_keys, expected_frozen_keys",
[
Expand Down
20 changes: 0 additions & 20 deletions torchtune/modules/peft/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import contextlib
import functools
from typing import Any, Dict, Generator, List, Literal, Optional, Protocol, Set

from torch import nn
Expand Down Expand Up @@ -33,7 +32,6 @@ def adapter_params(self) -> List[str]:
pass


@functools.lru_cache()
def get_adapter_params(model: nn.Module) -> Dict[str, nn.Parameter]:
"""
Return the subset of parameters from a model that correspond to an adapter.
Expand Down Expand Up @@ -63,24 +61,6 @@ def get_adapter_params(model: nn.Module) -> Dict[str, nn.Parameter]:
return adapter_params


@functools.lru_cache()
def _get_base_model_params(model: nn.Module) -> Dict[str, Any]:
"""
Given a model containing some adapter weights, return the subset of the model's
parameters that correspond to the base model. Assumes that any adapter class has
defined the :func:`~torchtune.modules.peft.AdapterModule.adapter_params` method.
Args:
model (nn.Module): Instance of model class containing some adapter params.
Returns:
Dict[str, Any]: the subset of adapted model's state dict containing
only the base model's parameters.
"""
adapter_params = get_adapter_params(model)
return {k: v for k, v in model.state_dict().items() if k not in adapter_params}


def set_trainable_params(model: nn.Module, adapter_params: Dict[str, Any]) -> None:
"""
Set trainable parameters for an nn.Module based on a state dict of adapter parameters.
Expand Down

0 comments on commit 3b3a820

Please sign in to comment.