Skip to content

Commit

Permalink
[Fix] Llama 3.2 Vision decoder_trainable flag fixed (pytorch#2150)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbontrager authored Dec 12, 2024
1 parent 9cfa288 commit cdf5ea2
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 37 deletions.
150 changes: 150 additions & 0 deletions tests/torchtune/models/llama3_2_vision/test_llama_vision_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from tests.test_utils import fixed_init_model
from torchtune.models.llama3_2_vision._component_builders import (
lora_llama3_2_vision_decoder,
lora_llama3_2_vision_encoder,
LoRATrainable,
)
from torchtune.modules.model_fusion import DeepFusionModel
from torchtune.modules.peft import get_adapter_params
from torchtune.training.seed import set_seed

EMBED_DIM = 128
NUM_LAYERS = 4
NUM_HEADS = 16
NUM_KV_HEADS = 8
VOCAB_SIZE = 32000
MAX_SEQ_LEN = 2048
BSZ = 2
SEQ_LEN = 100
LORA_ATTN_MODULES = ["q_proj", "k_proj", "v_proj", "output_proj"]
LORA_RANK = 8
LORA_ALPHA = 16
IMAGE_SIZE = 140
PATCH_SIZE = 14


def lora_llama3_2_vision(
decoder_type,
encoder_type,
fusion_type,
) -> DeepFusionModel:
encoder = lora_llama3_2_vision_encoder(
encoder_lora=encoder_type == LoRATrainable.LORA,
fusion_lora=fusion_type == LoRATrainable.LORA,
lora_attn_modules=LORA_ATTN_MODULES,
apply_lora_to_mlp=False,
apply_lora_to_output=False,
patch_size=PATCH_SIZE,
num_heads=NUM_HEADS,
clip_embed_dim=EMBED_DIM,
clip_num_layers=NUM_LAYERS,
clip_hidden_states=[2],
decoder_embed_dim=EMBED_DIM,
num_layers_projection=NUM_LAYERS,
tile_size=IMAGE_SIZE,
max_num_tiles=1,
in_channels=3,
lora_rank=LORA_RANK,
lora_alpha=LORA_ALPHA,
lora_dropout=0.0,
use_dora=False,
quantize_base=False,
)
decoder = lora_llama3_2_vision_decoder(
decoder_lora=decoder_type == LoRATrainable.LORA,
fusion_lora=fusion_type == LoRATrainable.LORA,
lora_attn_modules=LORA_ATTN_MODULES,
apply_lora_to_mlp=False,
apply_lora_to_output=False,
vocab_size=VOCAB_SIZE,
num_layers=NUM_LAYERS,
fusion_interval=2,
num_special_tokens=8,
num_heads=NUM_HEADS,
num_kv_heads=NUM_KV_HEADS,
embed_dim=EMBED_DIM,
max_seq_len=MAX_SEQ_LEN,
encoder_max_seq_len=2020, # 20*101
rope_base=500_000,
intermediate_dim=14336,
lora_rank=LORA_RANK,
lora_alpha=LORA_ALPHA,
lora_dropout=0.0,
use_dora=False,
quantize_base=False,
)
return DeepFusionModel(
encoder=encoder,
decoder=decoder,
encoder_trainable=encoder_type != LoRATrainable.FROZEN,
decoder_trainable=decoder_type != LoRATrainable.FROZEN,
fusion_trainable=fusion_type != LoRATrainable.FROZEN,
)


@pytest.fixture(autouse=True)
def random():
set_seed(16)


class TestLlamaVisionLora:
@pytest.fixture
def inputs(self):
return torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN))

def test_lora_args(self):
model = lora_llama3_2_vision(
LoRATrainable.LORA,
LoRATrainable.FROZEN,
LoRATrainable.FROZEN,
)
encoder = set(get_adapter_params(model).keys())
assert len(encoder) == 32, "Only the clip encoder should be trainable."

model = lora_llama3_2_vision(
LoRATrainable.FROZEN,
LoRATrainable.LORA,
LoRATrainable.FROZEN,
)
decoder = set(get_adapter_params(model).keys())
assert (
len(decoder) == 32
), "Only the decoder self attention layers should be trainable."

model = lora_llama3_2_vision(
LoRATrainable.FROZEN,
LoRATrainable.FROZEN,
LoRATrainable.LORA,
)
fusion = set(get_adapter_params(model).keys())
assert len(fusion) == 48, "Only the fusion layers should be trainable."

all_params = set.union(encoder, decoder, fusion)
assert (
len(all_params) == 48 + 32 + 32
), "There should be no overlap between options."

def test_forward(self, inputs):
model = lora_llama3_2_vision(
LoRATrainable.LORA,
LoRATrainable.LORA,
LoRATrainable.LORA,
)
fixed_init_model(model, min_val=-0.25, max_val=0.5)
tokens = torch.randint(0, VOCAB_SIZE, (BSZ, SEQ_LEN))
image = torch.randn(BSZ, 1, 1, 3, IMAGE_SIZE, IMAGE_SIZE)
aspect_ratio = torch.tensor([[1, 1] for _ in range(BSZ)])
actual = model(
tokens, encoder_input={"images": image, "aspect_ratio": aspect_ratio}
)
expected = torch.tensor(-3.9763)
assert actual.shape == (BSZ, SEQ_LEN, VOCAB_SIZE)
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-4)
107 changes: 70 additions & 37 deletions torchtune/models/llama3_2_vision/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def lora_llama3_2_vision_encoder(
**lora_options,
)
else:
projection_head = lora_llama3_2_vision_projection_head(**projection_options)
projection_head = llama3_2_vision_projection_head(**projection_options)

encoder = Llama3VisionEncoder(clip=clip, projection_head=projection_head)

Expand Down Expand Up @@ -549,22 +549,37 @@ def lora_llama3_2_vision_decoder(
for idx in range(1, num_layers + 1):

# Self attention layers for text decoder
self_attn = lora_llama3_attention(
lora_modules=lora_attn_modules,
pos_embeddings=rope,
head_dim=head_dim,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_seq_len=max_seq_len,
attn_dropout=0.0,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
if apply_lora_to_mlp:
if decoder_lora:
self_attn = lora_llama3_attention(
lora_modules=lora_attn_modules,
pos_embeddings=rope,
head_dim=head_dim,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_seq_len=max_seq_len,
attn_dropout=0.0,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
else:
self_attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
pos_embeddings=rope,
max_seq_len=max_seq_len,
attn_dropout=0.0,
)
if apply_lora_to_mlp and decoder_lora:
mlp = lora_llama3_mlp(
dim=embed_dim,
hidden_dim=hidden_dim,
Expand All @@ -588,25 +603,43 @@ def lora_llama3_2_vision_decoder(
# cross attention layers, mixing text and vision,
# placed every `fusion_interval` layers
if idx % fusion_interval == 0:
attn = lora_llama3_attention(
lora_modules=lora_attn_modules,
pos_embeddings=None,
head_dim=head_dim,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
q_norm=RMSNorm(dim=head_dim, eps=1e-05),
k_norm=RMSNorm(dim=head_dim, eps=1e-05),
max_seq_len=encoder_max_seq_len,
is_causal=False,
attn_dropout=0.0,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
if apply_lora_to_mlp:
if fusion_lora:
attn = lora_llama3_attention(
lora_modules=lora_attn_modules,
pos_embeddings=None,
head_dim=head_dim,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
q_norm=RMSNorm(dim=head_dim, eps=1e-05),
k_norm=RMSNorm(dim=head_dim, eps=1e-05),
max_seq_len=encoder_max_seq_len,
is_causal=False,
attn_dropout=0.0,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
use_dora=use_dora,
quantize_base=quantize_base,
)
else:
attn = MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
q_norm=RMSNorm(dim=head_dim, eps=1e-05),
k_norm=RMSNorm(dim=head_dim, eps=1e-05),
pos_embeddings=None,
max_seq_len=encoder_max_seq_len,
is_causal=False,
attn_dropout=0.0,
)
if apply_lora_to_mlp and fusion_lora:
mlp = lora_llama3_mlp(
dim=embed_dim,
hidden_dim=hidden_dim,
Expand Down Expand Up @@ -645,7 +678,7 @@ def lora_llama3_2_vision_decoder(
alpha=lora_alpha,
dropout=lora_dropout,
)
if apply_lora_to_output
if apply_lora_to_output and decoder_lora
else nn.Linear(embed_dim, vocab_size, bias=False)
)

Expand Down
10 changes: 10 additions & 0 deletions torchtune/models/llama3_2_vision/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ def lora_llama3_2_vision_11b(
decoder_type = LoRATrainable(decoder_trainable.lower())
encoder_type = LoRATrainable(encoder_trainable.lower())
fusion_type = LoRATrainable(fusion_trainable.lower())
assert LoRATrainable.FULL not in [
decoder_type,
encoder_type,
fusion_type,
], "We've temporarily removed support for mixed LoRA + Full Finetuning yet. Please don't use the 'full' option and use llama3_2_vision_11b if you need full finetuning"
encoder = lora_llama3_2_vision_encoder(
encoder_lora=encoder_type == LoRATrainable.LORA,
fusion_lora=fusion_type == LoRATrainable.LORA,
Expand Down Expand Up @@ -325,6 +330,11 @@ def lora_llama3_2_vision_90b(
decoder_type = LoRATrainable(decoder_trainable.lower())
encoder_type = LoRATrainable(encoder_trainable.lower())
fusion_type = LoRATrainable(fusion_trainable.lower())
assert LoRATrainable.FULL not in [
decoder_type,
encoder_type,
fusion_type,
], "We've temporarily removed support for mixed LoRA + Full Finetuning yet. Please don't use the 'full' option and use llama3_2_vision_90b if you need full finetuning"
encoder = lora_llama3_2_vision_encoder(
encoder_lora=encoder_type == LoRATrainable.LORA,
fusion_lora=fusion_type == LoRATrainable.LORA,
Expand Down

0 comments on commit cdf5ea2

Please sign in to comment.