QLoRA ์ค์ต with MLLMs(InternVL)
Step 1. ํ์ Library import:
import os import torch import torch.nn as nn import bitsandbytes as bnb import transformers from peft import ( LoraConfig, PeftConfig, PeftModel, get_peft_model, ) from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed, pipeline, TrainingArguments, )โ
Step 2. ๋ชจ๋ธ ๋ถ๋ฌ์จ ํ prepare_model_for_kbit_training(model) ์งํ
devices = [0]#[0, 3] max_memory = {i: '49140MiB' for i in devices} model_name = 'OpenGVLab/InternVL2-8B' model = AutoModelForCausalLM.from_pretrained( model_name, cache_dir='/data/huggingface_models', trust_remote_code=True, device_map="auto", max_memory=max_memory, quantization_config=BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' ), ) # ๋ชจ๋ธ ๊ตฌ์กฐ ์ถ๋ ฅ print(model) # get_input_embeddings ๋ฉ์๋๋ฅผ ๋ชจ๋ธ์ ์ถ๊ฐ def get_input_embeddings(self): if hasattr(self, 'embed_tokens'): return self.embed_tokens elif hasattr(self, 'language_model') and hasattr(self.language_model.model, 'tok_embeddings'): return self.language_model.model.tok_embeddings else: raise NotImplementedError("The model does not have an attribute 'embed_tokens' or 'language_model.model.tok_embeddings'.") model.get_input_embeddings = get_input_embeddings.__get__(model, type(model)) # prepare_model_for_kbit_training ํจ์๋ฅผ ์ง์ ๊ตฌํ def prepare_model_for_kbit_training(model): for param in model.parameters(): param.requires_grad = False # ๋ชจ๋ ํ๋ผ๋ฏธํฐ์ ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ์ ๋นํ์ฑํ if hasattr(model, 'model') and hasattr(model.model, 'tok_embeddings'): for param in model.model.tok_embeddings.parameters(): param.requires_grad = True # ์๋ฒ ๋ฉ ๋ ์ด์ด๋ง ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ ํ์ฑํ elif hasattr(model, 'embed_tokens'): for param in model.embed_tokens.parameters(): param.requires_grad = True # ์๋ฒ ๋ฉ ๋ ์ด์ด๋ง ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ ํ์ฑํ # ํ์ํ ๊ฒฝ์ฐ ๋ค๋ฅธ ํน์ ๋ ์ด์ด๋ค๋ ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ์ ํ์ฑํํ ์ ์์ # ์์: # if hasattr(model, 'some_other_layer'): # for param in model.some_other_layer.parameters(): # param.requires_grad = True return model model = prepare_model_for_kbit_training(model)โ
Step 3. QLoRA๋ฅผ ๋ถ์ผ layer ์ ํ:
def find_all_linear_names(model, train_mode): assert train_mode in ['lora', 'qlora'] cls = bnb.nn.Linear4bit if train_mode == 'qlora' else nn.Linear lora_module_names = set() for name, module in model.named_modules(): if isinstance(module, cls): names = name.split('.') lora_module_names.add(names[0] if len(names) == 1 else names[-1]) if 'lm_head' in lora_module_names: # LLM์ Head๋ถ๋ถ์ ์ํ๋ ์ ๋ค pass lora_module_names.remove('lm_head') return list(lora_module_names) print(sorted(config.target_modules)) # ['1','output', 'w1', 'w2', 'w3', 'wo', 'wqkv'] config.target_modules.remove('1') # LLM์ Head๋ถ๋ถ์ ์ํ๋ ์ ๋ค ์ ๊ฑฐ config = LoraConfig( r=16, lora_alpha=16, target_modules=find_all_linear_names(model, 'qlora'), lora_dropout=0.05, bias="none", task_type="QUESTION_ANS" #CAUSAL_LM, FEATURE_EXTRACTION, QUESTION_ANS, SEQ_2_SEQ_LM, SEQ_CLS, TOKEN_CLS. ) model = get_peft_model(model, config)
์ดํ trainer๋ก train์งํ.
QLoRA ๋ถ์ธ ๊ฒฐ๊ณผ:
trainer ์ข ๋ฅ? Trainer vs SFTTrainer
Trainer v.s. SFTTrainer
โ Trainer v.s. SFTTrainer
- ์ผ๋ฐ ๋ชฉ์ ์ ํ๋ จ: ํ ์คํธ ๋ถ๋ฅ, ์ง์์๋ต, ์์ฝ ๋ฑ์ ์ง๋ ํ์ต ์์ ์์ ๋ชจ๋ธ์ ์ฒ์๋ถํฐ ํ๋ จ์ํค๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
- ๋์ ์ปค์คํฐ๋ง์ด์ง ๊ฐ๋ฅ์ฑ: hyperparameter, optimizer, scheduler, logging, metric ๋ฑ์ ๋ฏธ์ธ ์กฐ์ ํ ์ ์๋ ๋ค์ํ ๊ตฌ์ฑ ์ต์ ์ ์ ๊ณตํฉ๋๋ค.
- ๋ณต์กํ ํ๋ จ ์ํฌํ๋ก์ฐ ์ฒ๋ฆฌ: ๊ทธ๋๋์ธํธ ์ถ์ , ์กฐ๊ธฐ ์ข ๋ฃ, ์ฒดํฌํฌ์ธํธ ์ ์ฅ, ๋ถ์ฐ ํ๋ จ ๋ฑ์ ๊ธฐ๋ฅ์ ์ง์ํฉ๋๋ค.
- ๋ ๋ง์ ๋ฐ์ดํฐ ์๊ตฌ: ํจ๊ณผ์ ์ธ ํ๋ จ์ ์ํด ์ผ๋ฐ์ ์ผ๋ก ๋ ํฐ ๋ฐ์ดํฐ์ ์ด ํ์ํฉ๋๋ค.
โ SFTTrainer
- ์ง๋ ํ์ต ๋ฏธ์ธ ์กฐ์ (SFT): ์์ ๋ฐ์ดํฐ์ ์ผ๋ก PLMs Fine-Tuning์ ์ต์ ํ.
- ๊ฐ๋จํ ์ธํฐํ์ด์ค: ๋ ์ ์ configuration์ผ๋ก ๊ฐ์ํ๋ workflow๋ฅผ ์ ๊ณต.
- ํจ์จ์ ์ธ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ: PEFT์ ํจํน ์ต์ ํ์ ๊ฐ์ ๊ธฐ์ ์ ์ฌ์ฉํ์ฌ ํ๋ จ ์ค ๋ฉ๋ชจ๋ฆฌ ์๋น๋ฅผ ์ค์ ๋๋ค.
- ๋น ๋ฅธ ํ๋ จ: ์์ ๋ฐ์ดํฐ์ ๊ณผ ์งง์ ํ๋ จ ์๊ฐ์ผ๋ก๋ ์ ์ฌํ๊ฑฐ๋ ๋ ๋์ ์ ํ๋๋ฅผ ๋ฌ์ฑํฉ๋๋ค.
โ Trainer์ SFTTrainer ์ ํ ๊ธฐ์ค:
- Trainer ์ฌ์ฉ:
ํฐ ๋ฐ์ดํฐ์ ์ด ์๊ณ , ํ๋ จ ๋ฃจํ ๋๋ ๋ณต์กํ ํ๋ จ ์ํฌํ๋ก์ฐ์ ๋ํ ๊ด๋ฒ์ํ ์ปค์คํฐ๋ง์ด์ง์ด ํ์ํ ๊ฒฝ์ฐ.
Data preprocessing, Datacollator๋ ์ฌ์ฉ์๊ฐ ์ง์ ์ค์ ํด์ผ ํ๋ฉฐ, ์ผ๋ฐ์ ์ธ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ๋ฐฉ๋ฒ์ ์ฌ์ฉ
- SFTTrainer ์ฌ์ฉ:
PLMS์ ์๋์ ์ผ๋ก ์์ ๋ฐ์ดํฐ์ ์ ๊ฐ์ง๊ณ ์์ผ๋ฉฐ, ํจ์จ์ ์ธ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๊ณผ ํจ๊ป ๋ ๊ฐ๋จํ๊ณ ๋น ๋ฅธ ๋ฏธ์ธ ์กฐ์ ๊ฒฝํ์ ์ํ ๊ฒฝ์ฐ.
PEFT๋ฅผ ๊ธฐ๋ณธ์ ์ผ๋ก ์ง์, `peft_config`์ ๊ฐ์ ์ค์ ์ ํตํด ํจ์จ์ ์ธ ํ์ธ ํ๋์ ์ฝ๊ฒ ์ค์ ํ ์ ์๋ค.
Data preprocessing, Datacollator๋ ํจ์จ์ ์ธ FT๋ฅผ ์ํด ์ต์ ํ๋์ด ์์.
`dataset_text_field`์ ๊ฐ์ ํ๋๋ฅผ ํตํด ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ์ฝ๊ฒ ์ฒ๋ฆฌํ ์ ์์.
Feature Trainer SFTTrainer ๋ชฉ์ Gerneral Purpose training Supervised Fine-Tuning of PLMs ์ปค์คํ ์ฉ๋ Highly Customizable Simpler interface with fewer options Training workflow Handles complex workflows Streamlined workflow ํ์ Data Large Datsets Smaller Datasets Memory ์ฌ์ฉ๋ Higher Lower with PEFT & packing optimization Training speed Slower Faster with smaller datasets
'HuggingFace๐ค' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
HuggingFace(๐ค)-Tutorials (1) | 2024.07.31 |
---|---|
[Data Preprocessing] - Data Collator (1) | 2024.07.14 |
[QLoRA] & [PEFT] & deepspeed, DDP (0) | 2024.07.09 |