SplitLLM is a Split Learning simulation framework designed for Large Language Models (LLMs). It enables flexible and scalable model fine-tuning under a split learning architecture. The framework is compatible with Hugging Face models.
SplitLLM supports extensible integration of privacy attack experiments, including mainstream DRAs like DLG, TAG, LAMP. The proposed Bidirectional Semi-white-box Reconstruction (BiSR) attack is also demonstrated in the example.
conda create -n sfl python=3.11
conda activate sfl
conda install pytorch==2.3.0 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install -r requirements.txt
- Go to
sfl/config
and modifydataset_cache_dir
,model_download_dir
,model_cache_dir
to your own path - Run the following commands to download models, some of them may require a qualified huggingface token
cd experiments/script
python model_download.py --repo_id meta-llama/Llama-2-7b-chat-hf
python model_download.py --repo_id gpt2-large
python model_download.py --repo_id THUDM/chatglm3-6b
...
#python model_download.py --repo_id FacebookAI/roberta-large
#python model_download.py --repo_id google-bert/bert-large-uncased
#python model_download.py --repo_id google/flan-t5-base
#python model_download.py --repo_id google/flan-ul2-base
#python model_download.py --repo_id meta-llama/Meta-Llama-3-8B
#python model_download.py --repo_id lucyknada/microsoft_WizardLM-2-7B
#python model_download.py --repo_id lmsys/vicuna-7b-v1.5
#python model_download.py --repo_id tiiuae/falcon-7b-instruct
#python model_download.py --repo_id Salesforce/codegen25-7b-instruct_P
#python model_download.py --repo_id EleutherAI/gpt-j-6b
#python model_download.py --repo_id google/flan-ul2
#python model_download.py --repo_id google/vit-large-patch16-224
#python model_download.py --repo_id bigscience/bloomz-560m
#python model_download.py --repo_id state-spaces/mamba-1.4b-hf
- (Optional) Use parallelized implementation of Mamba
causal-conv1d
&mamba-ssm
using:
pip install causal-conv1d>=1.2.0
pip install mamba-ssm
Note that the default implementation of Mamba is sequential.
cd $dataset_cache_dir
git clone https://huggingface.co/datasets/wikitext.git
git clone https://huggingface.co/datasets/piqa.git
git clone https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K.git
git clone https://huggingface.co/datasets/knkarthick/dialogsum.git
git clone https://huggingface.co/datasets/gsm8k.git
git clone https://huggingface.co/datasets/imdb.git
git clone https://huggingface.co/datasets/Hello-SimpleAI/HC3-Chinese.git
git clone https://huggingface.co/datasets/frgfm/imagewoof.git
git clone https://huggingface.co/datasets/SetFit/qnli.git
git clone https://huggingface.co/datasets/linxinyuan/cola.git
- Run the script in
experiments/scripts/pipeline/demo_bisr.sh
Note that the script requires wandb to be installed and configured.
In SL, a model is divided into three parts:
Bottom Layers | Trunk Layers(Adapters) | Top Layers |
---|
where Bottom-Layers and Top-Layers are input and output end of the model, and Trunk-Layers are the middle part of the model.
To simulate Split Federated Learning (SFL), we do not employ the approach of physically splitting the model in code implementation. Instead, we independently maintain different parts of the model's parameters. We simulate Client training in a serial manner, without actual Client parallelism.
The simulation process is as follows:
- At the start of a round of federated learning, select Clients 0, 1, and 2.
- Load the model parameters of Client 0 (including bottom and top layers, and its corresponding trunk on the Server) from disk into the GPU model.
- Client 0 performs local training, updating all model parameters, following the process consistent with centralized learning.
- Once Client 0 completes training, save the model parameters to disk.
- Load the model parameters of Client 1 from disk into the GPU model.
- ...
- At the end of the federated learning round, aggregate the trunk parameters corresponding to all clients on disk to obtain the average trunk. Then, update the trunk parameters of all clients to the average trunk.
- Repeat steps 1-7.