Skip to content

Commit

Permalink
Merge pull request #8 from shadowpa0327/fix/missing_bias
Browse files Browse the repository at this point in the history
[Bug Fix] Add missing bias in HeadwiseLowRankModule
  • Loading branch information
shadowpa0327 authored Oct 3, 2024
2 parents 0b9cebd + d468e50 commit 4407fb5
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 11 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pip install -e 3rdparty/fast-hadamard-transform
We provide a script `compress.py` to perform the rank search and low-rank decomposition to generate the low-rank projection matrices for compressing KV-Cache. Here, we perform the decomposition with proposed `G-LRD` methods with group size equal to 4 as an example.
```bash
python compress.py \
--model_id="meta-llama/Llama-2-7b-hf" \
--model_id=/Path/To/Pretrained/Model \
--calib_dataset wikitext2 \
--param_ratio_target 0.7 \
--search_method fisher_uniform \
Expand All @@ -58,7 +58,7 @@ python compress.py \
--use_cache
```

After executing the above command, a compressed models with decomposed low-rank projection matrices will be dumped into the `Llama-2-7b-hf_ratio-0.5_gs-4-fisher_uniform` directory. Here, the dumped models is stored via the huggingface transformers format.
After executing the above command, a compressed models with decomposed low-rank projection matrices will be dumped into the `{MODEL_NAME}-ratio-{TARGET_RATIO}_gs-{GROUP_SIZE}-{SEARCH_METHOD}-{DECOMPOSE_METHODS}` directory. Here, the dumped models is stored via the huggingface transformers format.

### Evaluation
With the compressed model dumped, we can evaluate the performance of the compressed model on the various tasks. We provide the scripts for evaluating the perplexity, zero-shot evaluation, and LongBench. By default, we will keep the compressed KV-Cache in fp16.
Expand Down Expand Up @@ -93,7 +93,7 @@ Before we start, please make sure the `lm-eval==0.4.2` library is installed.

To reproduce the results in our paper, simply execute:
```bash
CUDA_VISIBLE_DEVICES=0 python run_lm_eval.py --model_name_or_path "./Meta-Llama-3-8b-Instruct_ratio-0.7_gs-4-fisher_uniform" \
CUDA_VISIBLE_DEVICES=0 python run_lm_eval.py --model_name_or_path /Path/To/Palu/Model \
--tasks "openbookqa,hellaswag,piqa,arc_easy,arc_challenge,winogrande"
```

Expand Down
18 changes: 13 additions & 5 deletions compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from utils import set_seed, dump_to_huggingface_repos, load_model_and_tokenizer
from palu.rank_search import rank_search
from tqdm import tqdm
from palu.decomposition import compress_model_whiten

from palu.decomposition import compress_model
from run_lm_eval import run_lm_eval_zero_shot
import os

def compress(args):
Expand All @@ -19,10 +19,10 @@ def compress(args):
# Step 1: Perform rank selection to get layer-wise compression rate
search_results, rank_sum, total_rank = rank_search(model, tokenizer, args)
# Step 2: Compress models
compress_model_whiten(model, tokenizer, args, torch.device("cuda"), search_results)
compress_model(model, tokenizer, args, args.device, search_results)

if args.dump_huggingface_model:
save_folder = f"{args.model_id.split('/')[-1]}_ratio-{args.param_ratio_target}_gs-{args.head_group_size}-{args.search_method}"
save_folder = f"{args.model_id.split('/')[-1]}_ratio-{args.param_ratio_target}_gs-{args.head_group_size}-{args.search_method}-{args.decompose_method}"
dump_to_huggingface_repos(model, tokenizer, save_folder, args)
logger.info(f"Huggingface model is saved to {save_folder}", fg="green")

Expand Down Expand Up @@ -114,11 +114,19 @@ def compress(args):
parser.add_argument(
"--search_method",
type=str,
default="STRS",
default="fisher_uniform",
choices=["fisher", "fisher_uniform", "uniform"],
help="Search method",
)

parser.add_argument(
'--decompose_method',
type=str,
default='whiten',
choices=['whiten', 'svd'],
help='Decomposition method'
)

args = parser.parse_args()

logger.remove()
Expand Down
45 changes: 44 additions & 1 deletion palu/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def hook(module, input, output):
logger.info(f"Save the whiten scale matrix dict to: {cache_file}")

def compress_model_whiten(model, tokenizer, args, dev, selection_result):
logger.info("Compressing model with whiten decomposition...")
# NOTE(brian1009): Prepare whiten scaling matrix
get_whiten_scale_matrix(model, tokenizer, args, dev)
# Compress the model
Expand Down Expand Up @@ -222,4 +223,46 @@ def compress_model_whiten(model, tokenizer, args, dev, selection_result):
raw_linear,
selected_head_rank
)
setattr(info["father"], info["name"], head_wise_svd_linear)
setattr(info["father"], info["name"], head_wise_svd_linear)

def compress_model_svd(model, selection_result):
logger.info("Compressing model with svd decomposition...")
# Compress the model
module_dict = {name: module for name, module in model.named_modules()}
full_name_dict = {module: name for name, module in model.named_modules()}
linear_info = {}
modules = [model]
while len(modules) > 0:
submodule = modules.pop()
for name, raw_linear in submodule.named_children():
if isinstance(raw_linear, nn.Linear):
full_name = full_name_dict[raw_linear]
linear_info[raw_linear] = {
"father": submodule,
"name": name,
"full_name": full_name,
}
else:
modules.append(raw_linear)

logger.info(f"Start decompose the layer with selected ranks... #target layers: {len(selection_result.keys())}")
for layername, selected_head_rank in tqdm(selection_result.items()):
logger.debug(f"Decompose {layername} with ranks: {selected_head_rank}")
# set ratio
raw_linear = module_dict[layername]
info = linear_info[raw_linear]
print("head-wise svd", layername, raw_linear)
head_wise_svd_linear = HeadwiseLowRankModule.from_linear(
raw_linear,
selected_head_rank
)
setattr(info["father"], info["name"], head_wise_svd_linear)

# Wrapper for different decompose methods
def compress_model(model, tokenizer, args, dev, selection_result):
if args.decompose_method == "whiten":
compress_model_whiten(model, tokenizer, args, dev, selection_result)
elif args.decompose_method == "svd":
compress_model_svd(model, selection_result)
else:
raise ValueError(f"Decomposition method {args.decompose_method} is not supported.")
11 changes: 11 additions & 0 deletions palu/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
PaluMistralConfig,
PaluMistralForCausalLM
)

#qwen
from .svd_qwen import (
PaluQwen2Config,
PaluQwen2ForCausalLM
)

#modules
from .modules import (
HeadwiseLowRankModule
Expand All @@ -26,5 +33,9 @@
'mistral': {
'config': PaluMistralConfig,
'ModelForCausalLM': PaluMistralForCausalLM
},
'qwen2': {
'config': PaluQwen2Config,
'ModelForCausalLM': PaluQwen2ForCausalLM
}
}
57 changes: 56 additions & 1 deletion palu/model/modules/svd_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ def _per_head_whiten_decomposition_from_weight(weight, scaling_diag_matrix, rank

return L, R

def _per_head_decomposition_from_weight(weight, rank):
original_dtype = weight.dtype
# Get weight matrix decomposed
U, S, Vt = torch.linalg.svd(weight.to(torch.float32), full_matrices=False)

# Low rank approximation to the target rank
U = U[:, :rank]
S = S[:rank]
Vt = Vt[:rank, :]

sqrtSigma = torch.sqrt(torch.diag(S))
# Fuse the SVD components
L = torch.matmul(U, sqrtSigma).to(original_dtype)
R = torch.matmul(sqrtSigma, Vt).to(original_dtype)
assert torch.allclose(torch.matmul(L, R), weight, atol=1e-3), "SVD decomposition failed"
return L, R

class HeadwiseLowRankModule(nn.Module):
""" Headwise low rank module """

Expand Down Expand Up @@ -157,7 +174,10 @@ def from_linear_whiten(
):
new_module = HeadwiseLowRankModule(ranks, old_module.in_features, old_module.out_features, bias=old_module.bias is not None)
w = old_module.weight.data.reshape(len(ranks), -1, old_module.in_features)

# Handle the cases where the bias is not None
if old_module.bias is not None:
b = old_module.bias.data.reshape(len(ranks), -1)

wl = []
wr = []
for i in range(len(ranks)):
Expand All @@ -171,11 +191,46 @@ def from_linear_whiten(
if new_module.U[i].weight.data.shape != wl[i].shape:
raise ValueError(f"{new_module.U[i].weight.data.shape} != {wl[i].shape}")
new_module.U[i].weight.data = wl[i].contiguous()
# Handle the cases where the bias is not None
if old_module.bias is not None:
new_module.U[i].bias.data = b[i]

# load to VT
# shape (sum(ranks), hidden_size)
VT_weight = torch.cat(wr, dim=0).contiguous()
assert new_module.VT.weight.data.shape == VT_weight.shape
new_module.VT.weight.data = VT_weight

return new_module

@staticmethod
def from_linear(
old_module: nn.Linear,
ranks: list,
):
new_module = HeadwiseLowRankModule(ranks, old_module.in_features, old_module.out_features, bias=old_module.bias is not None)
w = old_module.weight.data.reshape(len(ranks), -1, old_module.in_features)
if old_module.bias is not None:
b = old_module.bias.data.reshape(len(ranks), -1)
wl = []
wr = []
for i in range(len(ranks)):
l, r = _per_head_decomposition_from_weight(w[i], ranks[i])
# l: (head_dim, rank), r: (rank, hidden_size)
wl.append(l)
wr.append(r)

# load to U
for i in range(len(ranks)):
if new_module.U[i].weight.data.shape != wl[i].shape:
raise ValueError(f"{new_module.U[i].weight.data.shape} != {wl[i].shape}")
new_module.U[i].weight.data = wl[i].contiguous()
if old_module.bias is not None:
new_module.U[i].bias.data = b[i]
# load to VT
# shape (sum(ranks), hidden_size)
VT_weight = torch.cat(wr, dim=0).contiguous()
assert new_module.VT.weight.data.shape == VT_weight.shape
new_module.VT.weight.data = VT_weight

return new_module
8 changes: 8 additions & 0 deletions palu/model/svd_qwen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Qwen2Tokenizer
from .configuration_palu_qwen import PaluQwen2Config
from .modeling_palu_qwen import PaluQwen2ForCausalLM

AutoConfig.register("paluqwen2", PaluQwen2Config)
AutoModelForCausalLM.register(PaluQwen2Config, PaluQwen2ForCausalLM)
AutoTokenizer.register(PaluQwen2Config, Qwen2Tokenizer)

61 changes: 61 additions & 0 deletions palu/model/svd_qwen/configuration_palu_qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)

class PaluQwen2Config(PretrainedConfig):
model_type = "paluqwen2"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
# [Palu]
head_wise_ranks=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window if use_sliding_window else None
self.max_window_layers = max_window_layers

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout

super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

# for avsd
self.head_wise_ranks = head_wise_ranks
64 changes: 64 additions & 0 deletions palu/model/svd_qwen/modeling_palu_qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from transformers import Qwen2ForCausalLM
import torch.nn as nn
from types import SimpleNamespace
from .configuration_palu_qwen import PaluQwen2Config
from ..modules.svd_linear import HeadwiseLowRankModule

class PaluQwen2ForCausalLM(Qwen2ForCausalLM):
config_class = PaluQwen2Config
def __init__(self, config: PaluQwen2Config):
super().__init__(config)
self.head_wise_ranks=config.head_wise_ranks

full_name_dict = {module: name for name, module in self.named_modules()}
linear_info = {}
modules = [self]
while len(modules) > 0:
submodule = modules.pop()
for name, raw_linear in submodule.named_children():
if isinstance(raw_linear, nn.Linear):
full_name = full_name_dict[raw_linear]
linear_info[raw_linear] = {
"father": submodule,
"name": name,
"full_name": full_name,
}
else:
modules.append(raw_linear)


for name, module in self.named_modules():
if name in self.head_wise_ranks:
info = linear_info[module]
new_layer = HeadwiseLowRankModule(
self.head_wise_ranks[name],
module.in_features,
module.out_features,
bias=module.bias is not None
)
setattr(info["father"], info["name"], new_layer)


@staticmethod
def get_kv_info(qwen2: Qwen2ForCausalLM, num_heads_in_lr_groups: int):
num_lr_groups = qwen2.config.num_attention_heads // num_heads_in_lr_groups
num_lr_kv_groups = qwen2.config.num_key_value_heads // num_heads_in_lr_groups
head_dim = qwen2.config.hidden_size // qwen2.config.num_attention_heads
lr_group_dims = head_dim * num_heads_in_lr_groups

if num_lr_groups * num_heads_in_lr_groups != qwen2.config.num_attention_heads:
raise ValueError(
f"num_heads must be divisible by num_heads_in_lr_groups (got `num_heads`: {qwen2.config.num_attention_heads}"
f" and `num_heads_in_lr_groups`: {num_heads_in_lr_groups})."
)

if num_lr_kv_groups * num_heads_in_lr_groups != qwen2.config.num_key_value_heads:
raise ValueError(
f"num_key_value_heads must be divisible by num_heads_in_lr_groups (got `num_key_value_heads`: {qwen2.config.num_key_value_heads}"
f" and `num_heads_in_lr_groups`: {num_heads_in_lr_groups})."
)

return SimpleNamespace(
num_lr_groups=num_lr_kv_groups,
lr_group_dims=lr_group_dims,
)
2 changes: 1 addition & 1 deletion palu/rank_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def rank_search(model: nn.Module, tokenizer, args):
target_model_class = AVAILABLE_MODELS[model.config.model_type]["ModelForCausalLM"]
total_rank = 0
select_result = {}
info = target_model_class.get_info(model, args.head_group_size)
info = target_model_class.get_kv_info(model, args.head_group_size)

for name, module in model.named_modules():
if "k_proj" in name or "v_proj" in name:
Expand Down
3 changes: 3 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def dump_to_huggingface_repos(model, tokenizer, save_path, args):
elif "mistral" in model.config._name_or_path.lower():
config["model_type"] = "palumistral"
config['architectures'] = ['PaluMistralForCausalLM']
elif "qwen2" in model.config._name_or_path.lower():
config["model_type"] = "paluqwen2"
config['architectures'] = ['PaluQwenForCausalLM']
else:
raise NotImplementedError

Expand Down

0 comments on commit 4407fb5

Please sign in to comment.