Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'TinyLlavaForConditionalGeneration' object has no attribute 'peft_config' #142

Open
pspdada opened this issue Dec 4, 2024 · 1 comment · May be fixed by #143
Open

'TinyLlavaForConditionalGeneration' object has no attribute 'peft_config' #142

pspdada opened this issue Dec 4, 2024 · 1 comment · May be fixed by #143

Comments

@pspdada
Copy link

pspdada commented Dec 4, 2024

I followed the instructions in the CUSTOM_FINETUNE.md file to run the Pokémon dataset for LoRA fine-tuning and encountered an issue when using bash scripts/train/custom_finetune.sh. The error message is as follows:

Traceback (most recent call last):
  File "/root/llm-project/TinyLLaVA_Factory/tinyllava/train/custom_finetune.py", line 52, in <module>
    train()
  File "/root/llm-project/TinyLLaVA_Factory/tinyllava/train/custom_finetune.py", line 34, in train
    model = training_recipe(model)
  File "/root/llm-project/TinyLLaVA_Factory/tinyllava/training_recipe/base.py", line 14, in __call__
    model = self.training_model_converse(model)
  File "/root/llm-project/TinyLLaVA_Factory/tinyllava/training_recipe/lora_recipe.py", line 46, in training_model_converse
    if model.peft_config is None:
  File "/root/anaconda3/envs/tinyllava_factory/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'TinyLlavaForConditionalGeneration' object has no attribute 'peft_config'
[2024-12-04 12:09:01,833] [INFO] [launch.py:316:sigkill_handler] Killing subprocess 73410

The context of this function is as follows:

@register_training_recipe('lora')
class LoRATrainingRecipe(BaseTrainingRecipe):
    def __init__(self, training_arguments):
        super().__init__(training_arguments)
        self.training_arguments = training_arguments
        self.lora_skip_module = ['connector', 'vision_tower', 'language_model']
        
    def training_model_converse(self, model):
        if self.training_arguments.tune_type_connector == 'lora':
            self.lora_skip_module.remove('connector')
        if self.training_arguments.tune_type_llm == 'lora':
            self.lora_skip_module.remove('language_model')
        if self.training_arguments.tune_type_vision_tower == 'lora':
            self.lora_skip_module.remove('vision_tower')
        lora_config = LoraConfig(
            r=self.training_arguments.lora_r,
            lora_alpha=self.training_arguments.lora_alpha,
            target_modules=find_all_linear_names(model, self.lora_skip_module),
            lora_dropout=self.training_arguments.lora_dropout,
            bias=self.training_arguments.lora_bias,
            task_type="CAUSAL_LM",
        )
        if self.training_arguments.bits == 16:
            if self.training_arguments.bf16:
                model.to(torch.bfloat16)
            if self.training_arguments.fp16:
                model.to(torch.float16)
        if model.peft_config is None:  # Error raise here!
            log("Adding LoRA adapters...")
            model = get_peft_model(model, lora_config)  
        return model

My bash scripe is:

DATA_PATH="mydata/text_files/pokemon_blip_captions.json"
IMAGE_PATH="/root/llm-project/TinyLLaVA_Factory/mydata"
MODEL_MAX_LENGTH=3072
OUTPUT_DIR="/root/llm-project/TinyLLaVA_Factory/output/custom-finetune-TinyLLaVA-Phi-2-SigLIP-3.1B-lora"

deepspeed --include localhost:0 --master_port 29501 tinyllava/train/custom_finetune.py \
    --deepspeed ./scripts/zero2.json \
    --data_path  $DATA_PATH \
    --image_folder $IMAGE_PATH \
    --is_multimodal True \
    --conv_version phi \
    --mm_vision_select_layer -2 \
    --image_aspect_ratio square \
    --fp16 True \
    --training_recipe lora \
    --tune_type_llm lora \
    --tune_type_vision_tower frozen \
    --tune_vision_tower_from_layer 0 \
    --tune_type_connector full \
    --lora_r 128 \
    --lora_alpha 256 \
    --group_by_modality_length False \
    --pretrained_model_path "tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B" \
    --output_dir $OUTPUT_DIR \
    --num_train_epochs 1 \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 50000 \
    --save_total_limit 1 \
    --learning_rate 1e-4 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 False \
    --model_max_length $MODEL_MAX_LENGTH \
    --gradient_checkpointing True \
    --dataloader_num_workers 8 \
    --lazy_preprocess True \
    --report_to tensorboard \
    --tokenizer_use_fast False \
    --run_name custom-finetune-TinyLLaVA-Phi-2-SigLIP-3.1B-lora

My environment details are:

  • transformers version: 4.40.1
  • Platform: Linux-5.15.0-120-generic-x86_64-with-glibc2.35
  • Python version: 3.10.15
  • Huggingface_hub version: 0.26.3
  • Safetensors version: 0.4.5
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: 1
  • Using distributed or parallel set-up in script?: None
@pspdada
Copy link
Author

pspdada commented Dec 4, 2024

Simply changing if model.peft_config is None: to if not hasattr(model, 'peft_config') or model.peft_config is None: can resolve this issue, and no anomalies were found during the subsequent fine-tuning process.

pspdada added a commit to pspdada/TinyLLaVA_Factory that referenced this issue Dec 4, 2024
@pspdada pspdada linked a pull request Dec 4, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant