Collate: ํ•จ๊ป˜ ํ•ฉ์น˜๋‹ค.

์ด์—์„œ ์œ ์ถ”๊ฐ€๋Šฅํ•˜๋“ฏ, Data Collator๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์—ญํ• ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

 

 

Data Collator

Data Collator

์ผ๋ จ์˜ sample list๋ฅผ "single training mini-batch"์˜ Tensorํ˜•ํƒœ๋กœ ๋ฌถ์–ด์คŒ
Default Data Collator
์ด๋Š” ์•„๋ž˜์ฒ˜๋Ÿผ train_dataset์ด data_collator๋ฅผ ์ด์šฉํ•ด mini-batch๋กœ ๋ฌถ์—ฌ ๋ชจ๋ธ๋กœ ๋“ค์–ด๊ฐ€ ํ•™์Šตํ•˜๋Š”๋ฐ ๋„์›€์ด ๋œ๋‹ค.
trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,โ€‹





batch["input_ids"] , batch["labels"] ?

๋‹ค๋งŒ, ์œ„์™€ ๋‹ฌ๋ฆฌ ๋Œ€๋ถ€๋ถ„์˜ Data Collatorํ•จ์ˆ˜๋ฅผ ๋ณด๋ฉด ์•„๋ž˜์™€ ๊ฐ™์€ ์ฝ”๋“œ์˜ ํ˜•ํƒœ๋ฅผ ๋ ๋Š”๋ฐ, ์—ฌ๊ธฐ์„œ input_ids์™€ label์ด๋ผ๋Š” ์กฐ๊ธˆ ์ƒ์†Œํ•œ ๋‹จ์–ด๊ฐ€ ์žˆ๋‹ค:
class MyDataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, examples): 
        texts = []
        images = []
        for example in examples:
            image, question, answer = example 
            messages = [{"role": "user", "content": question},
                        {"role": "assistant", "content": answer}] # <-- ์—ฌ๊ธฐ๊นŒ์ง€ ์ž˜ ๋“ค์–ด๊ฐ€๋Š”๊ฒƒ ํ™•์ธ์™„๋ฃŒ.
            text = self.processor.tokenizer.apply_chat_template(messages, add_generation_prompt=False)
            texts.append(text)
            images.append(image)

        batch = self.processor(text=text, images=image, return_tensors="pt", padding=True)
        labels = batch["input_ids"].clone()
        if self.processor.tokenizer.pad_token_id is not None:
            labels[labels == self.processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        return batch

data_collator = MyDataCollator(processor)โ€‹

๊ณผ์—ฐ batch["input_ids"]์™€ batch["labels"]๊ฐ€ ๋ญ˜๊นŒ?

์ „์ˆ ํ–ˆ๋˜ data_collator๋Š” ์•„๋ž˜์™€ ๊ฐ™์€ ํ˜•์‹์„ ๋ ๋Š”๋ฐ, ์—ฌ๊ธฐ์„œ๋„ ๋ณด๋ฉด inputs์™€ labels๊ฐ€ ์žˆ๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

๋ชจ๋“  ๋ชจ๋ธ์€ ๋‹ค๋ฅด์ง€๋งŒ, ๋‹ค๋ฅธ๋ชจ๋ธ๊ณผ ์œ ์‚ฌํ•œ์ ์„ ๊ณต์œ ํ•œ๋‹ค
= ๋Œ€๋ถ€๋ถ„์˜ ๋ชจ๋ธ์€ ๋™์ผํ•œ ์ž…๋ ฅ์„ ์‚ฌ์šฉํ•œ๋‹ค!

โˆ™Input IDs

Input ID๋Š” ๋ชจ๋ธ์— ์ž…๋ ฅ์œผ๋กœ ์ „๋‹ฌ๋˜๋Š” "์œ ์ผํ•œ ํ•„์ˆ˜ ๋งค๊ฐœ๋ณ€์ˆ˜"์ธ ๊ฒฝ์šฐ๊ฐ€ ๋งŽ๋‹ค.
Input ID๋Š” token_index๋กœ, ์‚ฌ์šฉํ•  sequence(๋ฌธ์žฅ)๋ฅผ ๊ตฌ์„ฑํ•˜๋Š” token์˜ ์ˆซ์žํ‘œํ˜„์ด๋‹ค.
๊ฐ tokenizer๋Š” ๋‹ค๋ฅด๊ฒŒ ์ž‘๋™ํ•˜์ง€๋งŒ "๊ธฐ๋ณธ ๋ฉ”์ปค๋‹ˆ์ฆ˜์€ ๋™์ผ"ํ•˜๋‹ค.

ex)

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

sequence = "A Titan RTX has 24GB of VRAM"


tokenizer๋Š” sequence(๋ฌธ์žฅ)๋ฅผ tokenizer vocab์— ์žˆ๋Š” Token์œผ๋กœ ๋ถ„ํ• ํ•œ๋‹ค:

tokenized_sequence = tokenizer.tokenize(sequence)


token์€ word๋‚˜ subword ๋‘˜์ค‘ ํ•˜๋‚˜์ด๋‹ค:

print(tokenized_sequence)
# ์ถœ๋ ฅ: ['A', 'Titan', 'R', '##T', '##X', 'has', '24', '##GB', 'of', 'V', '##RA', '##M']
# ์˜ˆ๋ฅผ ๋“ค์–ด, "VRAM"์€ ๋ชจ๋ธ ์–ดํœ˜์— ์—†์–ด์„œ "V", "RA" ๋ฐ "M"์œผ๋กœ ๋ถ„ํ• ๋จ.
# ์ด๋Ÿฌํ•œ ํ† ํฐ์ด ๋ณ„๋„์˜ ๋‹จ์–ด๊ฐ€ ์•„๋‹ˆ๋ผ ๋™์ผํ•œ ๋‹จ์–ด์˜ ์ผ๋ถ€์ž„์„ ๋‚˜ํƒ€๋‚ด๊ธฐ ์œ„ํ•ด์„œ๋Š”?
# --> "RA"์™€ "M" ์•ž์— ์ด์ค‘ํ•ด์‹œ(##) ์ ‘๋‘์‚ฌ๊ฐ€ ์ถ”๊ฐ€๋ฉ


inputs = tokenizer(sequence)


์ด๋ฅผ ํ†ตํ•ด token์€ ๋ชจ๋ธ์ด ์ดํ•ด๊ฐ€๋Šฅํ•œ ID๋กœ ๋ณ€ํ™˜๋  ์ˆ˜ ์žˆ๋‹ค.
์ด๋•Œ, ๋ชจ๋ธ๋‚ด๋ถ€์—์„œ ์ž‘๋™ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” input_ids๋ฅผ key๋กœ, ID๊ฐ’์„ value๋กœ ํ•˜๋Š” "๋”•์…”๋„ˆ๋ฆฌ"ํ˜•ํƒœ๋กœ ๋ฐ˜ํ™˜ํ•ด์•ผํ•œ๋‹ค:

encoded_sequence = inputs["input_ids"]
print(encoded_sequence)
# ์ถœ๋ ฅ: [101, 138, 18696, 155, 1942, 3190, 1144, 1572, 13745, 1104, 159, 9664, 2107, 102]

๋˜ํ•œ, ๋ชจ๋ธ์— ๋”ฐ๋ผ์„œ ์ž๋™์œผ๋กœ "special token"์„ ์ถ”๊ฐ€ํ•˜๋Š”๋ฐ, 
์—ฌ๊ธฐ์—๋Š” ๋ชจ๋ธ์ด ๊ฐ€๋” ์‚ฌ์šฉํ•˜๋Š” "special IDs"๊ฐ€ ์ถ”๊ฐ€๋œ๋‹ค.

decoded_sequence = tokenizer.decode(encoded_sequence)
print(decoded_sequence)
# ์ถœ๋ ฅ: [CLS] A Titan RTX has 24GB of VRAM [SEP]





โˆ™Attention Mask

Attention Mask๋Š” Sequence๋ฅผ batch๋กœ ๋ฌถ์„ ๋•Œ ์‚ฌ์šฉํ•˜๋Š” Optionalํ•œ ์ธ์ˆ˜๋กœ 
"๋ชจ๋ธ์ด ์–ด๋–ค token์„ ์ฃผ๋ชฉํ•˜๊ณ  ํ•˜์ง€ ๋ง์•„์•ผ ํ•˜๋Š”์ง€"๋ฅผ ๋‚˜ํƒ€๋‚ธ๋‹ค.

ex)
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

sequence_a = "This is a short sequence."
sequence_b = "This is a rather long sequence. It is at least longer than the sequence A."

encoded_sequence_a = tokenizer(sequence_a)["input_ids"]
encoded_sequence_b = tokenizer(sequence_b)["input_ids"]

len(encoded_sequence_a), len(encoded_sequence_b)
# ์ถœ๋ ฅ: (8, 19)
์œ„๋ฅผ ๋ณด๋ฉด, encoding๋œ ๊ธธ์ด๊ฐ€ ๋‹ค๋ฅด๊ธฐ ๋•Œ๋ฌธ์— "๋™์ผํ•œ Tensor๋กœ ๋ฌถ์„ ์ˆ˜๊ฐ€ ์—†๋‹ค."
--> padding์ด๋‚˜ truncation์ด ํ•„์š”ํ•จ.
padded_sequences = tokenizer([sequence_a, sequence_b], padding=True)

padded_sequences["input_ids"]
# ์ถœ๋ ฅ: [[101, 1188, 1110, 170, 1603, 4954, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 1188, 1110, 170, 1897, 1263, 4954, 119, 1135, 1110, 1120, 1655, 2039, 1190, 1103, 4954, 138, 119, 102]]

padded_sequences["attention_mask"]
# ์ถœ๋ ฅ: [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
attention_mask๋Š” tokenizer๊ฐ€ ๋ฐ˜ํ™˜ํ•˜๋Š” dictionary์˜ "attention_mask" key์— ์กด์žฌํ•œ๋‹ค.


โˆ™Token Types IDs

์–ด๋–ค ๋ชจ๋ธ์˜ ๋ชฉ์ ์€ classification์ด๋‚˜ QA์ด๋‹ค.
์ด๋Ÿฐ ๋ชจ๋ธ์€ 2๊ฐœ์˜ "๋‹ค๋ฅธ ๋ชฉ์ ์„ ๋‹จ์ผ input_ids"ํ•ญ๋ชฉ์œผ๋กœ ๊ฒฐํ•ฉํ•ด์•ผํ•œ๋‹ค.
= [CLS], [SEP] ๋“ฑ์˜ ํŠน์ˆ˜ํ† ํฐ์„ ์ด์šฉํ•ด ์ˆ˜ํ–‰๋จ.

ex)
# [CLS] SEQUENCE_A [SEP] SEQUENCE_B [SEP]

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
sequence_a = "HuggingFace is based in NYC"
sequence_b = "Where is HuggingFace based?"

encoded_dict = tokenizer(sequence_a, sequence_b)
decoded = tokenizer.decode(encoded_dict["input_ids"])

print(decoded)
# ์ถœ๋ ฅ: [CLS] HuggingFace is based in NYC [SEP] Where is HuggingFace based? [SEP]
์œ„์˜ ์˜ˆ์ œ์—์„œ tokenizer๋ฅผ ์ด์šฉํ•ด 2๊ฐœ์˜ sequence๊ฐ€ 2๊ฐœ์˜ ์ธ์ˆ˜๋กœ ์ „๋‹ฌ๋˜์–ด ์ž๋™์œผ๋กœ ์œ„์™€๊ฐ™์€ ๋ฌธ์žฅ์„ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค.
์ด๋Š” seq์ดํ›„์— ๋‚˜์˜ค๋Š” seq์˜ ์‹œ์ž‘์œ„์น˜๋ฅผ ์•Œ๊ธฐ์—๋Š” ์ข‹๋‹ค.

๋‹ค๋งŒ, ๋‹ค๋ฅธ ๋ชจ๋ธ์€ token_types_ids๋„ ์‚ฌ์šฉํ•˜๋ฉฐ, token_type_ids๋กœ ์ด MASK๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
encoded_dict['token_type_ids']
# ์ถœ๋ ฅ: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]

 

์งˆ๋ฌธ์— ์‚ฌ์šฉ๋˜๋Š” context๋Š” ๋ชจ๋‘ 0์œผ๋กœ, 
์งˆ๋ฌธ์— ํ•ด๋‹น๋˜๋Š” sequence๋Š” ๋ชจ๋‘ 1๋กœ ์„ค์ •๋œ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.


โˆ™Position IDs

RNN: ๊ฐ ํ† ํฐ์˜ ์œ„์น˜๊ฐ€ ๋‚ด์žฅ. 
Transformer: ๊ฐ ํ† ํฐ์˜ ์œ„์น˜๋ฅผ ์ธ์‹ โŒ


∴ position_ids๋Š” ๋ชจ๋ธ์ด ๊ฐ ํ† ํฐ์˜ ์œ„์น˜๋ฅผ list์—์„œ ์‹๋ณ„ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” optional ๋งค๊ฐœ๋ณ€์ˆ˜.

๋ชจ๋ธ์— position_ids๊ฐ€ ์ „๋‹ฌ๋˜์ง€ ์•Š์œผ๋ฉด, ID๋Š” ์ž๋™์œผ๋กœ Absolute positional embeddings์œผ๋กœ ์ƒ์„ฑ:

Absolute positional embeddings์€ [0, config.max_position_embeddings - 1] ๋ฒ”์œ„์—์„œ ์„ ํƒ.

์ผ๋ถ€ ๋ชจ๋ธ์€ sinusoidal position embeddings์ด๋‚˜ relative position embeddings๊ณผ ๊ฐ™์€ ๋‹ค๋ฅธ ์œ ํ˜•์˜ positional embedding์„ ์‚ฌ์šฉ.




โˆ™Labels 

Labels๋Š” ๋ชจ๋ธ์ด ์ž์ฒด์ ์œผ๋กœ ์†์‹ค์„ ๊ณ„์‚ฐํ•˜๋„๋ก ์ „๋‹ฌ๋  ์ˆ˜ ์žˆ๋Š” Optional์ธ์ˆ˜์ด๋‹ค.
์ฆ‰, Labels๋Š” ๋ชจ๋ธ์˜ ์˜ˆ์ƒ ์˜ˆ์ธก๊ฐ’์ด์–ด์•ผ ํ•œ๋‹ค: ํ‘œ์ค€ ์†์‹ค์„ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ธก๊ฐ’๊ณผ ์˜ˆ์ƒ๊ฐ’(๋ ˆ์ด๋ธ”) ๊ฐ„์˜ ์†์‹ค์„ ๊ณ„์‚ฐ.


์ด๋•Œ, Labels๋Š” Model Head์— ๋”ฐ๋ผ ๋‹ค๋ฅด๋‹ค:

  • AutoModelForSequenceClassification: ๋ชจ๋ธ์€ (batch_size)์ฐจ์›ํ…์„œ๋ฅผ ๊ธฐ๋Œ€ํ•˜๋ฉฐ, batch์˜ ๊ฐ ๊ฐ’์€ ์ „์ฒด ์‹œํ€€์Šค์˜ ์˜ˆ์ƒ label์— ํ•ด๋‹น.

  • AutoModelForTokenClassification: ๋ชจ๋ธ์€ (batch_size, seq_length)์ฐจ์›ํ…์„œ๋ฅผ ๊ธฐ๋Œ€ํ•˜๋ฉฐ, ๊ฐ ๊ฐ’์€ ๊ฐœ๋ณ„ ํ† ํฐ์˜ ์˜ˆ์ƒ label์— ํ•ด๋‹น

  • AutoModelForMaskedLM: ๋ชจ๋ธ์€ (batch_size, seq_length)์ฐจ์›ํ…์„œ๋ฅผ ๊ธฐ๋Œ€ํ•˜๋ฉฐ, ๊ฐ ๊ฐ’์€ ๊ฐœ๋ณ„ ํ† ํฐ์˜ ์˜ˆ์ƒ ๋ ˆ์ด๋ธ”์— ํ•ด๋‹น: label์€ ๋งˆ์Šคํ‚น๋œ token_ids์ด๋ฉฐ, ๋‚˜๋จธ์ง€๋Š” ๋ฌด์‹œํ•  ๊ฐ’(๋ณดํ†ต -100).

  • AutoModelForConditionalGeneration: ๋ชจ๋ธ์€ (batch_size, tgt_seq_length)์ฐจ์›ํ…์„œ๋ฅผ ๊ธฐ๋Œ€ํ•˜๋ฉฐ, ๊ฐ ๊ฐ’์€ ๊ฐ ์ž…๋ ฅ ์‹œํ€€์Šค์™€ ์—ฐ๊ด€๋œ ๋ชฉํ‘œ ์‹œํ€€์Šค๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. ํ›ˆ๋ จ ์ค‘์—๋Š” BART์™€ T5๊ฐ€ ์ ์ ˆํ•œ ๋””์ฝ”๋” ์ž…๋ ฅ ID์™€ ๋””์ฝ”๋” ์–ดํ…์…˜ ๋งˆ์Šคํฌ๋ฅผ ๋‚ด๋ถ€์ ์œผ๋กœ ๋งŒ๋“ค๊ธฐ์— ๋ณดํ†ต ์ œ๊ณตํ•  ํ•„์š”X. ์ด๋Š” Encoder-Decoder ํ”„๋ ˆ์ž„์›Œํฌ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋ธ์—๋Š” ์ ์šฉ๋˜์ง€ ์•Š์Œ. ๊ฐ ๋ชจ๋ธ์˜ ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์—ฌ ๊ฐ ํŠน์ • ๋ชจ๋ธ์˜ ๋ ˆ์ด๋ธ”์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์ •๋ณด๋ฅผ ํ™•์ธํ•˜์„ธ์š”.

๊ธฐ๋ณธ ๋ชจ๋ธ(BertModel ๋“ฑ)์€ Labels๋ฅผ ๋ฐ›์•„๋“ค์ด์ง€ ๋ชปํ•˜๋Š”๋ฐ, ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์€ ๊ธฐ๋ณธ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ๋กœ์„œ ๋‹จ์ˆœํžˆ ํŠน์ง•๋“ค๋งŒ ์ถœ๋ ฅํ•œ๋‹ค.




โˆ™ Decoder input IDs

์ด ์ž…๋ ฅ์€ ์ธ์ฝ”๋”-๋””์ฝ”๋” ๋ชจ๋ธ์— ํŠนํ™”๋˜์–ด ์žˆ์œผ๋ฉฐ, ๋””์ฝ”๋”์— ์ž…๋ ฅ๋  ์ž…๋ ฅ ID๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
์ด๋Ÿฌํ•œ ์ž…๋ ฅ์€ ๋ฒˆ์—ญ ๋˜๋Š” ์š”์•ฝ๊ณผ ๊ฐ™์€ ์‹œํ€€์Šค-ํˆฌ-์‹œํ€€์Šค ์ž‘์—…์— ์‚ฌ์šฉ๋˜๋ฉฐ, ๋ณดํ†ต ๊ฐ ๋ชจ๋ธ์— ํŠน์ •ํ•œ ๋ฐฉ์‹์œผ๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.

๋Œ€๋ถ€๋ถ„์˜ ์ธ์ฝ”๋”-๋””์ฝ”๋” ๋ชจ๋ธ(BART, T5)์€ ๋ ˆ์ด๋ธ”์—์„œ ๋””์ฝ”๋” ์ž…๋ ฅ ID๋ฅผ ์ž์ฒด์ ์œผ๋กœ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
์ด๋Ÿฌํ•œ ๋ชจ๋ธ์—์„œ๋Š” ๋ ˆ์ด๋ธ”์„ ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ์ด ํ›ˆ๋ จ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์„ ํ˜ธ ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.

์‹œํ€€์Šค-ํˆฌ-์‹œํ€€์Šค ํ›ˆ๋ จ์„ ์œ„ํ•œ ์ด๋Ÿฌํ•œ ์ž…๋ ฅ ID๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ํ™•์ธํ•˜๋ ค๋ฉด ๊ฐ ๋ชจ๋ธ์˜ ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.



โˆ™Feed Forward Chunking

ํŠธ๋žœ์Šคํฌ๋จธ์˜ ๊ฐ ์ž”์ฐจ ์–ดํ…์…˜ ๋ธ”๋ก์—์„œ ์…€ํ”„ ์–ดํ…์…˜ ๋ ˆ์ด์–ด๋Š” ๋ณดํ†ต 2๊ฐœ์˜ ํ”ผ๋“œ ํฌ์›Œ๋“œ ๋ ˆ์ด์–ด ๋‹ค์Œ์— ์œ„์น˜ํ•ฉ๋‹ˆ๋‹ค.
ํ”ผ๋“œ ํฌ์›Œ๋“œ ๋ ˆ์ด์–ด์˜ ์ค‘๊ฐ„ ์ž„๋ฒ ๋”ฉ ํฌ๊ธฐ๋Š” ์ข…์ข… ๋ชจ๋ธ์˜ ์ˆจ๊ฒจ์ง„ ํฌ๊ธฐ๋ณด๋‹ค ํฝ๋‹ˆ๋‹ค(์˜ˆ: bert-base-uncased).

ํฌ๊ธฐ [batch_size, sequence_length]์˜ ์ž…๋ ฅ์— ๋Œ€ํ•ด ์ค‘๊ฐ„ ํ”ผ๋“œ ํฌ์›Œ๋“œ ์ž„๋ฒ ๋”ฉ์„ ์ €์žฅํ•˜๋Š” ๋ฐ ํ•„์š”ํ•œ ๋ฉ”๋ชจ๋ฆฌ [batch_size, sequence_length, config.intermediate_size]๋Š” ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์˜ ํฐ ๋ถ€๋ถ„์„ ์ฐจ์ง€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Reformer: The Efficient Transformer์˜ ์ €์ž๋“ค์€ ๊ณ„์‚ฐ์ด sequence_length ์ฐจ์›๊ณผ ๋…๋ฆฝ์ ์ด๋ฏ€๋กœ ๋‘ ํ”ผ๋“œ ํฌ์›Œ๋“œ ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ ์ž„๋ฒ ๋”ฉ [batch_size, config.hidden_size]_0, ..., [batch_size, config.hidden_size]_n์„ ๊ฐœ๋ณ„์ ์œผ๋กœ ๊ณ„์‚ฐํ•˜๊ณ  n = sequence_length์™€ ํ•จ๊ป˜ [batch_size, sequence_length, config.hidden_size]๋กœ ๊ฒฐํ•ฉํ•˜๋Š” ๊ฒƒ์ด ์ˆ˜ํ•™์ ์œผ๋กœ ๋™์ผํ•˜๋‹ค๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ–ˆ์Šต๋‹ˆ๋‹ค.

์ด๋Š” ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ด๋Š” ๋Œ€์‹  ๊ณ„์‚ฐ ์‹œ๊ฐ„์„ ์ฆ๊ฐ€์‹œํ‚ค๋Š” ๊ฑฐ๋ž˜๋ฅผ ํ•˜์ง€๋งŒ, ์ˆ˜ํ•™์ ์œผ๋กœ ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

apply_chunking_to_forward() ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ, chunk_size๋Š” ๋ณ‘๋ ฌ๋กœ ๊ณ„์‚ฐ๋˜๋Š” ์ถœ๋ ฅ ์ž„๋ฒ ๋”ฉ์˜ ์ˆ˜๋ฅผ ์ •์˜ํ•˜๋ฉฐ, ๋ฉ”๋ชจ๋ฆฌ์™€ ์‹œ๊ฐ„ ๋ณต์žก์„ฑ ๊ฐ„์˜ ๊ฑฐ๋ž˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. chunk_size๊ฐ€ 0์œผ๋กœ ์„ค์ •๋˜๋ฉด ํ”ผ๋“œ ํฌ์›Œ๋“œ ์ฒญํ‚น์€ ์ˆ˜ํ–‰๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

+ Recent posts