-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from shadowpa0327/fix/missing_bias
[Bug Fix] Add missing bias in HeadwiseLowRankModule
- Loading branch information
Showing
10 changed files
with
264 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Qwen2Tokenizer | ||
from .configuration_palu_qwen import PaluQwen2Config | ||
from .modeling_palu_qwen import PaluQwen2ForCausalLM | ||
|
||
AutoConfig.register("paluqwen2", PaluQwen2Config) | ||
AutoModelForCausalLM.register(PaluQwen2Config, PaluQwen2ForCausalLM) | ||
AutoTokenizer.register(PaluQwen2Config, Qwen2Tokenizer) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from transformers.configuration_utils import PretrainedConfig | ||
from transformers.utils import logging | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
class PaluQwen2Config(PretrainedConfig): | ||
model_type = "paluqwen2" | ||
keys_to_ignore_at_inference = ["past_key_values"] | ||
|
||
def __init__( | ||
self, | ||
vocab_size=151936, | ||
hidden_size=4096, | ||
intermediate_size=22016, | ||
num_hidden_layers=32, | ||
num_attention_heads=32, | ||
num_key_value_heads=32, | ||
hidden_act="silu", | ||
max_position_embeddings=32768, | ||
initializer_range=0.02, | ||
rms_norm_eps=1e-6, | ||
use_cache=True, | ||
tie_word_embeddings=False, | ||
rope_theta=10000.0, | ||
use_sliding_window=False, | ||
sliding_window=4096, | ||
max_window_layers=28, | ||
attention_dropout=0.0, | ||
# [Palu] | ||
head_wise_ranks=None, | ||
**kwargs, | ||
): | ||
self.vocab_size = vocab_size | ||
self.max_position_embeddings = max_position_embeddings | ||
self.hidden_size = hidden_size | ||
self.intermediate_size = intermediate_size | ||
self.num_hidden_layers = num_hidden_layers | ||
self.num_attention_heads = num_attention_heads | ||
self.use_sliding_window = use_sliding_window | ||
self.sliding_window = sliding_window if use_sliding_window else None | ||
self.max_window_layers = max_window_layers | ||
|
||
# for backward compatibility | ||
if num_key_value_heads is None: | ||
num_key_value_heads = num_attention_heads | ||
|
||
self.num_key_value_heads = num_key_value_heads | ||
self.hidden_act = hidden_act | ||
self.initializer_range = initializer_range | ||
self.rms_norm_eps = rms_norm_eps | ||
self.use_cache = use_cache | ||
self.rope_theta = rope_theta | ||
self.attention_dropout = attention_dropout | ||
|
||
super().__init__( | ||
tie_word_embeddings=tie_word_embeddings, | ||
**kwargs, | ||
) | ||
|
||
# for avsd | ||
self.head_wise_ranks = head_wise_ranks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from transformers import Qwen2ForCausalLM | ||
import torch.nn as nn | ||
from types import SimpleNamespace | ||
from .configuration_palu_qwen import PaluQwen2Config | ||
from ..modules.svd_linear import HeadwiseLowRankModule | ||
|
||
class PaluQwen2ForCausalLM(Qwen2ForCausalLM): | ||
config_class = PaluQwen2Config | ||
def __init__(self, config: PaluQwen2Config): | ||
super().__init__(config) | ||
self.head_wise_ranks=config.head_wise_ranks | ||
|
||
full_name_dict = {module: name for name, module in self.named_modules()} | ||
linear_info = {} | ||
modules = [self] | ||
while len(modules) > 0: | ||
submodule = modules.pop() | ||
for name, raw_linear in submodule.named_children(): | ||
if isinstance(raw_linear, nn.Linear): | ||
full_name = full_name_dict[raw_linear] | ||
linear_info[raw_linear] = { | ||
"father": submodule, | ||
"name": name, | ||
"full_name": full_name, | ||
} | ||
else: | ||
modules.append(raw_linear) | ||
|
||
|
||
for name, module in self.named_modules(): | ||
if name in self.head_wise_ranks: | ||
info = linear_info[module] | ||
new_layer = HeadwiseLowRankModule( | ||
self.head_wise_ranks[name], | ||
module.in_features, | ||
module.out_features, | ||
bias=module.bias is not None | ||
) | ||
setattr(info["father"], info["name"], new_layer) | ||
|
||
|
||
@staticmethod | ||
def get_kv_info(qwen2: Qwen2ForCausalLM, num_heads_in_lr_groups: int): | ||
num_lr_groups = qwen2.config.num_attention_heads // num_heads_in_lr_groups | ||
num_lr_kv_groups = qwen2.config.num_key_value_heads // num_heads_in_lr_groups | ||
head_dim = qwen2.config.hidden_size // qwen2.config.num_attention_heads | ||
lr_group_dims = head_dim * num_heads_in_lr_groups | ||
|
||
if num_lr_groups * num_heads_in_lr_groups != qwen2.config.num_attention_heads: | ||
raise ValueError( | ||
f"num_heads must be divisible by num_heads_in_lr_groups (got `num_heads`: {qwen2.config.num_attention_heads}" | ||
f" and `num_heads_in_lr_groups`: {num_heads_in_lr_groups})." | ||
) | ||
|
||
if num_lr_kv_groups * num_heads_in_lr_groups != qwen2.config.num_key_value_heads: | ||
raise ValueError( | ||
f"num_key_value_heads must be divisible by num_heads_in_lr_groups (got `num_key_value_heads`: {qwen2.config.num_key_value_heads}" | ||
f" and `num_heads_in_lr_groups`: {num_heads_in_lr_groups})." | ||
) | ||
|
||
return SimpleNamespace( | ||
num_lr_groups=num_lr_kv_groups, | ||
lr_group_dims=lr_group_dims, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters