diff --git a/docs/source/developer_guides/torch_compile.md b/docs/source/developer_guides/torch_compile.md index 13c81a4f03..8d88f758b9 100644 --- a/docs/source/developer_guides/torch_compile.md +++ b/docs/source/developer_guides/torch_compile.md @@ -43,6 +43,7 @@ The following adapters were tested successfully: - LoRA + DoRA - OFT - VeRA +- HRA The following adapters **don't work** correctly for training or inference when using `torch.compile`: diff --git a/examples/hra_dreambooth/README.md b/examples/hra_dreambooth/README.md new file mode 100644 index 0000000000..11530edc3b --- /dev/null +++ b/examples/hra_dreambooth/README.md @@ -0,0 +1,98 @@ + + +# DreamBooth fine-tuning with HRA + +This guide demonstrates how to use Householder reflection adaptation (HRA) method, to fine-tune Dreambooth with `stabilityai/stable-diffusion-2-1` model. + +HRA provides a new perspective connecting LoRA to OFT and achieves encouraging performance in various downstream tasks. +HRA adapts a pre-trained model by multiplying each frozen weight matrix with a chain of r learnable Householder reflections (HRs). +HRA can be interpreted as either an OFT adapter or an adaptive LoRA. +Consequently, it harnesses the advantages of both strategies, reducing parameters and computation costs while penalizing the loss of pre-training knowledge. +For further details on HRA, please consult the [original HRA paper](https://arxiv.org/abs/2405.17484). + +In this guide we provide a Dreambooth fine-tuning script that is available in [PEFT's GitHub repo examples](https://github.com/huggingface/peft/tree/main/examples/hra_dreambooth). This implementation is adapted from [peft's boft_dreambooth](https://github.com/huggingface/peft/tree/main/examples/boft_dreambooth). + +You can try it out and fine-tune on your custom images. + +## Set up your environment + +Start by cloning the PEFT repository: + +```bash +git clone --recursive https://github.com/huggingface/peft +``` + +Navigate to the directory containing the training scripts for fine-tuning Dreambooth with HRA: + +```bash +cd peft/examples/hra_dreambooth +``` + +Set up your environment: install PEFT, and all the required libraries. At the time of writing this guide we recommend installing PEFT from source. The following environment setup should work on A100 and H100: + +```bash +conda create --name peft python=3.10 +conda activate peft +conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia +conda install xformers -c xformers +pip install -r requirements.txt +pip install git+https://github.com/huggingface/peft +``` + +## Download the data + +[dreambooth](https://github.com/google/dreambooth) dataset should have been automatically cloned in the following structure when running the training script. + +``` +hra_dreambooth +├── data +│ └── dreambooth +│ └── dataset +│ ├── backpack +│ └── backpack_dog +│ ... +``` + +You can also put your custom images into `hra_dreambooth/data/dreambooth/dataset`. + +## Fine-tune Dreambooth with HRA + +```bash +class_idx=0 +bash ./train_dreambooth.sh $class_idx +``` + +where the `$class_idx` corresponds to different subjects ranging from 0 to 29. + +Launch the training script with `accelerate` and pass hyperparameters, as well as LoRa-specific arguments to it such as: + +- `use_hra`: Enables HRA in the training script. +- `hra_r`: the number of HRs (i.e., r) across different layers, expressed in `int`. +As r increases, the number of trainable parameters increases, which generally leads to improved performance. +However, this also results in higher memory consumption and longer computation times. +Therefore, r is usually set to 8. +**Note**, please set r to an even number to avoid potential issues during initialization. +- `hra_apply_GS`: Applies Gram-Schmidt orthogonalization. Default is `false`. +- `hra_bias`: specify if the `bias` parameters should be trained. Can be `none`, `all` or `hra_only`. + +If you are running this script on Windows, you may need to set the `--num_dataloader_workers` to 0. + +To learn more about DreamBooth fine-tuning with prior-preserving loss, check out the [Diffusers documentation](https://huggingface.co/docs/diffusers/training/dreambooth#finetuning-with-priorpreserving-loss). + +## Generate images with the fine-tuned model + +To generate images with the fine-tuned model, simply run the jupyter notebook `dreambooth_inference.ipynb` for visualization with `jupyter notebook` under `./examples/hra_dreambooth`. diff --git a/examples/hra_dreambooth/a_purple_qwe_backpack.png b/examples/hra_dreambooth/a_purple_qwe_backpack.png new file mode 100644 index 0000000000..c2784a37c6 Binary files /dev/null and b/examples/hra_dreambooth/a_purple_qwe_backpack.png differ diff --git a/examples/hra_dreambooth/dreambooth_inference.ipynb b/examples/hra_dreambooth/dreambooth_inference.ipynb new file mode 100644 index 0000000000..9cab0d0d24 --- /dev/null +++ b/examples/hra_dreambooth/dreambooth_inference.ipynb @@ -0,0 +1,221 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 19, + "id": "acab479f", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from PIL import Image\n", + "\n", + "import torch\n", + "from accelerate.logging import get_logger\n", + "from diffusers import StableDiffusionPipeline\n", + "from diffusers.utils import check_min_version\n", + "\n", + "from peft import PeftModel\n", + "\n", + "\n", + "# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\n", + "check_min_version(\"0.10.0.dev0\")\n", + "\n", + "logger = get_logger(__name__)\n", + "\n", + "MODEL_NAME = \"stabilityai/stable-diffusion-2-1\"\n", + "\n", + "PEFT_TYPE=\"hra\"\n", + "HRA_R=8\n", + "SELECTED_SUBJECT=\"backpack\"\n", + "EPOCH_IDX = 1000\n", + "\n", + "PROJECT_NAME=f\"dreambooth_{PEFT_TYPE}\"\n", + "RUN_NAME=f\"{SELECTED_SUBJECT}_{PEFT_TYPE}_{HRA_R}\"\n", + "OUTPUT_DIR=f\"./data/output/{PEFT_TYPE}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "06cfd506", + "metadata": {}, + "outputs": [], + "source": [ + "def get_hra_sd_pipeline(\n", + " ckpt_dir, base_model_name_or_path=None, epoch=int, dtype=torch.float32, device=\"cuda\", adapter_name=\"default\"\n", + "):\n", + "\n", + " if base_model_name_or_path is None:\n", + " raise ValueError(\"Please specify the base model name or path\")\n", + "\n", + " pipe = StableDiffusionPipeline.from_pretrained(\n", + " base_model_name_or_path, torch_dtype=dtype, requires_safety_checker=False\n", + " ).to(device)\n", + " \n", + " load_adapter(pipe, ckpt_dir, epoch, adapter_name)\n", + "\n", + " if dtype in (torch.float16, torch.bfloat16):\n", + " pipe.unet.half()\n", + " pipe.text_encoder.half()\n", + "\n", + " pipe.to(device)\n", + " return pipe\n", + "\n", + "\n", + "def load_adapter(pipe, ckpt_dir, epoch, adapter_name=\"default\"):\n", + " \n", + " unet_sub_dir = os.path.join(ckpt_dir, f\"unet/{epoch}\", adapter_name)\n", + " text_encoder_sub_dir = os.path.join(ckpt_dir, f\"text_encoder/{epoch}\", adapter_name)\n", + " \n", + " if isinstance(pipe.unet, PeftModel):\n", + " pipe.unet.load_adapter(unet_sub_dir, adapter_name=adapter_name)\n", + " else:\n", + " pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)\n", + " \n", + " if os.path.exists(text_encoder_sub_dir):\n", + " if isinstance(pipe.text_encoder, PeftModel):\n", + " pipe.text_encoder.load_adapter(text_encoder_sub_dir, adapter_name=adapter_name)\n", + " else:\n", + " pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)\n", + " \n", + "\n", + "def set_adapter(pipe, adapter_name):\n", + " pipe.unet.set_adapter(adapter_name)\n", + " if isinstance(pipe.text_encoder, PeftModel):\n", + " pipe.text_encoder.set_adapter(adapter_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "98a0d8ac", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"a purple qwe backpack.\"\n", + "negative_prompt = \"low quality, blurry, unfinished\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d4e888d2", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 14.47it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1.72 s, sys: 495 ms, total: 2.22 s\n", + "Wall time: 2.28 s\n" + ] + } + ], + "source": [ + "%%time\n", + "pipe = get_hra_sd_pipeline(OUTPUT_DIR, MODEL_NAME, EPOCH_IDX, adapter_name=RUN_NAME)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "f1c1a1c0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/50 [00:00" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%time\n", + "image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]\n", + "image" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "60fa38d2", + "metadata": {}, + "outputs": [ + { + "data": { + "image/jpeg": "", + "image/png": "", + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This is an example.\n", + "example_image = Image.open(\"./a_purple_qwe_backpack.png\")\n", + "example_image" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llama", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/hra_dreambooth/requirements.txt b/examples/hra_dreambooth/requirements.txt new file mode 100644 index 0000000000..3cc487940f --- /dev/null +++ b/examples/hra_dreambooth/requirements.txt @@ -0,0 +1,13 @@ +transformers==4.36.2 +accelerate==0.25.0 +evaluate +tqdm +datasets==2.16.1 +diffusers==0.17.1 +Pillow +huggingface_hub +safetensors +nb_conda_kernels +ipykernel +ipywidgets +wandb==0.16.1 \ No newline at end of file diff --git a/examples/hra_dreambooth/train_dreambooth.py b/examples/hra_dreambooth/train_dreambooth.py new file mode 100644 index 0000000000..fdf15b9d92 --- /dev/null +++ b/examples/hra_dreambooth/train_dreambooth.py @@ -0,0 +1,609 @@ +#!/usr/bin/env python +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The implementation is based on "Bridging The Gap between Low-rank and Orthogonal +# Adaptation via Householder Reflection Adaptation" (https://arxiv.org/abs/2405.17484). + +import hashlib +import itertools +import logging +import math +import os +from contextlib import nullcontext +from pathlib import Path + +import datasets +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from huggingface_hub import Repository +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from utils.args_loader import ( + get_full_repo_name, + import_model_class_from_model_name_or_path, + parse_args, +) +from utils.dataset import DreamBoothDataset, PromptDataset, collate_fn +from utils.tracemalloc import TorchTracemalloc, b2mb + +from peft import HRAConfig, get_peft_model + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.16.0.dev0") + +logger = get_logger(__name__) + +UNET_TARGET_MODULES = ["to_q", "to_v", "to_k", "query", "value", "key", "to_out.0", "add_k_proj", "add_v_proj"] +TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] + + +def save_adaptor(accelerator, step, unet, text_encoder, args): + unwarpped_unet = accelerator.unwrap_model(unet) + unwarpped_unet.save_pretrained( + os.path.join(args.output_dir, f"unet/{step}"), state_dict=accelerator.get_state_dict(unet) + ) + if args.train_text_encoder: + unwarpped_text_encoder = accelerator.unwrap_model(text_encoder) + unwarpped_text_encoder.save_pretrained( + os.path.join(args.output_dir, f"text_encoder/{step}"), + state_dict=accelerator.get_state_dict(text_encoder), + ) + + +def main(args): + validation_prompts = list(filter(None, args.validation_prompt[0].split("."))) + + logging_dir = Path(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to if args.report_to != "none" else None, + project_dir=accelerator_project_config, + ) + if args.report_to == "wandb": + import wandb + + args.wandb_project_name = args.project_name + args.wandb_run_name = args.run_name + wandb_init = { + "wandb": { + "name": args.wandb_run_name, + "mode": "online", + } + } + + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + global_seed = hash(args.run_name) % (2**32) + set_seed(global_seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) # noqa: F841 + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + if args.use_hra: + config = HRAConfig( + r=args.hra_r, + apply_GS=args.hra_apply_GS, + target_modules=UNET_TARGET_MODULES, + bias=args.hra_bias, + ) + unet = get_peft_model(unet, config, adapter_name=args.run_name) + unet.print_trainable_parameters() + + vae.requires_grad_(False) + unet.train() + + if args.train_text_encoder and args.use_hra: + config = HRAConfig( + r=args.hra_r, + apply_GS=args.hra_apply_GS, + target_modules=UNET_TARGET_MODULES, + bias=args.hra_bias, + ) + text_encoder = get_peft_model(text_encoder, config, adapter_name=args.run_name) + text_encoder.print_trainable_parameters() + text_encoder.train() + else: + text_encoder.requires_grad_(False) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + # below fails when using hra so commenting it out + if args.train_text_encoder and not args.use_hra: + text_encoder.gradient_checkpointing_enable() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = [param for param in unet.parameters() if param.requires_grad] + + if args.train_text_encoder: + params_to_optimize += [param for param in text_encoder.parameters() if param.requires_grad] + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Download the official dreambooth dataset from the official repository: https://github.com/google/dreambooth.git + data_path = os.path.join(os.getcwd(), "data", "dreambooth") + if not os.path.exists(data_path): + os.makedirs(os.path.join(os.getcwd(), "data"), exist_ok=True) + os.system(f"git clone https://github.com/google/dreambooth.git '{data_path}'") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.num_dataloader_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and text_encoder to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + if args.report_to == "wandb": + accelerator.init_trackers(args.wandb_project_name, config=vars(args), init_kwargs=wandb_init) + else: + accelerator.init_trackers(args.project_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + if args.train_text_encoder: + text_encoder.train() + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + + with TorchTracemalloc() if not args.no_tracemalloc else nullcontext() as tracemalloc: + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + global_step += 1 + + if global_step % args.checkpointing_steps == 0 and global_step != 0: + if accelerator.is_main_process: + save_adaptor(accelerator, global_step, unet, text_encoder, args) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if ( + args.validation_prompt is not None + and (step + num_update_steps_per_epoch * epoch) % args.validation_steps == 0 + and global_step > 10 + ): + unet.eval() + + logger.info( + f"Running validation... \n Generating {len(validation_prompts)} images with prompt:" + f" {validation_prompts[0]}, ......" + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + safety_checker=None, + revision=args.revision, + ) + # set `keep_fp32_wrapper` to True because we do not want to remove + # mixed precision hooks while we are still training + pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + if args.seed is not None: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None + + images = [] + val_img_dir = os.path.join( + args.output_dir, + f"validation/{global_step}", + args.run_name, + ) + os.makedirs(val_img_dir, exist_ok=True) + + for val_promot in validation_prompts: + image = pipeline(val_promot, num_inference_steps=50, generator=generator).images[0] + image.save(os.path.join(val_img_dir, f"{'_'.join(val_promot.split(' '))}.png"[1:])) + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + import wandb + + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {validation_prompts[i]}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + if global_step >= args.max_train_steps: + break + + # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage + if not args.no_tracemalloc: + accelerator.print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}") + accelerator.print(f"GPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}") + accelerator.print(f"GPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}") + accelerator.print( + f"GPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" + ) + + accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}") + accelerator.print(f"CPU Memory consumed at the end of the train (end-begin): {tracemalloc.cpu_used}") + accelerator.print(f"CPU Peak Memory consumed during the train (max-begin): {tracemalloc.cpu_peaked}") + accelerator.print( + f"CPU Total Peak Memory consumed during the train (max): {tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)}" + ) + + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/hra_dreambooth/train_dreambooth.sh b/examples/hra_dreambooth/train_dreambooth.sh new file mode 100644 index 0000000000..c45915b403 --- /dev/null +++ b/examples/hra_dreambooth/train_dreambooth.sh @@ -0,0 +1,185 @@ + +CLASS_IDX=$1 + +# Define the UNIQUE_TOKEN, CLASS_TOKENs, and SUBJECT_NAMES +UNIQUE_TOKEN="qwe" + +SUBJECT_NAMES=( + "backpack" "backpack_dog" "bear_plushie" "berry_bowl" "can" + "candle" "cat" "cat2" "clock" "colorful_sneaker" + "dog" "dog2" "dog3" "dog5" "dog6" + "dog7" "dog8" "duck_toy" "fancy_boot" "grey_sloth_plushie" + "monster_toy" "pink_sunglasses" "poop_emoji" "rc_car" "red_cartoon" + "robot_toy" "shiny_sneaker" "teapot" "vase" "wolf_plushie" +) + +CLASS_TOKENs=( + "backpack" "backpack" "stuffed animal" "bowl" "can" + "candle" "cat" "cat" "clock" "sneaker" + "dog" "dog" "dog" "dog" "dog" + "dog" "dog" "toy" "boot" "stuffed animal" + "toy" "glasses" "toy" "toy" "cartoon" + "toy" "sneaker" "teapot" "vase" "stuffed animal" +) + +CLASS_TOKEN=${CLASS_TOKENs[$CLASS_IDX]} +SELECTED_SUBJECT=${SUBJECT_NAMES[$CLASS_IDX]} + +if [[ $CLASS_IDX =~ ^(0|1|2|3|4|5|8|9|17|18|19|20|21|22|23|24|25|26|27|28|29)$ ]]; then + PROMPT_LIST=( + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in the jungle." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in the snow." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on the beach." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on a cobblestone street." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of pink fabric." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a wooden floor." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a city in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a mountain in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a blue house in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a purple rug in a forest." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a wheat field in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a tree and autumn leaves in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with the Eiffel Tower in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} floating on top of water." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} floating in an ocean of milk." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of green grass with sunflowers around it." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a mirror." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of the sidewalk in a crowded street." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a dirt road." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a white rug." + "a red ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a purple ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a shiny ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a wet ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a cube shaped ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + ) + + prompt_test_list=( + "a ${CLASS_TOKEN} in the jungle" + "a ${CLASS_TOKEN} in the snow" + "a ${CLASS_TOKEN} on the beach" + "a ${CLASS_TOKEN} on a cobblestone street" + "a ${CLASS_TOKEN} on top of pink fabric" + "a ${CLASS_TOKEN} on top of a wooden floor" + "a ${CLASS_TOKEN} with a city in the background" + "a ${CLASS_TOKEN} with a mountain in the background" + "a ${CLASS_TOKEN} with a blue house in the background" + "a ${CLASS_TOKEN} on top of a purple rug in a forest" + "a ${CLASS_TOKEN} with a wheat field in the background" + "a ${CLASS_TOKEN} with a tree and autumn leaves in the background" + "a ${CLASS_TOKEN} with the Eiffel Tower in the background" + "a ${CLASS_TOKEN} floating on top of water" + "a ${CLASS_TOKEN} floating in an ocean of milk" + "a ${CLASS_TOKEN} on top of green grass with sunflowers around it" + "a ${CLASS_TOKEN} on top of a mirror" + "a ${CLASS_TOKEN} on top of the sidewalk in a crowded street" + "a ${CLASS_TOKEN} on top of a dirt road" + "a ${CLASS_TOKEN} on top of a white rug" + "a red ${CLASS_TOKEN}" + "a purple ${CLASS_TOKEN}" + "a shiny ${CLASS_TOKEN}" + "a wet ${CLASS_TOKEN}" + "a cube shaped ${CLASS_TOKEN}" + ) + +else + PROMPT_LIST=( + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in the jungle." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in the snow." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on the beach." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on a cobblestone street." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of pink fabric." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a wooden floor." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a city in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a mountain in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} with a blue house in the background." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} on top of a purple rug in a forest." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing a red hat." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing a santa hat." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing a rainbow scarf." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing a black top hat and a monocle." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in a chef outfit." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in a firefighter outfit." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in a police outfit." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing pink glasses." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} wearing a yellow shirt." + "a ${UNIQUE_TOKEN} ${CLASS_TOKEN} in a purple wizard outfit." + "a red ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a purple ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a shiny ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a wet ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + "a cube shaped ${UNIQUE_TOKEN} ${CLASS_TOKEN}." + ) + + prompt_test_list=( + "a ${CLASS_TOKEN} in the jungle" + "a ${CLASS_TOKEN} in the snow" + "a ${CLASS_TOKEN} on the beach" + "a ${CLASS_TOKEN} on a cobblestone street" + "a ${CLASS_TOKEN} on top of pink fabric" + "a ${CLASS_TOKEN} on top of a wooden floor" + "a ${CLASS_TOKEN} with a city in the background" + "a ${CLASS_TOKEN} with a mountain in the background" + "a ${CLASS_TOKEN} with a blue house in the background" + "a ${CLASS_TOKEN} on top of a purple rug in a forest" + "a ${CLASS_TOKEN} wearing a red hat" + "a ${CLASS_TOKEN} wearing a santa hat" + "a ${CLASS_TOKEN} wearing a rainbow scarf" + "a ${CLASS_TOKEN} wearing a black top hat and a monocle" + "a ${CLASS_TOKEN} in a chef outfit" + "a ${CLASS_TOKEN} in a firefighter outfit" + "a ${CLASS_TOKEN} in a police outfit" + "a ${CLASS_TOKEN} wearing pink glasses" + "a ${CLASS_TOKEN} wearing a yellow shirt" + "a ${CLASS_TOKEN} in a purple wizard outfit" + "a red ${CLASS_TOKEN}" + "a purple ${CLASS_TOKEN}" + "a shiny ${CLASS_TOKEN}" + "a wet ${CLASS_TOKEN}" + "a cube shaped ${CLASS_TOKEN}" + ) +fi + +VALIDATION_PROMPT=${PROMPT_LIST[@]} +INSTANCE_PROMPT="a photo of ${UNIQUE_TOKEN} ${CLASS_TOKEN}" +CLASS_PROMPT="a photo of ${CLASS_TOKEN}" + +export MODEL_NAME="stabilityai/stable-diffusion-2-1" + +PEFT_TYPE="hra" +HRA_R=8 + +export PROJECT_NAME="dreambooth_${PEFT_TYPE}" +export RUN_NAME="${SELECTED_SUBJECT}_${PEFT_TYPE}_${HRA_R}" +export INSTANCE_DIR="./data/dreambooth/dataset/${SELECTED_SUBJECT}" +export CLASS_DIR="./data/class_data/${CLASS_TOKEN}" +export OUTPUT_DIR="./data/output/${PEFT_TYPE}" + + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir="$CLASS_DIR" \ + --output_dir=$OUTPUT_DIR \ + --project_name=$PROJECT_NAME \ + --run_name=$RUN_NAME \ + --with_prior_preservation \ + --prior_loss_weight=1.0 \ + --instance_prompt="$INSTANCE_PROMPT" \ + --validation_prompt="$VALIDATION_PROMPT" \ + --class_prompt="$CLASS_PROMPT" \ + --resolution=512 \ + --train_batch_size=1 \ + --num_dataloader_workers=2 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --use_hra \ + --hra_r=$HRA_R \ + --hra_bias="hra_only" \ + --learning_rate=5e-3 \ + --max_train_steps=510 \ + --checkpointing_steps=200 \ + --validation_steps=200 \ + --enable_xformers_memory_efficient_attention \ + --report_to="none" \ \ No newline at end of file diff --git a/examples/hra_dreambooth/utils/__init__.py b/examples/hra_dreambooth/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/hra_dreambooth/utils/args_loader.py b/examples/hra_dreambooth/utils/args_loader.py new file mode 100644 index 0000000000..83d03d68e3 --- /dev/null +++ b/examples/hra_dreambooth/utils/args_loader.py @@ -0,0 +1,377 @@ +# adapted from [peft's boft_dreambooth](https://github.com/huggingface/peft/tree/main/examples/boft_dreambooth) + +import argparse +import os +import warnings +from typing import Optional + +from huggingface_hub import HfFolder, whoami +from transformers import PretrainedConfig + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a Dreambooth training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--validation_prompt", + nargs="+", + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=500, + help=( + "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + + # hra args + parser.add_argument("--use_hra", action="store_true", help="Whether to use HRA for parameter efficient tuning.") + parser.add_argument("--hra_r", type=int, default=8, help="The rank of HRA across different layers.") + parser.add_argument( + "--hra_apply_GS", default=False, action="store_true", help="Whether to apply Gram-Schmidt orthogonalization." + ) + parser.add_argument( + "--hra_bias", + type=str, + default="none", + help="Bias type for HRA. Can be 'none', 'all' or 'hra_only', only used if use_hra is True.", + ) + parser.add_argument( + "--num_dataloader_workers", type=int, default=1, help="Num of workers for the training dataloader." + ) + parser.add_argument( + "--no_tracemalloc", + default=False, + action="store_true", + help="Flag to stop memory allocation tracing during training. This could speed up training on Windows.", + ) + + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--project_name", + type=str, + default=None, + help=("The project name for log tracking"), + ) + parser.add_argument( + "--run_name", + type=str, + default=None, + help=("The run name for log tracking"), + ) + parser.add_argument( + "--report_to", + type=str, + default="wandb", + help=( + 'The integration to report the results and logs to. Supported platforms are `"wandb"`' + ' (default), `"tensorboard"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--wandb_key", + type=str, + default=None, + help=("If report to option is set to wandb, api-key for wandb used for login to wandb "), + ) + parser.add_argument( + "--wandb_project_name", + type=str, + default=None, + help=("If report to option is set to wandb, project name in wandb for log tracking "), + ) + parser.add_argument( + "--wandb_run_name", + type=str, + default=None, + help=("If report to option is set to wandb, project name in wandb for log tracking "), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + # if args.dataset_name is None and args.train_data_dir is None: + # raise ValueError("Need either a dataset name or a training folder.") + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args diff --git a/examples/hra_dreambooth/utils/dataset.py b/examples/hra_dreambooth/utils/dataset.py new file mode 100644 index 0000000000..8adb0976ff --- /dev/null +++ b/examples/hra_dreambooth/utils/dataset.py @@ -0,0 +1,128 @@ +# adapted from [peft's boft_dreambooth](https://github.com/huggingface/peft/tree/main/examples/boft_dreambooth) + +from pathlib import Path + +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + +def collate_fn(examples, with_prior_preservation=False): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example diff --git a/examples/hra_dreambooth/utils/tracemalloc.py b/examples/hra_dreambooth/utils/tracemalloc.py new file mode 100644 index 0000000000..8cef2cf5c4 --- /dev/null +++ b/examples/hra_dreambooth/utils/tracemalloc.py @@ -0,0 +1,60 @@ +# adapted from [peft's boft_dreambooth](https://github.com/huggingface/peft/tree/main/examples/boft_dreambooth) + +import gc +import threading + +import psutil +import torch + + +# Converting Bytes to Megabytes +def b2mb(x): + return int(x / 2**20) + + +# This context manager is used to track the peak memory usage of the process +class TorchTracemalloc: + def __enter__(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.cuda.memory_allocated() + self.process = psutil.Process() + + self.cpu_begin = self.cpu_mem_used() + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + return self + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_peak = -1 + + while True: + self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def __exit__(self, *exc): + self.peak_monitoring = False + + gc.collect() + torch.cuda.empty_cache() + self.end = torch.cuda.memory_allocated() + self.peak = torch.cuda.max_memory_allocated() + self.used = b2mb(self.end - self.begin) + self.peaked = b2mb(self.peak - self.begin) + + self.cpu_end = self.cpu_mem_used() + self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) + self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) + # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 017ecd3c1c..a500e98949 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -86,6 +86,8 @@ FourierFTModel, XLoraConfig, XLoraModel, + HRAConfig, + HRAModel, ) from .utils import ( TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 6c7caa554f..59a2f10217 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -40,6 +40,8 @@ BOFTModel, FourierFTConfig, FourierFTModel, + HRAConfig, + HRAModel, IA3Config, IA3Model, LNTuningConfig, @@ -97,6 +99,7 @@ "VERA": VeraConfig, "FOURIERFT": FourierFTConfig, "XLORA": XLoraConfig, + "HRA": HRAConfig, } PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[_BaseTuner]] = { @@ -112,6 +115,7 @@ "VERA": VeraModel, "FOURIERFT": FourierFTModel, "XLORA": XLoraModel, + "HRA": HRAModel, } diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 9782ba2411..22216577b4 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -44,6 +44,7 @@ AdaptionPromptModel, BOFTModel, FourierFTModel, + HRAModel, IA3Model, LNTuningModel, LoHaModel, @@ -96,6 +97,7 @@ PeftType.VERA: VeraModel, PeftType.FOURIERFT: FourierFTModel, PeftType.XLORA: XLoraModel, + PeftType.HRA: HRAModel, } diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index b82bd57ea5..ad6c838651 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -35,3 +35,4 @@ from .vera import VeraConfig, VeraModel from .fourierft import FourierFTConfig, FourierFTModel from .xlora import XLoraConfig, XLoraModel +from .hra import HRAConfig, HRAModel diff --git a/src/peft/tuners/hra/__init__.py b/src/peft/tuners/hra/__init__.py new file mode 100644 index 0000000000..08e57cc57d --- /dev/null +++ b/src/peft/tuners/hra/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import HRAConfig +from .layer import HRAConv2d, HRALayer, HRALinear +from .model import HRAModel + + +__all__ = ["HRAConfig", "HRAModel", "HRAConv2d", "HRALinear", "HRALayer"] diff --git a/src/peft/tuners/hra/config.py b/src/peft/tuners/hra/config.py new file mode 100644 index 0000000000..1b5457d9af --- /dev/null +++ b/src/peft/tuners/hra/config.py @@ -0,0 +1,116 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from peft.config import PeftConfig +from peft.utils import PeftType + + +@dataclass +class HRAConfig(PeftConfig): + """ + This is the configuration class to store the configuration of a [`HRAModel`]. + + Args: + r (`int`): + The rank of HRA across different layers. It is best to set 'r' to an even number; otherwise, the default + initialization method will not work. + apply_GS (`bool`): + Whether to apply Gram-Schmidt orthogonalization. + target_modules (`Optional[Union[List[str], str]]`): + The names of the modules to apply the adapter to. If this is specified, only the modules with the specified + names will be replaced. When passing a string, a regex match will be performed. When passing a list of + strings, either an exact match will be performed or it is checked if the name of the module ends with any + of the passed strings. If this is specified as 'all-linear', then all linear modules are chosen, excluding + the output layer. If this is not specified, modules will be chosen according to the model architecture. If + the architecture is not known, an error will be raised -- in this case, you should specify the target + modules manually. + init_weights (`bool`): + Whether to perform initialization of HRA weights. + layers_to_transform (`Union[List[int], int]`): + The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices + that are specified in this list. If a single integer is passed, it will apply the transformations on the + layer at this index. + layers_pattern (`str`): + The layer pattern name, used only if `layers_to_transform` is different from `None`. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. + modules_to_save (`List[str]`): + List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint. + """ + + r: int = field( + default=8, + metadata={ + "help": "The rank of HRA across different layers.", + "note": "It is best to set 'r' to an even number; otherwise, the default initialization method will not work.", + }, + ) + apply_GS: bool = field( + default=False, + metadata={"help": "Whether to apply Gram-Schmidt orthogonalization or not."}, + ) + target_modules: Optional[Union[List[str], str]] = field( + default=None, + metadata={ + "help": "List of module names or regex expression of the module names to replace with HRA.", + "example": "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' ", + }, + ) + init_weights: bool = field( + default=True, + metadata={ + "help": ( + "Whether to initialize the weights of the HRA layers with their default initialization. Don't change " + "this setting, except if you know exactly what you're doing." + ), + }, + ) + layers_to_transform: Optional[Union[List[int], int]] = field( + default=None, + metadata={ + "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." + }, + ) + layers_pattern: Optional[str] = field( + default=None, + metadata={ + "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." + }, + ) + bias: str = field(default="none", metadata={"help": "Bias type for HRA. Can be 'none', 'all' or 'hra_only'"}) + modules_to_save: Optional[List[str]] = field( + default=None, + metadata={ + "help": "List of modules apart from HRA layers to be set as trainable and saved in the final checkpoint. " + "For example, in Sequence Classification or Token Classification tasks, " + "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." + }, + ) + + def __post_init__(self): + self.peft_type = PeftType.HRA + self.target_modules = ( + set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules + ) + # if target_modules is a regex expression, then layers_to_transform should be None + if isinstance(self.target_modules, str) and self.layers_to_transform is not None: + raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.") + + # if target_modules is a regex expression, then layers_pattern should be None + if isinstance(self.target_modules, str) and self.layers_pattern is not None: + raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.") diff --git a/src/peft/tuners/hra/layer.py b/src/peft/tuners/hra/layer.py new file mode 100644 index 0000000000..f4fd553532 --- /dev/null +++ b/src/peft/tuners/hra/layer.py @@ -0,0 +1,435 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Any, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge + + +class HRALayer(BaseTunerLayer): + # All names of layers that may contain (trainable) adapter weights + adapter_layer_names = ("hra_u",) + # All names of other parameters that may contain adapter-related parameters + other_param_names = ("hra_r", "hra_apply_GS") + + def __init__(self, base_layer: nn.Module, **kwargs) -> None: + self.base_layer = base_layer + self.hra_r = {} + self.hra_apply_GS = {} + self.hra_u = nn.ParameterDict({}) + # Mark the weight as unmerged + self._disable_adapters = False + self.merged_adapters = [] + self.kwargs = kwargs + + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + self.in_features, self.out_features = base_layer.in_features, base_layer.out_features + elif isinstance(base_layer, nn.Conv2d): + self.in_features, self.out_features = base_layer.in_channels, base_layer.out_channels + else: + raise ValueError(f"Unsupported layer type {type(base_layer)}") + + def update_layer( + self, + adapter_name: str, + r: int, + apply_GS: bool, + init_weights: bool, + **kwargs, + ) -> None: + """Internal function to create hra adapter + + Args: + adapter_name (`str`): Name for the adapter to add. + r (`int`): Rank for the added adapter. + init_weights (`bool`): Whether to initialize weights. + apply_GS (`bool`): Whether to apply Gram-Schmidt orthogonalization or not. + """ + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.hra_r[adapter_name] = r + self.hra_apply_GS[adapter_name] = apply_GS + + # Determine shape of HRA weights + base_layer = self.get_base_layer() + if isinstance(base_layer, nn.Linear): + self.hra_u[adapter_name] = nn.Parameter(torch.empty(self.in_features, r), requires_grad=True) + elif isinstance(base_layer, nn.Conv2d): + self.hra_u[adapter_name] = nn.Parameter( + torch.empty(self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0], r), + requires_grad=True, + ) + else: + raise TypeError(f"HRA is not implemented for base layers of type {type(base_layer).__name__}") + + # Initialize weights + if init_weights: + self.reset_hra_parameters(adapter_name) + else: + self.reset_hra_parameters_random(adapter_name) + + # Move new weights to device + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters) + + def reset_hra_parameters(self, adapter_name: str): + if self.hra_r[adapter_name] % 2 != 0: + warnings.warn("The symmetric initialization can NOT be performed when r is odd!") + nn.init.kaiming_uniform_(self.hra_u[adapter_name], a=math.sqrt(5)) + else: + shape = self.hra_u[adapter_name].shape + half_u = torch.zeros(shape[0], shape[1] // 2) + nn.init.kaiming_uniform_(half_u, a=math.sqrt(5)) + self.hra_u[adapter_name] = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1)) + + def reset_hra_parameters_random(self, adapter_name: str): + nn.init.kaiming_uniform_(self.hra_u[adapter_name], a=math.sqrt(5)) + + def scale_layer(self, scale: float) -> None: + if scale == 1: + return + + for active_adapter in self.active_adapters: + if active_adapter not in self.hra_u.keys(): + continue + + warnings.warn("Scaling operation for HRA not supported! Automatically set scale to 1.") + + def unscale_layer(self, scale=None) -> None: + for active_adapter in self.active_adapters: + if active_adapter not in self.hra_u.keys(): + continue + + warnings.warn("Unscaling operation for HRA not supported! Keeping scale at 1.") + + +class HRALinear(nn.Module, HRALayer): + """ + HRA implemented in a dense layer. + """ + + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + apply_GS: bool = False, + init_weights: Union[bool, str] = True, + **kwargs, + ) -> None: + super().__init__() + HRALayer.__init__(self, base_layer, **kwargs) + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, apply_GS, init_weights, **kwargs) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If `None`, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.hra_u.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter) + orig_weight = torch.mm(orig_weight, delta_weight) + + if not torch.isfinite(orig_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.base_layer.weight.data = orig_weight + else: + delta_weight = self.get_delta_weight(active_adapter) + self.base_layer.weight.data = torch.mm(self.base_layer.weight.data, delta_weight) + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.hra_u.keys(): + orig_weight = self.get_base_layer().weight.data.clone() + delta_weight = self.get_delta_weight(active_adapter, reverse=True) + self.get_base_layer().weight.data = torch.mm(orig_weight, delta_weight) + + def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor: + rank = self.hra_r[adapter_name] + apply_GS = self.hra_apply_GS[adapter_name] + opt_u = self.hra_u[adapter_name] + shape = opt_u.shape + + if apply_GS: + weight = [(opt_u[:, 0] / opt_u[:, 0].norm()).view(-1, 1)] + for i in range(1, rank): + ui = opt_u[:, i].view(-1, 1) + for j in range(i): + ui = ui - (weight[j].t() @ ui) * weight[j] + weight.append((ui / ui.norm()).view(-1, 1)) + weight = torch.cat(weight, dim=1) + weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * weight @ weight.t() + + else: + opt_u = opt_u / opt_u.norm(dim=0) + weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) + if reverse: + indices = range(rank - 1, -1, -1) + else: + indices = range(rank) + + for i in indices: + ui = opt_u[:, i].view(-1, 1) + weight = weight @ (torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * ui @ ui.t()) + + return weight + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + new_weight = torch.eye(self.in_features, device=x.device) + + for active_adapter in self.active_adapters: + if active_adapter not in self.hra_u.keys(): + continue + delta_weight = self.get_delta_weight(active_adapter) + new_weight = torch.mm(new_weight, delta_weight) + + x = x.to(self.get_base_layer().weight.data.dtype) + orig_weight = self.get_base_layer().weight.data + new_weight = torch.mm(orig_weight, new_weight) + + result = F.linear(input=x, weight=new_weight, bias=self.base_layer.bias) + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "hra." + rep + + +class HRAConv2d(nn.Module, HRALayer): + """HRA implemented in Conv2d layer""" + + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + apply_GS: bool = False, + init_weights: Union[bool, str] = True, + **kwargs, + ): + super().__init__() + HRALayer.__init__(self, base_layer) + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, apply_GS, init_weights, **kwargs) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If `None`, all active adapters will be merged. + Defaults to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + for active_adapter in adapter_names: + if active_adapter in self.hra_u.keys(): + base_layer = self.get_base_layer() + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weight = base_layer.weight.data.clone() + orig_weight = orig_weight.view( + self.out_features, + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], + ) + delta_weight = self.get_delta_weight(active_adapter) + orig_weight = torch.mm(orig_weight, delta_weight) + orig_weight = orig_weight.view( + self.out_features, + self.in_features, + self.base_layer.kernel_size[0], + self.base_layer.kernel_size[0], + ) + + if not torch.isfinite(orig_weight).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + + self.base_layer.weight.data = orig_weight + else: + orig_weight = base_layer.weight.data + orig_weight = orig_weight.view( + self.out_features, + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], + ) + delta_weight = self.get_delta_weight(active_adapter) + orig_weight = torch.mm(orig_weight, delta_weight) + orig_weight = orig_weight.view( + self.out_features, + self.in_features, + self.base_layer.kernel_size[0], + self.base_layer.kernel_size[0], + ) + + self.base_layer.weight.data = orig_weight + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + """ + This method unmerges all merged adapter layers from the base weights. + """ + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.hra_u.keys(): + orig_weight = self.get_base_layer().weight.data.clone() + orig_weight = orig_weight.view( + self.out_features, + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], + ) + delta_weight = self.get_delta_weight(active_adapter, reverse=True) + orig_weight = torch.mm(orig_weight, delta_weight) + orig_weight = orig_weight.view( + self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0] + ) + + self.get_base_layer().weight.data = orig_weight + + def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Tensor: + rank = self.hra_r[adapter_name] + apply_GS = self.hra_apply_GS[adapter_name] + opt_u = self.hra_u[adapter_name] + shape = opt_u.shape + + if apply_GS: + weight = [(opt_u[:, 0] / opt_u[:, 0].norm()).view(-1, 1)] + for i in range(1, rank): + ui = opt_u[:, i].view(-1, 1) + for j in range(i): + ui = ui - (weight[j].t() @ ui) * weight[j] + weight.append((ui / ui.norm()).view(-1, 1)) + weight = torch.cat(weight, dim=1) + weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * weight @ weight.t() + + else: + opt_u = opt_u / opt_u.norm(dim=0) + weight = torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) + if reverse: + indices = range(rank - 1, -1, -1) + else: + indices = range(rank) + + for i in indices: + ui = opt_u[:, i].view(-1, 1) + weight = weight @ (torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * ui @ ui.t()) + + return weight + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + previous_dtype = x.dtype + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + new_weight = torch.eye( + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], device=x.device + ) + for active_adapter in self.active_adapters: + if active_adapter not in self.hra_u.keys(): + continue + delta_weight = self.get_delta_weight(active_adapter) + new_weight = torch.mm(new_weight, delta_weight) + + x = x.to(self.base_layer.weight.data.dtype) + + orig_weight = self.base_layer.weight.data + orig_weight = orig_weight.view( + self.out_features, + self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], + ) + new_weight = torch.mm(orig_weight, new_weight) + new_weight = new_weight.view( + self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0] + ) + + result = F.conv2d( + input=x, + weight=new_weight, + bias=self.base_layer.bias, + padding=self.base_layer.padding[0], + stride=self.base_layer.stride[0], + ) + + result = result.to(previous_dtype) + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "hra." + rep diff --git a/src/peft/tuners/hra/model.py b/src/peft/tuners/hra/model.py new file mode 100644 index 0000000000..64ad71d074 --- /dev/null +++ b/src/peft/tuners/hra/model.py @@ -0,0 +1,337 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import asdict +from enum import Enum +from typing import List, Optional + +import torch +from torch import nn +from tqdm import tqdm + +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _get_submodules, +) + +from .config import HRAConfig +from .layer import HRAConv2d, HRALayer, HRALinear + + +class HRAModel(BaseTuner): + """ + Creates Householder reflection adaptation (HRA) model from a pretrained model. The method is described in + https://arxiv.org/abs/2405.17484 + + Args: + model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached. + config ([`HRAConfig`]): The configuration of the HRA model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + + Returns: + `torch.nn.Module`: The HRA model. + + Example: + ```py + >>> from diffusers import StableDiffusionPipeline + >>> from peft import HRAModel, HRAConfig + + >>> config_te = HRAConfig( + ... r=8, + ... target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"], + ... init_weights=True, + ... ) + >>> config_unet = HRAConfig( + ... r=8, + ... target_modules=[ + ... "proj_in", + ... "proj_out", + ... "to_k", + ... "to_q", + ... "to_v", + ... "to_out.0", + ... "ff.net.0.proj", + ... "ff.net.2", + ... ], + ... init_weights=True, + ... ) + + >>> model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> model.text_encoder = HRAModel(model.text_encoder, config_te, "default") + >>> model.unet = HRAModel(model.unet, config_unet, "default") + ``` + + **Attributes**: + - **model** ([`~torch.nn.Module`]) -- The model to be adapted. + - **peft_config** ([`HRAConfig`]): The configuration of the HRA model. + """ + + prefix: str = "hra_" + + def _check_new_adapter_config(self, config: HRAConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check + # does not fully correspond to the error message. + if (len(self.peft_config) > 1) and (config.bias != "none"): + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + @staticmethod + def _check_target_module_exists(hra_config, key): + return check_target_module_exists(hra_config, key) + + def _create_and_replace( + self, + hra_config, + adapter_name, + target, + target_name, + parent, + current_key, + **optional_kwargs, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + bias = hasattr(target, "bias") and target.bias is not None + kwargs = { + "r": hra_config.r, + "apply_GS": hra_config.apply_GS, + "init_weights": hra_config.init_weights, + } + kwargs["bias"] = bias + + # If it is not a HRALayer, create a new module, else update it with new adapters + if not isinstance(target, HRALayer): + new_module = self._create_new_module(hra_config, adapter_name, target, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + else: + target.update_layer( + adapter_name, + r=hra_config.r, + apply_GS=hra_config.apply_GS, + init_weights=hra_config.init_weights, + ) + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if self.prefix in name: + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == "none": + continue + + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "hra_only": + for name, m in model.named_modules(): + if isinstance(m, HRALayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(hra_config, adapter_name, target, **kwargs): + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, torch.nn.Linear): + new_module = HRALinear(target, adapter_name, **kwargs) + elif isinstance(target_base_layer, torch.nn.Conv2d): + new_module = HRAConv2d(target, adapter_name, **kwargs) + else: + raise ValueError( + f"Target module {target} is not supported. " + "Currently, only `torch.nn.Linear` and `torch.nn.Conv2d` are supported." + ) + + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "base_model": + raise + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + for active_adapter in self.active_adapters: + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name): + for module in self.model.modules(): + if isinstance(module, HRALayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[List[str]] = None, + ): + self._unloading_checks(adapter_names) + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + setattr(parent, target_name, target.modules_to_save[target.active_adapter]) + + return self.model + + def delete_adapter(self, adapter_name: str) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, HRALayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[List[str]] = None + ) -> torch.nn.Module: + r""" + This method merges the HRA layers into the base model. This is needed if someone wants to use the base model as + a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> torch.nn.Module: + """ + Gets back the base model by removing all the hra modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py index c678b0b80e..ed82c1d724 100644 --- a/src/peft/utils/peft_types.py +++ b/src/peft/utils/peft_types.py @@ -41,6 +41,7 @@ class PeftType(str, enum.Enum): - LN_TUNING - VERA - FOURIERFT + - HRA """ PROMPT_TUNING = "PROMPT_TUNING" @@ -60,6 +61,7 @@ class PeftType(str, enum.Enum): VERA = "VERA" FOURIERFT = "FOURIERFT" XLORA = "XLORA" + HRA = "HRA" class TaskType(str, enum.Enum): diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 292c036bad..cd64c9a06e 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -176,12 +176,12 @@ def renamed_dora_weights(k): ) to_return["base_model.vera_A." + adapter_name] = state_dict["base_model.vera_A." + adapter_name] to_return["base_model.vera_B." + adapter_name] = state_dict["base_model.vera_B." + adapter_name] - elif config.peft_type == PeftType.FOURIERFT: to_return = {k: state_dict[k] for k in state_dict if "fourierft_" in k} - elif config.peft_type == PeftType.XLORA: to_return = {k: state_dict[k] for k in state_dict if "internal_xlora_classifier" in k} + elif config.peft_type == PeftType.HRA: + to_return = {k: state_dict[k] for k in state_dict if "hra_" in k} else: raise ValueError(f"Unknown PEFT type passed: {config.peft_type}") @@ -320,6 +320,7 @@ def set_peft_model_state_dict( PeftType.BOFT, PeftType.VERA, PeftType.FOURIERFT, + PeftType.HRA, ): peft_model_state_dict = {} parameter_prefix = { @@ -334,6 +335,7 @@ def set_peft_model_state_dict( PeftType.LN_TUNING: "ln_tuning_", PeftType.VERA: "vera_lambda_", PeftType.FOURIERFT: "fourierft_", + PeftType.HRA: "hra_", }[config.peft_type] for k, v in state_dict.items(): if parameter_prefix in k: diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 6039d7d850..1d9193b126 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -18,10 +18,13 @@ import pytest import torch import torch.nn.functional as F +from datasets import load_dataset from parameterized import parameterized from torch import nn from transformers import ( + AutoImageProcessor, AutoModelForCausalLM, + AutoModelForImageClassification, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, @@ -36,6 +39,7 @@ AdaLoraConfig, AdaptionPromptConfig, BOFTConfig, + HRAConfig, IA3Config, LNTuningConfig, LoHaConfig, @@ -1088,6 +1092,7 @@ def test_8bit_dora_merging(self): @pytest.mark.single_gpu_tests def test_dora_ephemeral_gpu_offload(self): torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained( "facebook/opt-125m", torch_dtype=torch.float32, @@ -1140,6 +1145,7 @@ def test_dora_ephemeral_gpu_offload(self): @pytest.mark.multi_gpu_tests def test_dora_ephemeral_gpu_offload_multigpu(self): torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained( "facebook/opt-125m", torch_dtype=torch.float32, @@ -1164,6 +1170,79 @@ def test_dora_ephemeral_gpu_offload_multigpu(self): layer.lora_A, layer.lora_B = la, lb layer.dora_init(layer.active_adapter[0]) # should not raise an error + def test_apply_GS_hra_inference(self): + # check for different result with and without apply_GS + model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torch_dtype=torch.float32, + ).eval() + + torch.manual_seed(0) + config_hra = HRAConfig(r=8, init_weights=True, apply_GS=False) + model = get_peft_model(model, config_hra).eval() + + random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device) + logits_hra = model(random_input).logits + + model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torch_dtype=torch.float32, + ) + torch.manual_seed(0) + config_hra_GS = HRAConfig(r=8, init_weights=True, apply_GS=True) + model = get_peft_model(model, config_hra_GS) + + logits_hra_GS = model(random_input).logits + + assert not torch.allclose(logits_hra, logits_hra_GS) + + @require_torch_gpu + @pytest.mark.single_gpu_tests + def test_apply_GS_hra_conv2d_inference(self): + # check for different result with and without apply_GS + model_id = "microsoft/resnet-18" + image_processor = AutoImageProcessor.from_pretrained(model_id) + dataset = load_dataset("huggingface/cats-image", trust_remote_code=True) + image = dataset["test"]["image"][0] + data = image_processor(image, return_tensors="pt") + + model = AutoModelForImageClassification.from_pretrained(model_id).eval() + torch.manual_seed(0) + config_hra = HRAConfig(r=8, init_weights=True, target_modules=["convolution"], apply_GS=False) + model = get_peft_model(model, config_hra).eval() + + logits_hra = model(**data).logits + + model = AutoModelForImageClassification.from_pretrained(model_id).eval() + torch.manual_seed(0) + config_hra_GS = HRAConfig(r=8, init_weights=True, target_modules=["convolution"], apply_GS=True) + model = get_peft_model(model, config_hra_GS) + + logits_hra_GS = model(**data).logits + + assert not torch.allclose(logits_hra, logits_hra_GS) + + @require_torch_gpu + @pytest.mark.single_gpu_tests + def test_r_odd_hra_inference(self): + # check that an untrained HRA adapter can't be initialized as an identity tranformation + # when r is an odd number + model = AutoModelForCausalLM.from_pretrained( + "facebook/opt-125m", + torch_dtype=torch.float32, + ).eval() + + random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device) + + torch.manual_seed(0) + logits = model(random_input).logits + + config_hra = HRAConfig(r=7, init_weights=True, apply_GS=False) + model = get_peft_model(model, config_hra).eval() + logits_hra = model(random_input).logits + + assert not torch.allclose(logits, logits_hra) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a CUDA GPU") @pytest.mark.single_gpu_tests @@ -1488,3 +1567,21 @@ def test_vera_add_new_adapter_does_not_change_device(self, mlp): assert model.lin0.base_layer.weight.device.type == "cuda" assert model.lin0.vera_A.other.device.type == "cuda" assert model.lin0.vera_lambda_d.other.device.type == "cuda" + + def test_hra_add_new_adapter_does_not_change_device(self, mlp): + # same as first test, but using HRA + config = HRAConfig(target_modules=["lin0"]) + model = get_peft_model(mlp, config) + model = model.cuda() + model.lin0.hra_u.cpu() + + # check that the adapter is indeed on CPU and the base model on GPU + assert model.lin0.hra_u.default.device.type == "cpu" + assert model.lin0.base_layer.weight.device.type == "cuda" + + model.add_adapter("other", config) + # check that after adding a new adapter, the old adapter is still on CPU + assert model.lin0.hra_u.default.device.type == "cpu" + # the rest should be on GPU + assert model.lin0.base_layer.weight.device.type == "cuda" + assert model.lin0.hra_u.other.device.type == "cuda" diff --git a/tests/test_config.py b/tests/test_config.py index b8c5e03608..9dabade766 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -27,6 +27,7 @@ AdaptionPromptConfig, BOFTConfig, FourierFTConfig, + HRAConfig, IA3Config, LoHaConfig, LoraConfig, @@ -60,6 +61,7 @@ BOFTConfig, VeraConfig, FourierFTConfig, + HRAConfig, ) @@ -234,7 +236,7 @@ def test_prompt_encoder_warning_num_layers(self): expected_msg = "for MLP, the argument `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." assert str(record.list[0].message) == expected_msg - @parameterized.expand([LoHaConfig, LoraConfig, IA3Config, OFTConfig, BOFTConfig]) + @parameterized.expand([LoHaConfig, LoraConfig, IA3Config, OFTConfig, BOFTConfig, HRAConfig]) def test_save_pretrained_with_target_modules(self, config_class): # See #1041, #1045 config = config_class(target_modules=["a", "list"]) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index a6ccc903ca..d1339eab35 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -36,6 +36,7 @@ AdaLoraConfig, BOFTConfig, FourierFTConfig, + HRAConfig, IA3Config, LNTuningConfig, LoHaConfig, @@ -246,6 +247,14 @@ ("Conv2d 3 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True}), ("Conv2d 4 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "block_share": True}), ("Conv2d 5 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True, "block_share": True}), + ######## + # HRA # + ######## + ("Vanilla MLP 1 HRA", "MLP", HRAConfig, {"target_modules": "lin0"}), + ("Vanilla MLP 2 HRA", "MLP", HRAConfig, {"target_modules": ["lin0"]}), + ("Vanilla MLP 3 HRA", "MLP", HRAConfig, {"target_modules": ["lin0", "lin1"]}), + ("Vanilla MLP 5 HRA", "MLP", HRAConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}), + ("Conv2d 1 HRA", "Conv2d", HRAConfig, {"target_modules": ["conv2d"]}), ############# # LN Tuning # ############# @@ -455,6 +464,20 @@ {"n_frequency": 10, "target_modules": ["lin0"]}, {"n_frequency": 10, "target_modules": ["lin1"]}, ), + ( + "HRA Same", + "hra", + HRAConfig, + {"target_modules": ["lin0"], "init_weights": False}, + {"target_modules": ["lin0"], "init_weights": False}, + ), + ( + "HRA Different", + "hra", + HRAConfig, + {"target_modules": ["lin0"], "init_weights": False}, + {"target_modules": ["lin1"], "init_weights": False}, + ), ] PREFIXES = { IA3Config: "ia3_", @@ -466,6 +489,7 @@ LNTuningConfig: "ln_tuning_", VeraConfig: "vera_lambda_", FourierFTConfig: "fourierft_", + HRAConfig: "hra_", } @@ -1203,7 +1227,7 @@ def test_multiple_adapters_automatic_modules_to_save(self): assert "default" in model.base_model.classifier.modules_to_save assert "other" in model.base_model.classifier.modules_to_save - @parameterized.expand([IA3Config, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig]) + @parameterized.expand([IA3Config, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, HRAConfig]) def test_multiple_adapters_mixed_modules_to_save(self, config_cls): # See issue 1574 # Check that we can have a model where one adapter has modules_to_save and the other doesn't. It should be @@ -1228,7 +1252,7 @@ def test_multiple_adapters_mixed_modules_to_save(self, config_cls): model.set_adapter("other") model(**inputs) - @parameterized.expand([IA3Config, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig]) + @parameterized.expand([IA3Config, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, HRAConfig]) def test_multiple_adapters_mixed_modules_to_save_order_switched(self, config_cls): # See issue 1574 # Same test as test_multiple_adapters_mixed_modules_to_save, but this time the 2nd adapter has modules_to_save. @@ -1433,6 +1457,7 @@ def test_load_resized_embedding_ignore_mismatched_sizes(self): IA3Config(target_modules=["lin0"], feedforward_modules=["lin0"], init_ia3_weights=False), OFTConfig(target_modules=["lin0"], init_weights=False), BOFTConfig(target_modules=["lin0"], init_weights=False, boft_block_size=2), + HRAConfig(target_modules=["lin0"], init_weights=False), ] ) def test_adapter_name_makes_no_difference(self, config0): @@ -2607,6 +2632,83 @@ def test_requires_grad_oft_same_targets(self): "base_model.model.lin0.oft_r.adapter1", ) + def test_requires_grad_hra_different_targets(self): + # test two different HRA adapters that target different modules + config0 = HRAConfig(target_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = HRAConfig(target_modules=["lin1"], inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin0.hra_u.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.hra_u.default", + ) + + # change activate pter to pter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin1.hra_u.adapter1", + ) + + # disable all pters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + self.check_requires_grad( + peft_model, + "base_model.model.lin1.hra_u.adapter1", + ) + + def test_requires_grad_hra_same_targets(self): + # same as previous test, except that HRA adapters target the same layer + config0 = HRAConfig(target_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = HRAConfig(target_modules=["lin0"], inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.check_requires_grad( + peft_model, + "base_model.model.lin0.hra_u.default", + ) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.hra_u.default", + ) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.hra_u.adapter1", + ) + + # disable all adapters + with peft_model.disable_adapter(): + self.check_requires_grad(peft_model) + + # after context is exited, return to the previous state + peft_model.set_adapter("adapter1") + self.check_requires_grad( + peft_model, + "base_model.model.lin0.hra_u.adapter1", + ) + def test_requires_grad_boft_different_targets(self): # test two different OFT adapters that target different modules config0 = BOFTConfig(target_modules=["lin0"], boft_block_size=2) diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 5df37e9c75..601863e0f4 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -19,7 +19,7 @@ from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer -from peft import AdaLoraConfig, BOFTConfig, LoraConfig, PromptTuningConfig, PromptTuningInit, get_peft_model +from peft import AdaLoraConfig, BOFTConfig, HRAConfig, LoraConfig, PromptTuningConfig, PromptTuningInit, get_peft_model from .testing_common import PeftCommonTester, PeftTestConfigManager @@ -45,15 +45,22 @@ def skip_adalora_and_gpt2(test_list): return [test for test in test_list if not (("GPT2LMHeadModel" in test[1]) and (test[2] == AdaLoraConfig))] -def skip_boft_and_gpt2(test_list): - return [test for test in test_list if not (("GPT2LMHeadModel" in test[1]) and (test[2] == BOFTConfig))] +def skip_boft_or_hra_and_gpt2(test_list): + return [ + test + for test in test_list + if not (("GPT2LMHeadModel" in test[1]) and ((test[2] == BOFTConfig) or (test[2] == HRAConfig))) + ] -def skip_adalora_or_boft_and_gpt2(test_list): +def skip_adalora_or_boft_or_hra_and_gpt2(test_list): return [ test for test in test_list - if not (("GPT2LMHeadModel" in test[1]) and ((test[2] == AdaLoraConfig) or (test[2] == BOFTConfig))) + if not ( + ("GPT2LMHeadModel" in test[1]) + and ((test[2] == AdaLoraConfig) or (test[2] == BOFTConfig) or (test[2] == HRAConfig)) + ) ] @@ -78,15 +85,21 @@ def prepare_inputs_for_testing(self): return input_dict - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): self._test_model_attr(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_adapter_name(self, test_name, model_id, config_cls, config_kwargs): self._test_adapter_name(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs): self._test_prepare_for_training(model_id, config_cls, config_kwargs) @@ -144,23 +157,33 @@ def test_prompt_tuning_config_invalid_args(self): tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"}, ) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, config_cls, config_kwargs): self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) @@ -174,6 +197,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "boft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, ) @@ -190,9 +214,10 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): "boft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, - filter_params_func=skip_boft_and_gpt2, + filter_params_func=skip_boft_or_hra_and_gpt2, ) ) def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs): @@ -224,11 +249,15 @@ def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_generate(self, test_name, model_id, config_cls, config_kwargs): self._test_generate(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs): # positional args are supported for PeftModelForCausalLM self._test_generate_pos_args(model_id, config_cls, config_kwargs, raises_err=False) @@ -245,7 +274,9 @@ def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs def test_prefix_tuning_half_prec_conversion(self, test_name, model_id, config_cls, config_kwargs): self._test_prefix_tuning_half_prec_conversion(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs): self._test_training(model_id, config_cls, config_kwargs) @@ -253,11 +284,15 @@ def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs) def test_training_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs): self._test_training_layer_indexing(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_training_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs): self._test_inference_safetensors(model_id, config_cls, config_kwargs) @@ -265,15 +300,21 @@ def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwa def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs): self._test_peft_model_device_map(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_delete_adapter(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs): self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) @@ -287,9 +328,10 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co "boft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, - filter_params_func=skip_adalora_or_boft_and_gpt2, + filter_params_func=skip_adalora_or_boft_or_hra_and_gpt2, ) ) def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): @@ -323,9 +365,10 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c "boft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, "task_type": "CAUSAL_LM", }, - filter_params_func=skip_boft_and_gpt2, + filter_params_func=skip_boft_or_hra_and_gpt2, ) ) def test_disable_adapter(self, test_name, model_id, config_cls, config_kwargs): @@ -341,7 +384,9 @@ def test_generate_adalora_no_dropout(self): } self._test_generate(model_id, AdaLoraConfig, config_kwargs) - @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_and_gpt2)) + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID, filter_params_func=skip_boft_or_hra_and_gpt2) + ) def test_passing_input_embeds_works(self, test_name, model_id, config_cls, config_kwargs): self._test_passing_input_embeds_works(test_name, model_id, config_cls, config_kwargs) diff --git a/tests/test_encoder_decoder_models.py b/tests/test_encoder_decoder_models.py index f4fe0b7934..bf4e7fe91c 100644 --- a/tests/test_encoder_decoder_models.py +++ b/tests/test_encoder_decoder_models.py @@ -94,6 +94,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "adalora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) @@ -173,6 +174,7 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) @@ -206,6 +208,7 @@ def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, c "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, "task_type": "SEQ_2_SEQ_LM", }, ) diff --git a/tests/test_feature_extraction_models.py b/tests/test_feature_extraction_models.py index e684d68c37..5521c1125d 100644 --- a/tests/test_feature_extraction_models.py +++ b/tests/test_feature_extraction_models.py @@ -47,7 +47,7 @@ def skip_deberta_lora_tests(test_list): Skip tests that are checkpointing with lora/ia3/boft/vera/fourierft for Deberta models (couldn't find much info on the error) """ - to_skip = ["lora", "ia3", "boft", "vera", "fourierft"] + to_skip = ["lora", "ia3", "boft", "vera", "fourierft", "hra"] return [test for test in test_list if not (any(k in test[0] for k in to_skip) and "Deberta" in test[0])] @@ -112,6 +112,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) @@ -164,6 +165,7 @@ def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_k "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, "vera_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) @@ -178,6 +180,7 @@ def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): "lora_kwargs": {"init_lora_weights": [False]}, "ia3_kwargs": {"init_ia3_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, "task_type": "FEATURE_EXTRACTION", }, ) diff --git a/tests/test_stablediffusion.py b/tests/test_stablediffusion.py index b8cc2e203a..b0d4d7773e 100644 --- a/tests/test_stablediffusion.py +++ b/tests/test_stablediffusion.py @@ -19,7 +19,7 @@ from diffusers import StableDiffusionPipeline from parameterized import parameterized -from peft import BOFTConfig, LoHaConfig, LoraConfig, OFTConfig, get_peft_model +from peft import BOFTConfig, HRAConfig, LoHaConfig, LoraConfig, OFTConfig, get_peft_model from .testing_common import ClassInstantier, PeftCommonTester from .testing_utils import temp_seed @@ -85,6 +85,16 @@ "boft_dropout": 0.0, }, }, + { + "text_encoder": { + "r": 8, + "target_modules": ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"], + }, + "unet": { + "r": 8, + "target_modules": ["proj_in", "proj_out", "to_k", "to_q", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"], + }, + }, ) CLASSES_MAPPING = { "lora": (LoraConfig, CONFIG_TESTING_KWARGS[0]), @@ -92,6 +102,7 @@ "lokr": (LoHaConfig, CONFIG_TESTING_KWARGS[1]), "oft": (OFTConfig, CONFIG_TESTING_KWARGS[2]), "boft": (BOFTConfig, CONFIG_TESTING_KWARGS[3]), + "hra": (HRAConfig, CONFIG_TESTING_KWARGS[4]), } @@ -145,6 +156,7 @@ def prepare_inputs_for_testing(self): "loha_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, }, ) ) @@ -158,7 +170,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): peft_output = np.array(model(**dummy_input).images[0]).astype(np.float32) # Merge adapter and model - if config_cls not in [LoHaConfig, OFTConfig]: + if config_cls not in [LoHaConfig, OFTConfig, HRAConfig]: # TODO: Merging the text_encoder is leading to issues on CPU with PyTorch 2.1 model.text_encoder = model.text_encoder.merge_and_unload() model.unet = model.unet.merge_and_unload() @@ -178,6 +190,7 @@ def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): "loha_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, }, ) ) @@ -191,7 +204,7 @@ def test_merge_layers_safe_merge(self, test_name, model_id, config_cls, config_k peft_output = np.array(model(**dummy_input).images[0]).astype(np.float32) # Merge adapter and model - if config_cls not in [LoHaConfig, OFTConfig]: + if config_cls not in [LoHaConfig, OFTConfig, HRAConfig]: # TODO: Merging the text_encoder is leading to issues on CPU with PyTorch 2.1 model.text_encoder = model.text_encoder.merge_and_unload(safe_merge=True) model.unet = model.unet.merge_and_unload(safe_merge=True) @@ -209,7 +222,9 @@ def test_merge_layers_safe_merge(self, test_name, model_id, config_cls, config_k "model_ids": PEFT_DIFFUSERS_SD_MODELS_TO_TEST, "lora_kwargs": {"init_lora_weights": [False]}, }, - filter_params_func=lambda tests: [x for x in tests if all(s not in x[0] for s in ["loha", "lokr", "oft"])], + filter_params_func=lambda tests: [ + x for x in tests if all(s not in x[0] for s in ["loha", "lokr", "oft", "hra"]) + ], ) ) def test_add_weighted_adapter_base_unchanged(self, test_name, model_id, config_cls, config_kwargs): @@ -239,6 +254,7 @@ def test_add_weighted_adapter_base_unchanged(self, test_name, model_id, config_c "lokr_kwargs": {"init_weights": [False]}, "oft_kwargs": {"init_weights": [False]}, "boft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, }, ) ) diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py index 818dcc1c43..ea987036cb 100644 --- a/tests/test_torch_compile.py +++ b/tests/test_torch_compile.py @@ -37,6 +37,7 @@ from peft import ( AdaLoraConfig, BOFTConfig, + HRAConfig, IA3Config, LNTuningConfig, LoHaConfig, @@ -77,6 +78,7 @@ "lora-with-modules-to-save": (LoraConfig(task_type=TaskType.CAUSAL_LM, modules_to_save=["embed_tokens"]), {}), "oft": (OFTConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"]), {}), "vera": (VeraConfig(task_type=TaskType.CAUSAL_LM), {}), + "hra": (HRAConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"]), {}), } diff --git a/tests/test_vision_models.py b/tests/test_vision_models.py index 8cb913707d..e706846f7b 100644 --- a/tests/test_vision_models.py +++ b/tests/test_vision_models.py @@ -22,7 +22,7 @@ from safetensors.torch import load_file from transformers import AutoImageProcessor, AutoModelForImageClassification -from peft import LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, PeftModel, get_peft_model +from peft import HRAConfig, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, PeftModel, get_peft_model CONFIGS = { @@ -30,6 +30,7 @@ "loha": LoHaConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), "lokr": LoKrConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), "oft": OFTConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), + "hra": HRAConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]), # TODO: cannot use BOFT because some convolutional kernel dimensions are even (64) and others odd (147). There is no # common denominator for the boft_block_size except 1, but using 1 results in an error in the fbd_cuda kernel: # > Error in forward_fast_block_diag_cuda_kernel: an illegal memory access was encountered diff --git a/tests/testing_common.py b/tests/testing_common.py index db238905a5..83566e3aa3 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -30,6 +30,7 @@ AdaLoraConfig, BOFTConfig, FourierFTConfig, + HRAConfig, IA3Config, LNTuningConfig, LoHaConfig, @@ -102,6 +103,10 @@ "n_frequency": 10, "target_modules": None, }, + # HRA + { + "target_modules": None, + }, ) CLASSES_MAPPING = { @@ -112,8 +117,9 @@ "prompt_tuning": (PromptTuningConfig, CONFIG_TESTING_KWARGS[4]), "adalora": (AdaLoraConfig, CONFIG_TESTING_KWARGS[5]), "boft": (BOFTConfig, CONFIG_TESTING_KWARGS[6]), - "vera": (VeraConfig, CONFIG_TESTING_KWARGS[6]), + "vera": (VeraConfig, CONFIG_TESTING_KWARGS[7]), "fourierft": (FourierFTConfig, CONFIG_TESTING_KWARGS[8]), + "hra": (HRAConfig, CONFIG_TESTING_KWARGS[9]), } @@ -626,7 +632,15 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): assert torch.allclose(logits_merged, logits_merged_from_pretrained, atol=atol, rtol=rtol) def _test_merge_layers_multi(self, model_id, config_cls, config_kwargs): - supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3, PeftType.OFT, PeftType.BOFT] + supported_peft_types = [ + PeftType.LORA, + PeftType.LOHA, + PeftType.LOKR, + PeftType.IA3, + PeftType.OFT, + PeftType.BOFT, + PeftType.HRA, + ] if ("gpt2" in model_id.lower()) and (config_cls == IA3Config): self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)") @@ -1080,6 +1094,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs): PeftType.BOFT, PeftType.VERA, PeftType.FOURIERFT, + PeftType.HRA, ] # IA3 does not support deleting adapters yet, but it just needs to be added # AdaLora does not support multiple adapters @@ -1126,6 +1141,7 @@ def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs): PeftType.OFT, PeftType.BOFT, PeftType.FOURIERFT, + PeftType.HRA, ] # IA3 does not support deleting adapters yet, but it just needs to be added # AdaLora does not support multiple adapters @@ -1171,7 +1187,7 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = get_peft_model(model, config) model = model.to(self.torch_device) - if config.peft_type not in ("LORA", "ADALORA", "IA3", "BOFT", "VERA", "FOURIERFT"): + if config.peft_type not in ("LORA", "ADALORA", "IA3", "BOFT", "VERA", "FOURIERFT", "HRA"): with pytest.raises(AttributeError): model = model.unload() else: