The repository contains the implementation for the paper: Sparse is Enough in Fine-tuning Pre-trained Large Language Models and the introduction of the general usage of SIFT in different demands.
Sparse is Enough in Fine-tuning Pre-trained Large Language Models
Weixi Song*, Zuchao Li*, Lefei Zhang, Hai Zhao, Bo Du
Paper: https://arxiv.org/abs/2312.11875
In this work, we present a compoent-sparse and memory-efficient updating scheme (SIFT). Inspired by the memory-efficient SGD implementation in LOMO, we implement a component-sparse updating scheme(SIFT) by injecting hook in the backward propagation. See our paper for more details. The main code of SIFT is in sift.py
Through this method, for x% sparse updates, we can simultaneously reduce the memory consumption of gradients and optimizer states to the original x%. Combined with techniques such as mixed-precision training and gradient checkpointing, it is able to fine-tune a 7B model on a single RTX 3090 24GB.
We provide several use cases in Natural Language Processing and it can be applied to different areas in the same way. See exp for experiments in GLUE benchmark and the Instruction-tuning task. The experiments are built on the orginal repositories of Transformers, Alpaca and MMLU. HumanEval Evaluation is conducted in code-eval. Thanks for these great works.
git clone [email protected]:song-wx/SIFT.git
cd SIFT
pip install .
Please solve the dependency issues as needed.
Note: The current implementation only considers training in a single card. If you are interested in training in multiple cards, please modify the code to fit your demand.
Step 1: After initializing your model, run the following code to specify the parameters that need to be updated sparsely by setting sparse_module
and sparse_rate
to customize the sparse training and also you can specify the module to be updated normally by setting exception
.
## initialize your model
model = ...
## initialize SIFT
from sift import SIFT
sift = SIFT(model, sparse_rate=0.01,
sparse_module=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'],
grad_acc=gradient_accumulation,
gradient_checkpointing=gradient_checkpointing)
## you can print the actual trainable numbers in SIFT
sift.print_trainable_parameters()
Step 2: Initialize the optimizer with the actual trainable parameters sift.named_parameters_in_optimizer()
in SIFT.
## example
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in sift.named_parameters_in_optimizer() if not any(nd in n for nd in no_decay) ] ,
"weight_decay": weight_decay,
},
{
"params": [p for n, p in sift.named_parameters_in_optimizer() if any(nd in n for nd in no_decay) ] ,
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)
Step 3: run the training loop normally with model
and optimizer
## if use Trainer
trainer = Trainer(model=model, optimizer=(optimizer, None), ...)
trainer.train()
## if use bare training loop, it is the same as the normal training process.
for i, batch in enumerate(dataloader):
output = model(**batch)
loss = ...
loss.backward()
optimizer.step()
optimizer.zero_()
SIFT essentially creates an additional sparse parametersparse_param
for each parameterp
that needs to be sparsely updated, which is represented in the indexes sparse_param.idx
and the values sparse_param.data
. After initializing SIFT, you can get the sparse parameter sparse_param
of a target parameter p
with the name n
by using the dict sift.sparse_mapping[n]
.
In our paper, we propose a gradient-based selection method based on our finding of the quasi-sparse gradient distribution of the pre-trainde model. We determine the indexes as the components whose absolute gradient of the first few batches are in the top x%.
sparse_idx = torch.flatten(abs(grad)).topk(sparse_param.train_num).indices.cpu().numpy()
sparse_param.idx = np.stack(np.unravel_index(sparse_idx, p.shape))
We compare the efficiency of this gradient-based method with LoRA and random selection in different quotas of the trainable parameters. You can modify the above codes in sift.py to customize your index selection.
Due to SIFT merging sparse_param
into the original p
in the hook to ensure the correct forward propagation(as the following codes), the final updated parameters are the original parameters p
. If you want to store in a memory-effient way, you can store the partial components of p
with sparse_param.idx
otherwise we save the complete p
.
## update the initial param sparsely
delta = p.data + torch.sparse_coo_tensor(sparse_param.idx, sparse_param, p.shape).to(p)
p.data.copy_(delta)
sparse_param.zero_()
@inproceedings{
song2024sparse,
title={Sparse is Enough in Fine-tuning Pre-trained Large Language Models},
author={Weixi Song and Zuchao Li and Lefei Zhang and hai zhao and Bo Du},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=10hu2D3hAg}
}