forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Fix] Llama 3.2 Vision decoder_trainable flag fixed (pytorch#2150)
- Loading branch information
1 parent
9cfa288
commit cdf5ea2
Showing
3 changed files
with
230 additions
and
37 deletions.
There are no files selected for viewing
150 changes: 150 additions & 0 deletions
150
tests/torchtune/models/llama3_2_vision/test_llama_vision_lora.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
from tests.test_utils import fixed_init_model | ||
from torchtune.models.llama3_2_vision._component_builders import ( | ||
lora_llama3_2_vision_decoder, | ||
lora_llama3_2_vision_encoder, | ||
LoRATrainable, | ||
) | ||
from torchtune.modules.model_fusion import DeepFusionModel | ||
from torchtune.modules.peft import get_adapter_params | ||
from torchtune.training.seed import set_seed | ||
|
||
EMBED_DIM = 128 | ||
NUM_LAYERS = 4 | ||
NUM_HEADS = 16 | ||
NUM_KV_HEADS = 8 | ||
VOCAB_SIZE = 32000 | ||
MAX_SEQ_LEN = 2048 | ||
BSZ = 2 | ||
SEQ_LEN = 100 | ||
LORA_ATTN_MODULES = ["q_proj", "k_proj", "v_proj", "output_proj"] | ||
LORA_RANK = 8 | ||
LORA_ALPHA = 16 | ||
IMAGE_SIZE = 140 | ||
PATCH_SIZE = 14 | ||
|
||
|
||
def lora_llama3_2_vision( | ||
decoder_type, | ||
encoder_type, | ||
fusion_type, | ||
) -> DeepFusionModel: | ||
encoder = lora_llama3_2_vision_encoder( | ||
encoder_lora=encoder_type == LoRATrainable.LORA, | ||
fusion_lora=fusion_type == LoRATrainable.LORA, | ||
lora_attn_modules=LORA_ATTN_MODULES, | ||
apply_lora_to_mlp=False, | ||
apply_lora_to_output=False, | ||
patch_size=PATCH_SIZE, | ||
num_heads=NUM_HEADS, | ||
clip_embed_dim=EMBED_DIM, | ||
clip_num_layers=NUM_LAYERS, | ||
clip_hidden_states=[2], | ||
decoder_embed_dim=EMBED_DIM, | ||
num_layers_projection=NUM_LAYERS, | ||
tile_size=IMAGE_SIZE, | ||
max_num_tiles=1, | ||
in_channels=3, | ||
lora_rank=LORA_RANK, | ||
lora_alpha=LORA_ALPHA, | ||
lora_dropout=0.0, | ||
use_dora=False, | ||
quantize_base=False, | ||
) | ||
decoder = lora_llama3_2_vision_decoder( | ||
decoder_lora=decoder_type == LoRATrainable.LORA, | ||
fusion_lora=fusion_type == LoRATrainable.LORA, | ||
lora_attn_modules=LORA_ATTN_MODULES, | ||
apply_lora_to_mlp=False, | ||
apply_lora_to_output=False, | ||
vocab_size=VOCAB_SIZE, | ||
num_layers=NUM_LAYERS, | ||
fusion_interval=2, | ||
num_special_tokens=8, | ||
num_heads=NUM_HEADS, | ||
num_kv_heads=NUM_KV_HEADS, | ||
embed_dim=EMBED_DIM, | ||
max_seq_len=MAX_SEQ_LEN, | ||
encoder_max_seq_len=2020, # 20*101 | ||
rope_base=500_000, | ||
intermediate_dim=14336, | ||
lora_rank=LORA_RANK, | ||
lora_alpha=LORA_ALPHA, | ||
lora_dropout=0.0, | ||
use_dora=False, | ||
quantize_base=False, | ||
) | ||
return DeepFusionModel( | ||
encoder=encoder, | ||
decoder=decoder, | ||
encoder_trainable=encoder_type != LoRATrainable.FROZEN, | ||
decoder_trainable=decoder_type != LoRATrainable.FROZEN, | ||
fusion_trainable=fusion_type != LoRATrainable.FROZEN, | ||
) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def random(): | ||
set_seed(16) | ||
|
||
|
||
class TestLlamaVisionLora: | ||
@pytest.fixture | ||
def inputs(self): | ||
return torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN)) | ||
|
||
def test_lora_args(self): | ||
model = lora_llama3_2_vision( | ||
LoRATrainable.LORA, | ||
LoRATrainable.FROZEN, | ||
LoRATrainable.FROZEN, | ||
) | ||
encoder = set(get_adapter_params(model).keys()) | ||
assert len(encoder) == 32, "Only the clip encoder should be trainable." | ||
|
||
model = lora_llama3_2_vision( | ||
LoRATrainable.FROZEN, | ||
LoRATrainable.LORA, | ||
LoRATrainable.FROZEN, | ||
) | ||
decoder = set(get_adapter_params(model).keys()) | ||
assert ( | ||
len(decoder) == 32 | ||
), "Only the decoder self attention layers should be trainable." | ||
|
||
model = lora_llama3_2_vision( | ||
LoRATrainable.FROZEN, | ||
LoRATrainable.FROZEN, | ||
LoRATrainable.LORA, | ||
) | ||
fusion = set(get_adapter_params(model).keys()) | ||
assert len(fusion) == 48, "Only the fusion layers should be trainable." | ||
|
||
all_params = set.union(encoder, decoder, fusion) | ||
assert ( | ||
len(all_params) == 48 + 32 + 32 | ||
), "There should be no overlap between options." | ||
|
||
def test_forward(self, inputs): | ||
model = lora_llama3_2_vision( | ||
LoRATrainable.LORA, | ||
LoRATrainable.LORA, | ||
LoRATrainable.LORA, | ||
) | ||
fixed_init_model(model, min_val=-0.25, max_val=0.5) | ||
tokens = torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN)) | ||
image = torch.randn(BSZ, 1, 1, 3, IMAGE_SIZE, IMAGE_SIZE) | ||
aspect_ratio = torch.tensor([[1, 1] for _ in range(BSZ)]) | ||
actual = model( | ||
tokens, encoder_input={"images": image, "aspect_ratio": aspect_ratio} | ||
) | ||
expected = torch.tensor(-3.9763) | ||
assert actual.shape == (BSZ, SEQ_LEN, VOCAB_SIZE) | ||
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters