Skip to content

Commit

Permalink
Merge pull request #287 from EliSchwartz/main
Browse files Browse the repository at this point in the history
Fixed issue with granite-vision QLORA training and made it the default
  • Loading branch information
stevhliu authored Feb 18, 2025
2 parents fa40159 + fb4aa65 commit 9349385
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions notebooks/en/fine_tuning_granite_vision_sft_trl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"!pip install -q flash-attn --no-build-isolation\n",
"\n",
"try:\n",
" from flash_attn.flash_attention import FlashAttention\n",
" import flash_attn\n",
" print(\"FlashAttention is installed\")\n",
" USE_FLASH_ATTENTION = True\n",
"except ImportError:\n",
Expand Down Expand Up @@ -639,16 +639,18 @@
"source": [
"from transformers import BitsAndBytesConfig\n",
"\n",
"USE_QLORA = False\n",
"USE_LORA = False\n",
"USE_QLORA = True\n",
"USE_LORA = True\n",
"\n",
"if USE_QLORA:\n",
" # BitsAndBytesConfig int-4 config\n",
" bnb_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_dtype=torch.bfloat16\n",
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
" llm_int8_skip_modules=[\"vision_tower\", \"lm_head\"], # Skip problematic modules\n",
" llm_int8_enable_fp32_cpu_offload=True\n",
" )\n",
"else:\n",
" bnb_config = None\n",
Expand Down Expand Up @@ -693,7 +695,6 @@
" r=8,\n",
" lora_alpha=8,\n",
" lora_dropout=0.1,\n",
" # target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],\n",
" target_modules=[name for name, _ in model.named_modules() if 'language_model' in name and '_proj' in name],\n",
" use_dora=True,\n",
" init_lora_weights=\"gaussian\"\n",
Expand Down Expand Up @@ -1052,7 +1053,8 @@
"outputs": [],
"source": [
"if USE_LORA:\n",
" model = model.merge_and_unload().to(torch.bfloat16)"
" from peft import PeftModel\n",
" model = PeftModel.from_pretrained(model, training_args.output_dir)"
]
},
{
Expand Down

0 comments on commit 9349385

Please sign in to comment.