QLoRA ์‹ค์Šต with MLLMs(InternVL)

Step 1. ํ•„์š” Library import:

import os

import torch
import torch.nn as nn
import bitsandbytes as bnb
import transformers

from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel, 
    get_peft_model,
)
from transformers import (
    AutoConfig,
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    set_seed,
    pipeline,
    TrainingArguments,
)โ€‹


Step 2. ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜จ ํ›„ prepare_model_for_kbit_training(model) ์ง„ํ–‰

devices = [0]#[0, 3]
max_memory = {i: '49140MiB' for i in devices}

model_name = 'OpenGVLab/InternVL2-8B'


model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    cache_dir='/data/huggingface_models',
    trust_remote_code=True,
    device_map="auto",
    max_memory=max_memory,
    quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        ),
)

# ๋ชจ๋ธ ๊ตฌ์กฐ ์ถœ๋ ฅ
print(model)

# get_input_embeddings ๋ฉ”์„œ๋“œ๋ฅผ ๋ชจ๋ธ์— ์ถ”๊ฐ€
def get_input_embeddings(self):
    if hasattr(self, 'embed_tokens'):
        return self.embed_tokens
    elif hasattr(self, 'language_model') and hasattr(self.language_model.model, 'tok_embeddings'):
        return self.language_model.model.tok_embeddings
    else:
        raise NotImplementedError("The model does not have an attribute 'embed_tokens' or 'language_model.model.tok_embeddings'.")

model.get_input_embeddings = get_input_embeddings.__get__(model, type(model))

# prepare_model_for_kbit_training ํ•จ์ˆ˜๋ฅผ ์ง์ ‘ ๊ตฌํ˜„
def prepare_model_for_kbit_training(model):
    for param in model.parameters():
        param.requires_grad = False  # ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ์˜ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ์„ ๋น„ํ™œ์„ฑํ™”

    if hasattr(model, 'model') and hasattr(model.model, 'tok_embeddings'):
        for param in model.model.tok_embeddings.parameters():
            param.requires_grad = True  # ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด๋งŒ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ ํ™œ์„ฑํ™”
    elif hasattr(model, 'embed_tokens'):
        for param in model.embed_tokens.parameters():
            param.requires_grad = True  # ์ž„๋ฒ ๋”ฉ ๋ ˆ์ด์–ด๋งŒ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ ํ™œ์„ฑํ™”
    
    # ํ•„์š”ํ•œ ๊ฒฝ์šฐ ๋‹ค๋ฅธ ํŠน์ • ๋ ˆ์ด์–ด๋“ค๋„ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ์„ ํ™œ์„ฑํ™”ํ•  ์ˆ˜ ์žˆ์Œ
    # ์˜ˆ์‹œ: 
    # if hasattr(model, 'some_other_layer'):
    #     for param in model.some_other_layer.parameters():
    #         param.requires_grad = True

    return model

model = prepare_model_for_kbit_training(model)โ€‹


Step 3. QLoRA๋ฅผ ๋ถ™์ผ layer ์„ ํƒ:

def find_all_linear_names(model, train_mode):
    assert train_mode in ['lora', 'qlora']
    cls = bnb.nn.Linear4bit if train_mode == 'qlora' else nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names:  # LLM์˜ Head๋ถ€๋ถ„์— ์†ํ•˜๋Š” ์• ๋“ค pass
        lora_module_names.remove('lm_head')
    
    return list(lora_module_names)


print(sorted(config.target_modules)) # ['1','output', 'w1', 'w2', 'w3', 'wo', 'wqkv']
config.target_modules.remove('1') # LLM์˜ Head๋ถ€๋ถ„์— ์†ํ•˜๋Š” ์• ๋“ค ์ œ๊ฑฐ


config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=find_all_linear_names(model, 'qlora'),
    lora_dropout=0.05,
    bias="none",
    task_type="QUESTION_ANS" #CAUSAL_LM, FEATURE_EXTRACTION, QUESTION_ANS, SEQ_2_SEQ_LM, SEQ_CLS, TOKEN_CLS.
)

model = get_peft_model(model, config)

์ดํ›„ trainer๋กœ train์ง„ํ–‰.

QLoRA ๋ถ™์ธ ๊ฒฐ๊ณผ:

 

 

 

 

 

 

 

trainer ์ข…๋ฅ˜? Trainer vs SFTTrainer

Trainer  v.s. SFTTrainer

โˆ™ Trainer  v.s. SFTTrainer

 - ์ผ๋ฐ˜ ๋ชฉ์ ์˜ ํ›ˆ๋ จ: ํ…์ŠคํŠธ ๋ถ„๋ฅ˜, ์งˆ์˜์‘๋‹ต, ์š”์•ฝ ๋“ฑ์˜ ์ง€๋„ ํ•™์Šต ์ž‘์—…์—์„œ ๋ชจ๋ธ์„ ์ฒ˜์Œ๋ถ€ํ„ฐ ํ›ˆ๋ จ์‹œํ‚ค๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
 - ๋†’์€ ์ปค์Šคํ„ฐ๋งˆ์ด์ง• ๊ฐ€๋Šฅ์„ฑ: hyperparameter, optimizer, scheduler, logging, metric ๋“ฑ์„ ๋ฏธ์„ธ ์กฐ์ •ํ•  ์ˆ˜ ์žˆ๋Š” ๋‹ค์–‘ํ•œ ๊ตฌ์„ฑ ์˜ต์…˜์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
 - ๋ณต์žกํ•œ ํ›ˆ๋ จ ์›Œํฌํ”Œ๋กœ์šฐ ์ฒ˜๋ฆฌ: ๊ทธ๋ž˜๋””์–ธํŠธ ์ถ•์ , ์กฐ๊ธฐ ์ข…๋ฃŒ, ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ, ๋ถ„์‚ฐ ํ›ˆ๋ จ ๋“ฑ์˜ ๊ธฐ๋Šฅ์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.
 - ๋” ๋งŽ์€ ๋ฐ์ดํ„ฐ ์š”๊ตฌ: ํšจ๊ณผ์ ์ธ ํ›ˆ๋ จ์„ ์œ„ํ•ด ์ผ๋ฐ˜์ ์œผ๋กœ ๋” ํฐ ๋ฐ์ดํ„ฐ์…‹์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.



โˆ™ SFTTrainer

 - ์ง€๋„ ํ•™์Šต ๋ฏธ์„ธ ์กฐ์ • (SFT): ์ž‘์€ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ PLMs Fine-Tuning์— ์ตœ์ ํ™”.
 - ๊ฐ„๋‹จํ•œ ์ธํ„ฐํŽ˜์ด์Šค: ๋” ์ ์€ configuration์œผ๋กœ ๊ฐ„์†Œํ™”๋œ workflow๋ฅผ ์ œ๊ณต.
 - ํšจ์œจ์ ์ธ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ: PEFT์™€ ํŒจํ‚น ์ตœ์ ํ™”์™€ ๊ฐ™์€ ๊ธฐ์ˆ ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ›ˆ๋ จ ์ค‘ ๋ฉ”๋ชจ๋ฆฌ ์†Œ๋น„๋ฅผ ์ค„์ž…๋‹ˆ๋‹ค.
 - ๋น ๋ฅธ ํ›ˆ๋ จ: ์ž‘์€ ๋ฐ์ดํ„ฐ์…‹๊ณผ ์งง์€ ํ›ˆ๋ จ ์‹œ๊ฐ„์œผ๋กœ๋„ ์œ ์‚ฌํ•˜๊ฑฐ๋‚˜ ๋” ๋‚˜์€ ์ •ํ™•๋„๋ฅผ ๋‹ฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.



โˆ™ Trainer์™€ SFTTrainer ์„ ํƒ ๊ธฐ์ค€:

 - Trainer ์‚ฌ์šฉ:
ํฐ ๋ฐ์ดํ„ฐ์…‹์ด ์žˆ๊ณ , ํ›ˆ๋ จ ๋ฃจํ”„ ๋˜๋Š” ๋ณต์žกํ•œ ํ›ˆ๋ จ ์›Œํฌํ”Œ๋กœ์šฐ์— ๋Œ€ํ•œ ๊ด‘๋ฒ”์œ„ํ•œ ์ปค์Šคํ„ฐ๋งˆ์ด์ง•์ด ํ•„์š”ํ•œ ๊ฒฝ์šฐ.
Data preprocessing, Datacollator๋Š” ์‚ฌ์šฉ์ž๊ฐ€ ์ง์ ‘ ์„ค์ •ํ•ด์•ผ ํ•˜๋ฉฐ, ์ผ๋ฐ˜์ ์ธ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉ

 - SFTTrainer ์‚ฌ์šฉ:
PLMS์™€ ์ƒ๋Œ€์ ์œผ๋กœ ์ž‘์€ ๋ฐ์ดํ„ฐ์…‹์„ ๊ฐ€์ง€๊ณ  ์žˆ์œผ๋ฉฐ, ํšจ์œจ์ ์ธ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๊ณผ ํ•จ๊ป˜ ๋” ๊ฐ„๋‹จํ•˜๊ณ  ๋น ๋ฅธ ๋ฏธ์„ธ ์กฐ์ • ๊ฒฝํ—˜์„ ์›ํ•  ๊ฒฝ์šฐ.
PEFT๋ฅผ ๊ธฐ๋ณธ์ ์œผ๋กœ ์ง€์›, `peft_config`์™€ ๊ฐ™์€ ์„ค์ •์„ ํ†ตํ•ด ํšจ์œจ์ ์ธ ํŒŒ์ธ ํŠœ๋‹์„ ์‰ฝ๊ฒŒ ์„ค์ •ํ•  ์ˆ˜ ์žˆ๋‹ค.
Data preprocessing, Datacollator๋„ ํšจ์œจ์ ์ธ FT๋ฅผ ์œ„ํ•ด ์ตœ์ ํ™”๋˜์–ด ์žˆ์Œ.
`dataset_text_field`์™€ ๊ฐ™์€ ํ•„๋“œ๋ฅผ ํ†ตํ•ด ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ์‰ฝ๊ฒŒ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์Œ.



Feature Trainer SFTTrainer
๋ชฉ์  Gerneral Purpose training Supervised Fine-Tuning of PLMs
์ปค์Šคํ…€ ์šฉ๋„ Highly Customizable Simpler interface with fewer options
Training workflow Handles complex workflows Streamlined workflow
ํ•„์š” Data Large Datsets Smaller Datasets
Memory ์‚ฌ์šฉ๋Ÿ‰ Higher Lower with PEFT & packing optimization
Training speed Slower Faster with smaller datasets

'HuggingFace๐Ÿค—' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

HuggingFace(๐Ÿค—)-Tutorials  (1) 2024.07.31
[Data Preprocessing] - Data Collator  (1) 2024.07.14
[QLoRA] & [PEFT] & deepspeed, DDP  (0) 2024.07.09

+ Recent posts