๐Ÿ“Œ ๋ชฉ์ฐจ

1. Neural Machine Translation (NMT)
2. seq2seq

3. Attention
4. Input Feeding
5. AutoRegressive. &. Teacher Forcing
6. Searching Algorithm(Inference)  &  Beam Search
7. Performance Metric [PPL / BLEU / METEOR / ROUGE]

๐Ÿ˜š ๊ธ€์„ ๋งˆ์น˜๋ฉฐ...

 

 


1. Neural Machine Translation (NMT)

1.1 ๋ฒˆ์—ญ์˜ ๋ชฉํ‘œ
NMT๋Š” end-to-endํ•™์Šต์œผ๋กœ์จ, ๊ทœ์น™๊ธฐ๋ฐ˜๊ธฐ๊ณ„๋ฒˆ์—ญ(RBMT)๊ณผ ํ†ต๊ณ„๊ธฐ๋ฐ˜๊ธฐ๊ณ„๋ฒˆ์—ญ(SMT)์˜ ๋ช…๋ชฉ์„ ์ด์–ด๋ฐ›์•„์„œ ๊ฐ€์žฅ ํฐ ์„ฑ์ทจ๋ฅผ ์ด๋ฃฉํ•ด๋ƒˆ๋‹ค.
cf) end-to-endํ•™์Šต: ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์—์„œ๋ถ€ํ„ฐ ์›ํ•˜๋Š” ์ถœ๋ ฅ์„ ์ง์ ‘ ์˜ˆ์ธกํ•˜๊ณ  ํ•™์Šตํ•˜๋Š” ๋ฐฉ์‹, ์ค‘๊ฐ„ ๋‹จ๊ณ„๋‚˜ ํŠน์ง• ์ถ”์ถœ ๋‹จ๊ณ„ ์—†์ด ์ž…๋ ฅ๊ณผ ์ถœ๋ ฅ ๊ฐ„์˜ ๊ด€๊ณ„๋ฅผ ๋ชจ๋ธ๋งํ•˜๋ ค๋Š” ๊ฒƒ์„ ์˜๋ฏธ


๋ฒˆ์—ญ์˜ ๊ถ๊ทน์ ์ธ ๋ชฉํ‘œ: ์–ด๋–ค ์–ธ์–ด f์˜ ๋ฌธ์žฅ์ด ์ฃผ์–ด์งˆ ๋•Œ, ๊ฐ€๋Šฅํ•œ e ์–ธ์–ด์˜ ๋ฒˆ์—ญ๋ฌธ์žฅ ์ค‘, ์ตœ๋Œ€ํ™•๋ฅ ์„ ๊ฐ–๋Š” ๊ฐ’์„ ์ฐพ๋Š”๊ฒƒ

 

1.2 ๊ธฐ๊ณ„๋ฒˆ์—ญ์˜ ์—ญ์‚ฌ
โˆ™ ๊ทœ์น™๊ธฐ๋ฐ˜ ๊ธฐ๊ณ„๋ฒˆ์—ญ[RBMT]
์ฃผ์–ด์ง„ ๋ฌธ์žฅ์˜ ๊ตฌ์กฐ๋ฅผ ๋ถ„์„, ๊ทธ ๋ถ„์„์— ๋”ฐ๋ผ ๊ทœ์น™์„ ์„ธ์šฐ๊ณ , ๋ถ„๋ฅ˜๋ฅผ ๋‚˜๋ˆ  ์ •ํ•ด์ง„ ๊ทœ์น™์— ๋”ฐ๋ผ ๋ฒˆ์—ญ
์ด ๊ณผ์ •์„ ์‚ฌ๋žŒ์ด ๋ชจ๋‘ ๊ฐœ์ž…ํ•ด์•ผํ•˜๊ธฐ์— ๋น„์šฉ์  ์ธก๋ฉด์—์„œ ๋งค์šฐ ๋ถˆ๋ฆฌํ•˜๋‹ค.

โˆ™ ํ†ต๊ณ„๊ธฐ๋ฐ˜ ๊ธฐ๊ณ„๋ฒˆ์—ญ[SMT]
๋Œ€๋Ÿ‰์˜ Bi-Direction corpus์—์„œ ํ†ต๊ณ„๋ฅผ ์–ป์–ด ๋ฒˆ์—ญ์‹œ์Šคํ…œ์„ ๊ตฌ์„ฑํ•˜๋Š” ๊ฒƒ์œผ๋กœ
์•Œ๊ณ ๋ฆฌ์ฆ˜์ด๋‚˜ ์‹œ์Šคํ…œ์œผ๋กœ ์ธํ•ด ์–ธ์–ด์Œ์„ ํ™•์žฅํ•  ๋•Œ, RBMT์— ๋น„ํ•ด ํ›จ์”ฌ ์œ ๋ฆฌํ•˜๋‹ค.

โˆ™ ์‹ ๊ฒฝ๋ง ๊ธฐ๊ณ„๋ฒˆ์—ญ[NMT]
 - DNN ์ด์ „: Encoder-Decoderํ˜•ํƒœ์˜ ๊ตฌ์กฐ
 - DNN ์ดํ›„: end-to-end ๋ชจ๋ธ, NNLM๊ธฐ๋ฐ˜, ํ›Œ๋ฅญํ•œ ๋ฌธ์žฅ์ž„๋ฒ ๋”ฉ ๋“ฑ์˜ ์žฅ์ ์œผ๋กœ ๋งค์šฐ powerfulํ•ด์กŒ๋‹ค.

 

 

 


2. seq2seq

2.1 Architecture
seq2seq๋Š” ์‚ฌํ›„ํ™•๋ฅ  P(Y | X;θ)๋ฅผ ์ตœ๋Œ€๋กœํ•˜๋Š” ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ฐพ์•„์•ผ ํ•˜๋ฉฐ, ์ด ์‚ฌํ›„ํ™•๋ฅ ์„ ์ตœ๋Œ€๋กœ ํ•˜๋Š” Y๋ฅผ ์ฐพ์•„์•ผ ํ•˜๊ธฐ์— ํฌ๊ฒŒ 3๊ฐ€์ง€์˜ ์„œ๋ธŒ๋ชจ๋“ˆ[ Encoder / Decoder ]๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.

seq2seq : Sequence to Sequence[Sutskever2014; https://arxiv.org/abs/1409.3215]์˜ ํ˜์‹ ์„ฑ์€ "๊ฐ€๋ณ€๊ธธ์ด์˜ ๋ฌธ์žฅ์„ ๊ฐ€๋ณ€๊ธธ์ด์˜ ๋ฌธ์žฅ์œผ๋กœ ๋ณ€ํ™˜"ํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.
ex) ํ•œ๊ตญ์–ด→์˜์–ด๋กœ ๋ฒˆ์—ญ ์‹œ, ๋‘˜์˜ ๋ฌธ์žฅ๊ธธ์ด๊ฐ€ ๋‹ฌ๋ผ seq2seq model์ด ํ•„์š”

- ํ•™์Šต ์‹œ Decoder์˜ input๋ถ€๋ถ„๊ณผ output๋ถ€๋ถ„์ด ๋ชจ๋‘ ๋™์ž‘ํ•œ๋‹ค. 
์ฆ‰, ์ •๋‹ต์— ํ•ด๋‹นํ•˜๋Š” ์ถœ๋ ฅ์„ ์•Œ๋ ค์ฃผ๋Š” ๊ต์‚ฌ๊ฐ•์š”(teacher forcing)๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉ.
teacher forcing์— ๋Œ€ํ•ด์„œ๋Š” ์•„๋ž˜ 5๋ฒˆํ•ญ๋ชฉ์—์„œ ์„ค๋ช…ํ•˜๊ฒ ๋‹ค.

- ์˜ˆ์ธก ์‹œ ์ •๋‹ต์„ ๋ชจ๋ฅด๊ธฐ ๋•Œ๋ฌธ์— ํšŒ์ƒ‰ํ‘œ์‹œํ•œ input ๋ถ€๋ถ„์„ ์ œ์™ธํ•˜๊ณ  ์ž๊ธฐํšŒ๊ท€(auto-regressive) ๋ฐฉ์‹์œผ๋กœ ๋™์ž‘ํ•œ๋‹ค.
์ž๊ธฐํšŒ๊ท€์—์„œ <SOS>๊ฐ€ ์ž…๋ ฅ๋˜๋ฉด ์ฒซ ๋‹จ์–ด 'That'์„ ์ถœ๋ ฅํ•˜๊ณ  'That'์„ ๋ณด๊ณ  ๊ทธ ๋‹ค์Œ ๋‘˜์งธ ๋‹จ์–ด 'can't'๋ฅผ ์ถœ๋ ฅํ•œ๋‹ค.
์ฆ‰, ์ด์ „์— ์ถœ๋ ฅ๋œ ๋‹จ์–ด๋ฅผ ๋ณด๊ณ  ํ˜„์žฌ๋‹จ์–ด๋ฅผ ์ถœ๋ ฅํ•˜๋Š” ์ผ์„ ๋ฐ˜๋ณตํ•˜๋ฉฐ, ๋ฌธ์žฅ๋์„ ๋‚˜ํƒ€๋‚ด๋Š” <EOS>๊ฐ€ ๋ฐœ์ƒํ•˜๋ฉด ๋ฉˆ์ถ˜๋‹ค.


- ํ•œ๊ณ„ : ๊ฐ€์žฅ ํฐ ๋ฌธ์ œ๋Š” encoder์˜ ๋งˆ์ง€๋ง‰ hidden state๋งŒ decoder์— ์ „๋‹ฌํ•œ๋‹ค๋Š” ์ ์ด๋‹ค.
์•„๋ž˜ ๊ทธ๋ฆผ์—์„œ ๋ณด๋ฉด h5๋งŒ decoder๋กœ ์ „๋‹ฌ๋œ๋‹ค.
๋”ฐ๋ผ์„œ encoder๋Š” ๋งˆ์ง€๋ง‰ hidden state์— ๋ชจ๋“  ์ •๋ณด๋ฅผ ์••์ถ•ํ•ด์•ผํ•˜๋Š” ๋ถ€๋‹ด์ด ์กด์žฌํ•œ๋‹ค.

 

 

Encoder
์ฃผ์–ด์ง„ ๋ฌธ์žฅ์ธ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๋ฒกํ„ฐ๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„ ๋ฌธ์žฅ์„ ํ•จ์ถ•ํ•˜๋Š” ๋ฌธ์žฅ์ž„๋ฒ ๋”ฉ๋ฒกํ„ฐ๋กœ ๋งŒ๋“ ๋‹ค.
์ฆ‰, P(z | X)๋ฅผ ๋ชจ๋ธ๋ง ํ›„, ์ฃผ์–ด์ง„ ๋ฌธ์žฅ์„ manifold๋ฅผ ๋”ฐ๋ผ ์ฐจ์›์ถ•์†Œ, ํ•ด๋‹น ๋„๋ฉ”์ธ์˜ latent space์˜ ์–ด๋–ค ํ•˜๋‚˜์˜ ์ ์— ํˆฌ์˜ํ•˜๋Š” ์ž‘์—…์ด๋‹ค.

๋‹ค๋งŒ ๊ธฐ์กด์˜ text classification์—์„œ๋Š” ๋ชจ๋“  ์ •๋ณด(feature)๊ฐ€ ํ•„์š”ํ•˜์ง€ ์•Š๋‹ค.
๋”ฐ๋ผ์„œ ๋ฒกํ„ฐ์ƒ์„ฑ ์‹œ ๋งŽ์€ ์ •๋ณด๋ฅผ ๊ฐ€์งˆ ํ•„์š”๊ฐ€ ์—†๋‹ค.
ํ•˜์ง€๋งŒ NMT๋ฅผ ์œ„ํ•œ sentence embedding vector ์ƒ์„ฑ ์‹œ, ์ตœ๋Œ€ํ•œ ๋งŽ์€ ์ •๋ณด๋ฅผ ๊ฐ€์ ธ์•ผํ•œ๋‹ค.

์ถ”๊ฐ€์ ์œผ๋กœ seq2seq๋ชจ๋ธ์—์„œ hidden layer๊ฐ„์— concatenate ์ž‘์—…์œผ๋กœ ์ „์ฒด time-step์„ ํ•œ๋ฒˆ์— ๋ณ‘๋ ฌ๋กœ ์ฒ˜๋ฆฌํ•œ๋‹ค.

 

Decoder
์ผ์ข…์˜ ์กฐ๊ฑด๋ถ€ ์‹ ๊ฒฝ๋ง ์–ธ์–ด๋ชจ๋ธ[CNNLM]์— ์กฐ๊ฑด๋ถ€ ํ™•๋ฅ ๋ณ€์ˆ˜ ๋ถ€๋ถ„์— X๊ฐ€ ์ถ”๊ฐ€๋œ ํ˜•ํƒœ๋ผ ํ•  ์ˆ˜ ์žˆ๋‹ค.
์ฆ‰, encoder์˜ ๊ฒฐ๊ณผ์ธ sentence embedding vector์™€ ์ด์ „ time-step๊นŒ์ง€ ๋ฒˆ์—ญํ•ด ์ƒ์„ฑํ•œ ๋‹จ์–ด๋“ค์— ๊ธฐ๋ฐ˜ํ•ด ํ˜„์žฌ time-step์˜ ๋‹จ์–ด๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.

ํŠน์ดํ•œ ์ ์€ Decoder์˜ ์ž…๋ ฅ์˜ ์ดˆ๊ธฐ๊ฐ’์œผ๋กœ์จ BOS token์„ ์ž…๋ ฅ์œผ๋กœ ์ค€๋‹ค๋Š” ์ ์ด๋‹ค.
โ—๏ธBOS (Beginning of Sentence): BOS๋Š” ๋ฌธ์žฅ์˜ ์‹œ์ž‘์„ ๋‚˜ํƒ€๋‚ด๋Š” ํŠน๋ณ„ํ•œ ํ† ํฐ ๋˜๋Š” ์‹ฌ๋ณผ๋กœ ์ฃผ๋กœ Seq2Seq๋ชจ๋ธ๊ณผ ๊ฐ™์€ ๋ชจ๋ธ์—์„œ ์ž…๋ ฅ ์‹œํ€€์Šค์˜ ์‹œ์ž‘์„ ํ‘œ์‹œํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋œ๋‹ค.

 

Generator
Decoder์—์„œ ๊ฐ time-step๋ณ„๋กœ ๊ฒฐ๊ณผ๋ฒกํ„ฐ h๋ฅผ ๋ฐ›์•„ softmax๋ฅผ ๊ณ„์‚ฐํ•ด ๊ฐ target์–ธ์–ด์˜ ๋‹จ์–ด์–ดํœ˜๋ณ„ ํ™•๋ฅ ๊ฐ’์„ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
์ƒ์„ฑ์ž์˜ ๊ฒฐ๊ณผ๊ฐ’์€ ๊ฐ ๋‹จ์–ด๊ฐ€ ๋‚˜ํƒ€๋‚œ ํ™•๋ฅ ์ธ ์ด์‚ฐํ™•๋ฅ ๋ถ„ํฌ๊ฐ€ ๋œ๋‹ค.

์ด๋•Œ, ์ฃผ์˜ํ•  ์ ์€ ๋ฌธ์žฅ์˜ ๊ธธ์ด๊ฐ€ |Y|=m์ด๋ผ๋ฉด, ๋งˆ์ง€๋ง‰ ๋ฐ˜ํ™˜๋‹จ์–ด ym์€ EOS token์ด ๋œ๋‹ค๋Š” ์ ์ด๋‹ค.
์ด EOS๋กœ Decoder ๊ณ„์‚ฐ์˜ ์ข…๋ฃŒ๋ฅผ ๋‚˜ํƒ€๋‚ธ๋‹ค.
โ—๏ธEOS (End of Sentence): EOS๋Š” ๋ฌธ์žฅ์˜ ๋์„ ๋‚˜ํƒ€๋‚ด๋Š” ํŠน๋ณ„ํ•œ ํ† ํฐ ๋˜๋Š” ์‹ฌ๋ณผ๋กœ ์ฃผ๋กœ Seq2Seq ๋ชจ๋ธ๊ณผ ๊ฐ™์€ ๋ชจ๋ธ์—์„œ ์ถœ๋ ฅ ์‹œํ€€์Šค์˜ ๋์„ ๋‚˜ํƒ€๋‚ด๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

 

2.2 seq2seq ํ™œ์šฉ๋ถ„์•ผ
๊ธฐ๊ณ„๋ฒˆ์—ญ(MT)
ChatBot
Summarization
Speech Recognition
Image Captioning

 

2.3 ํ•œ๊ณ„ [Bottleneck Problem]
๊ฐ€์žฅ ํฐ ๋ฌธ์ œ๋Š” encoder์˜ ๋งˆ์ง€๋ง‰ hidden state๋งŒ decoder์— ์ „๋‹ฌํ•œ๋‹ค๋Š” ์ ์ด๋‹ค.
์ด๋ฅผ Bottleneck Problem์ด๋ผ ํ•˜๋Š”๋ฐ, Encoder๋Š” ๋งˆ์ง€๋ง‰ hidden state์— ํ•˜๋‚˜์˜ ๊ณ ์ •๋œ ํฌ๊ธฐ์˜ single vector๋กœ ๋ชจ๋“  ์ •๋ณด๋ฅผ ์••์ถ•ํ•ด์•ผํ•˜๋Š” ๋ถ€๋‹ด์ด ์กด์žฌํ•˜๊ฒŒ ๋œ๋‹ค.
์ฆ‰, ์ •๋ณด์†์‹ค์ด ๋ฐœ์ƒ ๋ฐ ๊ธฐ์šธ๊ธฐ ์†Œ์‹ค์ด ๋˜์–ด๋ฒ„๋ฆฐ๋‹ค.

์ด๋Ÿฐ ํ•œ๊ณ„๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๊ฒƒ์ด ๋ฐ”๋กœ Attention ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ์ด์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์ด๋‹ค.
attention ๋ฉ”์ปค๋‹ˆ์ฆ˜์€ ๊ด€๋ จ์žˆ๋Š” ๋‹จ์–ด์™€์˜ attention์„ ๋†’์—ฌ ๊ธฐ์กด์ฒ˜๋Ÿผ ๋’ค์— ์ง‘์ค‘๋˜๋Š” ํ˜„์ƒ์„ ๋ฐฉ์ง€ํ•œ๋‹ค.
์ฆ‰, "ํŠน์ •๋ถ€๋ถ„์— ์ง‘์ค‘"ํ•˜๊ธฐ ์œ„ํ•ด Decoder์˜ ๊ฐ ๋‹จ๊ณ„์—์„œ encoder์™€ ์ง์ ‘์ ์ธ ์—ฐ๊ฒฐ์„ ํ•˜๊ฒŒ ํ•œ๋‹ค.

 


Pytorch ์˜ˆ์ œ

Encoder ํด๋ž˜์Šค
Encoder๋Š” RNN์„ ์‚ฌ์šฉํ•œ text classification๊ณผ ๊ฑฐ์˜ ์œ ์‚ฌํ•˜๋‹ค.
๋”ฐ๋ผ์„œ Bi-Directional LSTM์„ ์‚ฌ์šฉํ•œ๋‹ค.
class Encoder(nn.Module):

    def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2):
        super(Encoder, self).__init__()

        # Be aware of value of 'batch_first' parameter.
        # Also, its hidden_size is half of original hidden_size,
        # because it is bidirectional.
        self.rnn = nn.LSTM(
            word_vec_size,
            int(hidden_size / 2),
            num_layers=n_layers,
            dropout=dropout_p,
            bidirectional=True,
            batch_first=True,
        )

    def forward(self, emb):
        # |emb| = (batch_size, length, word_vec_size)

        if isinstance(emb, tuple):
            x, lengths = emb
            x = pack(x, lengths.tolist(), batch_first=True)

            # Below is how pack_padded_sequence works.
            # As you can see,
            # PackedSequence object has information about mini-batch-wise information,
            # not time-step-wise information.
            # 
            # a = [torch.tensor([1,2,3]), torch.tensor([3,4])]
            # b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True)
            # >>>>
            # tensor([[ 1,  2,  3],
            #     [ 3,  4,  0]])
            # torch.nn.utils.rnn.pack_padded_sequence(b, batch_first=True, lengths=[3,2]
            # >>>>PackedSequence(data=tensor([ 1,  3,  2,  4,  3]), batch_sizes=tensor([ 2,  2,  1]))
        else:
            x = emb

        y, h = self.rnn(x)
        # |y| = (batch_size, length, hidden_size)
        # |h[0]| = (num_layers * 2, batch_size, hidden_size / 2)

        if isinstance(emb, tuple):
            y, _ = unpack(y, batch_first=True)

        return y, h
Decoder ํด๋ž˜์Šค
์ดํ›„ ๋‚˜์˜ฌ Attention๊ฐœ๋…์„ ์ถ”๊ฐ€ํ•˜๋ฉด, ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

class Decoder(nn.Module):

    def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2):
        super(Decoder, self).__init__()

        # Be aware of value of 'batch_first' parameter and 'bidirectional' parameter.
        self.rnn = nn.LSTM(
            word_vec_size + hidden_size,
            hidden_size,
            num_layers=n_layers,
            dropout=dropout_p,
            bidirectional=False,
            batch_first=True,
        )

    def forward(self, emb_t, h_t_1_tilde, h_t_1):
        # |emb_t| = (batch_size, 1, word_vec_size)
        # |h_t_1_tilde| = (batch_size, 1, hidden_size)
        # |h_t_1[0]| = (n_layers, batch_size, hidden_size)
        batch_size = emb_t.size(0)
        hidden_size = h_t_1[0].size(-1)

        if h_t_1_tilde is None:
            # If this is the first time-step,
            h_t_1_tilde = emb_t.new(batch_size, 1, hidden_size).zero_()

        # Input feeding trick.
        x = torch.cat([emb_t, h_t_1_tilde], dim=-1)

        # Unlike encoder, decoder must take an input for sequentially.
        y, h = self.rnn(x, h_t_1)

        return y, h
Generator ํด๋ž˜์Šค
class Generator(nn.Module):

    def __init__(self, hidden_size, output_size):
        super(Generator, self).__init__()

        self.output = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        # |x| = (batch_size, length, hidden_size)

        y = self.softmax(self.output(x))
        # |y| = (batch_size, length, output_size)

        # Return log-probability instead of just probability.
        return y
Loss Function
'softmax+CE'๋ณด๋‹ค๋Š” 'logsoftmax + NLL(์Œ์˜ ๋กœ๊ทธ๊ฐ€๋Šฅ๋„)'๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.
# Default weight for loss equals to 1
# But, we don't need to get loss for PAD Token
# Thus, set a weight for PAD to 0

loss_weight = torch.ones(output_size)
loss_weight[data_loader.PAD] = 0.

# Instead of using Cross-Entropy,
# We can use NLL(Negative-Log-Likelihood) loss with log-probability
crit = nn.NLLLoss(weight=loss_weight, reduction='sum', )


๋”ฐ๋ผ์„œ softmax์‚ฌ์šฉ๋Œ€์‹  logsoftmaxํ•จ์ˆ˜๋กœ ๋กœ๊ทธํ™•๋ฅ ์„ ๊ตฌํ•œ๋‹ค.

def _get_loss(self, y_hat, y, crit=None):
	# |y_hat| = (batch_size, length, output_size)
    # |y| = (batch_size, length)
    
    crit = self.crit if crit is None else crit
    loss = crit(y_hat.contiguous().view(-1, y_hat.size(-1)), 
    			y_contiguous().view(-1)
                )
    return loss

 

 

 

 

 

 

 

 

 

 

 

 


3. Attention

3.1 Attention์˜ ๋ชฉํ‘œ
Query์™€ ๋น„์Šทํ•œ ๊ฐ’์„ ๊ฐ–๋Š” Key๋ฅผ ์ฐพ์•„ ๊ทธ ๊ฐ’์„ ์–ป๋Š” ๊ณผ์ •. (์ด๋•Œ, ๊ทธ ๊ฐ’์„ Value๋ผ ํ•œ๋‹ค.)

 

3.2 Key-Value ํ•จ์ˆ˜
โˆ™ Python์˜ Dictionary: Key-Value์˜ ์Œ์œผ๋กœ ์ด๋ฃจ์–ด์ง„ ์ž๋ฃŒํ˜•
Dic = {'A.I':9 , 'computer':5, 'NLP':4}


์ด์™€ ๊ฐ™์ด Key์™€ Value์— ํ•ด๋‹นํ•˜๋Š” ๊ฐ’๋“ค์„ ๋„ฃ๊ณ , Key๋ฅผ ํ†ตํ•ด Value๊ฐ’์— ์ ‘๊ทผ๊ฐ€๋Šฅํ•˜๋‹ค.
์ฆ‰, Query๊ฐ€ ์ฃผ์–ด์งˆ ๋•Œ, Key๊ฐ’์—๋”ฐ๋ผ Value๊ฐ’์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๋‹ค.
def KV(Q):
	weights = []
    
    for K in dic.keys():
    	weights += [is_same(K, Q)]
    
    weight_sum = sum(weights)
    
    for i, w in enumerate(weights):
    	weights[i] = weights[i] / weight_sum
        
    ans = 0
    
    for weight, V in zip(weights, dic.values()):
    	ans += weight*V 
    
    return ans
def is same(K, Q):
	if K == Q:
    	return 1.
    else:
    	reutnr .0

 

3.3 ์—ฐ์†์ ์ธ Key-Value ๋ฒกํ„ฐ ํ•จ์ˆ˜
๋งŒ์•ฝ Dic์˜ Value์— 100์ฐจ์› ๋ฒกํ„ฐ๊ฐ€ ๋“ค์–ด๊ฐ€์žˆ๋‹ค๋ฉด??
ํ˜น์€ Query์™€ Key๊ฐ’ ๋ชจ๋‘ ๋ฒกํ„ฐ๋ฅผ ๋‹ค๋ค„์•ผ ํ•œ๋‹ค๋ฉด??
โ“ ์ฆ‰, Q, K๊ฐ€ word embedding vector๋ผ๋ฉด??
โ“ ๋˜๋Š”, Dic์˜ K, V๊ฐ’์ด ์„œ๋กœ ๊ฐ™๋‹ค๋ฉด??

def KV(Q):
	weights = []
    
    for K in dic.keys():
    	weights += [how_similar(K, Q)]	# cosine similarity๊ฐ’์„ ์ฑ„์šด๋‹ค.
        
    weights = softmax(weights)	# ๋ชจ๋“  ๊ฐ€์ค‘์น˜๋ฅผ ๊ตฌํ•œ ํ›„ softmax๋ฅผ ๊ณ„์‚ฐ(๋ชจ๋“  wํ•ฉํฌ๊ธฐ๋ฅผ 1๋กœ ๊ณ ์ •)
    ans = 0
    
    for w, V in zip(weights, dic.values()):
    	ans += w*V
        
    return ans

 

์œ„์˜ ์ฝ”๋“œ์—์„œ ans์—๋Š” ์–ด๋– ํ•œ ๋ฒกํ„ฐ๊ฐ’์ด ๋“ค์–ด๊ฐ„๋‹ค.
ans๋‚ด๋ถ€์˜ ๋ฒกํ„ฐ๋“ค์˜ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„์— ๋”ฐ๋ผ ๋ฒกํ„ฐ๊ฐ’์ด ์ •ํ•ด์ง„๋‹ค.

์ฆ‰, ์œ„์˜ ํ•จ์ˆ˜๋Š” Q์™€ ๋น„์Šทํ•œ K๊ฐ’์„ ์ฐพ์•„ ์œ ์‚ฌ๋„์— ๋”ฐ๋ผ Weight๋ฅผ ์ •ํ•˜๊ณ , ๊ฐ K์™€ V๊ฐ’์„ W๊ฐ’๋งŒํผ ๊ฐ€์ ธ์™€ ๋ชจ๋‘ ๋”ํ•˜๋Š” ๊ฒƒ์œผ๋กœ
์ด๊ฒƒ์ด ๋ฐ”๋กœ Attention Mechanism์˜ ํ•ต์‹ฌ ์•„์ด๋””์–ด์ด๋‹ค.

 

3.4 NMT์—์„œ์˜ Attention
๊ทธ๋ ‡๋‹ค๋ฉด, MT์—์„œ Attention Mechanism์€ ์–ด๋–ป๊ฒŒ ์ž‘๋™๋ ๊นŒ?
  โˆ™ K,V: Encoder์˜ ๊ฐ time-step๋ณ„ ์ถœ๋ ฅ
  โˆ™ Q: ํ˜„์žฌ time-step์˜ Decoder์ถœ๋ ฅ

Seq2Seq with Attention
์›ํ•˜๋Š” ์ •๋ณด๋ฅผ Attention์„ ํ†ตํ•ด Encoder์—์„œ ์–ป๊ณ ,
ํ•ด๋‹น ์ •๋ณด๋ฅผ Decoder์˜ ์ถœ๋ ฅ๊ณผ ์ด์–ด๋ถ™์—ฌ tanh๋ฅผ ์ทจํ•œ ํ›„
softmax๊ณ„์‚ฐ์„ ํ†ตํ•ด ๋‹ค์Œ time-step์˜ ์ž…๋ ฅ์ด ๋˜๋Š” y_hat์„ ๊ตฌํ•œ๋‹ค.

 

Linear Transformation
์‹ ๊ฒฝ๋ง ๋‚ด๋ถ€์˜ ๊ฐ ์ฐจ์›๋“ค์€ latent feature๊ฐ’์ด๊ธฐ์— ์ •ํ™•ํžˆ ์ •๋ฆฌํ•  ์ˆ˜ ์—†๋‹ค.
ํ•˜์ง€๋งŒ, ํ™•์‹คํ•œ ์ ์€ source์–ธ์–ด์™€ ๋Œ€์ƒ์–ธ์–ด๊ฐ€ ์• ์ดˆ์— ๋‹ค๋ฅด๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.
๋”ฐ๋ผ์„œ, ๋‹จ์ˆœํžˆ ๋ฒกํ„ฐ๋‚ด์ ์„ ํ•˜๊ธฐ๋ณด๋‹จ ์†Œ์Šค์™€ ๋Œ€์ƒ๊ฐ„์— ์—ฐ๊ฒฐ๊ณ ๋ฆฌ๊ฐ€ ํ•„์š”ํ•˜๋‹ค.

๋”ฐ๋ผ์„œ, ๋‘ ์–ธ์–ด๊ฐ€ ๊ฐ๊ฐ ์ž„๋ฒ ๋”ฉ๋œ latent space์ด ์„ ํ˜•๊ด€๊ณ„์— ์žˆ๋‹ค ๊ฐ€์ •ํ•˜๊ณ ,
๋‚ด์ ์—ฐ์‚ฐ์ˆ˜ํ–‰์„ ์œ„ํ•ด ์„ ํ˜•๋ณ€ํ™˜์„ ํ•ด์ค€๋‹ค. (์„ ํ˜•๋ณ€ํ™˜์„ ์œ„ํ•œ W๊ฐ’์€ ๊ฐ€์ค‘์น˜๋กœ FF, BP๋กœ ํ•™์Šต๋œ๋‹ค.)


โ“์™œ Attention์ด ํ•„์š”ํ• ๊นŒ?์— ๋Œ€ํ•œ ์งˆ๋ฌธ์—์„œ, ์ด ์„ ํ˜•๋ณ€ํ™˜์„ ๋ฐฐ์šฐ๋Š” ๊ฒƒ ์ž์ฒด๊ฐ€ Attention์ด๋ผ ํ‘œํ˜„ํ•˜๋Š” ๊ฒƒ์€ ๊ณผํ•˜์ง€ ์•Š๋‹ค.
(โˆต ์„ ํ˜•๋ณ€ํ™˜๊ณผ์ •์œผ๋กœ Decoder์˜ ํ˜„์žฌ์ƒํƒœ์— ํ•„์š”ํ•œ Q๋ฅผ ์ƒ์„ฑ, Encoder์˜ K๊ฐ’๋“ค๊ณผ ๋น„๊ต, ๊ฐ€์ค‘ํ•ฉ์„ ํ•˜๋Š” ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ)
์ฆ‰, Attention์„ ํ†ตํ•ด Decoder๋Š” Encoder์— Q๋ฅผ ์ „๋‹ฌํ•˜๋ฉฐ, ์ด๋•Œ ์ข‹์€ ๊ฐ’์„ ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ์€ ์ข‹์€ ๊ฒฐ๊ณผ๋กœ ์ด์–ด์ง€๊ธฐ ๋•Œ๋ฌธ์— ํ˜„์žฌ Decoder์ƒํƒœ์— ํ•„์š”ํ•œ ์ •๋ณด๊ฐ€ ๋ฌด์—‡์ธ์ง€ ์Šค์Šค๋กœ ํŒ๋‹จํ•ด ์„ ํ˜•๋ณ€ํ™˜์„ ํ†ตํ•ด Q๋ฅผ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด ๋งค์šฐ ์ค‘์š”ํ•œ ๊ฒƒ์ด๋‹ค.
๋˜ํ•œ ์„ ํ˜•๋ณ€ํ™˜์„ ์œ„ํ•œ ๊ฐ€์ค‘์น˜ ์ž์ฒด๋„ ํ•œ๊ณ„๊ฐ€ ์žˆ๊ธฐ์— Decoder์˜ ์ƒํƒœ ์ž์ฒด๊ฐ€ ์„ ํ˜•๋ณ€ํ™˜์ด ๋˜์–ด Q๊ฐ€ ์ข‹์€ ํ˜•ํƒœ๊ฐ€ ๋˜๋„๋ก RNN์ด ๋™์ž‘ํ•  ๊ฒƒ์ด๋‹ค.

์ˆ˜์‹ ๋“ฑ์˜ ์ž์„ธํ•œ ์˜ˆ์‹œ ๋ฐ ๊ณผ์ •๋“ค ์ฐธ๊ณ : https://chan4im.tistory.com/161
 

[๋…ผ๋ฌธ preview] - ViT : Vision Transformer(2020). part 1. Attention

[Attention์˜ ๋ฐฐ๊ฒฝ] ๐Ÿง ๊ณ ์ „์ ์ธ Attention Algorithm 1. Feature Selection - feature selection์€ ์œ ์šฉํ•œ ํŠน์ง•์„ ๋‚จ๊ธฐ๊ณ  ๋‚˜๋จธ์ง€๋Š” ์ œ๊ฑฐํ•˜๋Š” ๋ฐฉ๋ฒ•์ด๋‹ค. ์—ฌ๊ธฐ์„œ ์œ ์šฉํ•œ ํŠน์ง•์ด๋ž€ ์˜ˆ๋ฅผ ๋“ค์–ด ๋ถ„๋ณ„๋ ฅ์ด ๊ฐ•ํ•œ ํŠน์ง• ๋“ฑ์„

chan4im.tistory.com

 

 

Pytorch ์˜ˆ์ œ

โˆ™ Attention ํด๋ž˜์Šค: ์„ ํ˜•๋ณ€ํ™˜์„ ์œ„ํ•œ ๊ฐ€์ค‘์น˜ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ bias๊ฐ€ ์—†๋Š” ์„ ํ˜•์ธต์œผ๋กœ ๋Œ€์ฒดํ•˜์˜€๋‹ค.
์กฐ๊ธˆ ๋” ์ž์„ธํ•œ ์„ค๋ช…์€ ๋‹ค์Œ Section์—์„œ ๊ณ„์† ์ง„ํ–‰ํ•œ๋‹ค.
class Attention(nn.Module):
	def __init__(self, hidden_size):
    	super(Attention, self).__init__()
        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.softmax = nn.Softmax(dim = -1)
        
    def forward(self, h_src, h_t_target, mask=None):
    	# |h_src| = (batch_size, length, hidden_size)
        # |h_t_target| = (batch_size, 1, hidden_size)
        # |mask| = (batch_size, length)
        
        Q = self.linear(h_t_target.squeeze(1)).unsqueeze(-1)
        # |Q| = (batch_size, hidden_size, 1)
        
        weight = torch.bmm(h_src, Q).squeeze(-1)
        # |weight| = (batch_size, length)
        
        if mask is not None:
        	# Set each weight as -inf, if the mask value equals to 1.
            # Since the softmax operation makes -inf to 0,
            # masked weights would be set to 0 after softmax operation.
            # Thus, if the sample is shorter than other samples in mini-batch,
            # the weight for empty time-step would be set to 0.
            weight.masked_fill_(mask.unsqueeze(1), -float('inf'))
        
        weight = self.softmax(weight)

        context_vector = torch.bmm(weight, h_src)
        # |context_vector| = (batch_size, 1, hidden_size)

        return context_vector


cf) bmm (batch matrix multiplication; ๋ฐฐ์น˜ ํ–‰๋ ฌ๊ณฑ)
torch.bmm(x, y)์— ๋Œ€ํ•ด ์„ค๋ช…ํ•ด๋ณด์ž.

โˆ™ |x| = (batch_size, m, k)
โˆ™ |y| = (batch_size, h, m)
| torch.bmm(x, y) | = (batch_size, n, m)

 



 

 

 

 

 

 

 


4. Input Feeding

๊ฐ time-step์˜ decoder์ถœ๋ ฅ๊ฐ’๊ณผ attention๊ฒฐ๊ณผ๊ฐ’์„ ์ด์–ด๋ถ™์ธ ํ›„ Generator Module์—์„œ softmax๋ฅผ ์ทจํ•ด ํ™•๋ฅ ๋ถ„ํฌ๋ฅผ ๊ตฌํ•œ๋‹ค.
์ดํ›„ ํ•ด๋‹น ํ™•๋ฅ ๋ถ„ํฌ์—์„œ argmax๋ฅผ ์ˆ˜ํ–‰ํ•ด y_hat์„ samplingํ•œ๋‹ค.
๋‹ค๋งŒ, ๋ถ„ํฌ์—์„œ samplingํ•˜๋Š” ๊ณผ์ •์—์„œ๋ณด๋‹ค ๋” ๋งŽ์€ ์ •๋ณด๊ฐ€ ์†์‹ค๋œ๋‹ค.

๋”ฐ๋ผ์„œ softmax์ด์ „๊ฐ’๋„ ๊ฐ™์ด ๋„ฃ์–ด์ฃผ๋Š” ํŽธ์ด ์ •๋ณด์†์‹ค์—†์ด ๋” ์ข‹์€ ํšจ๊ณผ๋ฅผ ์–ป๋Š”๋‹ค.,
y์™€ ๋‹ฌ๋ฆฌ concat์ธต์˜ ์ถœ๋ ฅ์€ y๊ฐ€ embedding์ธต์—์„œ dense๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜๋˜๊ณ  ๋‚œ ํ›„ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ์™€ ์ด์–ด๋ถ™์—ฌ Decoder RNN์— ์ž…๋ ฅ์œผ๋กœ ์ฃผ์–ด์ง€๋Š” ๊ณผ์ •์„ input feeding์ด๋ผ ํ•œ๋‹ค.

 

4.1 ๋‹จ์ 
์ด ๋ฐฉ์‹์€ train์†๋„์ €ํ•˜์˜ ๋‹จ์ ์ด ์กด์žฌํ•˜๋Š”๋ฐ, input feeding์ด์ „๋ฐฉ์‹์—์„œ๋Š” ํ›ˆ๋ จ ์‹œ decoder ๋˜ํ•œ encoder์ฒ˜๋Ÿผ ๋ชจ๋“  time-step์„ ํ•œ๋ฒˆ์— ๊ณ„์‚ฐํ•œ๋‹ค.
ํ•˜์ง€๋งŒ input feeding์œผ๋กœ์ธํ•ด decoder RNN์ž…๋ ฅ์œผ๋กœ ์ด์ „ time-step์˜ ๊ฒฐ๊ณผ๊ฐ€ ํ•„์š”ํ•˜๊ฒŒ ๋˜์–ด ์ˆœ์ฐจ์ ์œผ๋กœ time-step๋ณ„๋กœ ๊ณ„์‚ฐํ•ด์•ผํ•œ๋‹ค.

๋‹ค๋งŒ, ์ด ๋‹จ์ ์€ ์ถ”๋ก ๋‹จ๊ณ„์—์„œ ์–ด์ฐจํ”ผ decoder๋Š” input feeding์ด ์•„๋‹ˆ๋”๋ผ๋„ time-step๋ณ„ ๋ณ‘๋ ฌ์ฒ˜๋ฆฌ๊ฐ€ ์•„๋‹Œ ์ˆœ์ฐจ์  ๊ณ„์‚ฐ์ด ํ•„์š”ํ•˜๊ธฐ์— ์ถ”๋ก  ์‹œ input feeding์œผ๋กœ ์ธํ•œ ์†๋„ ์ €ํ•˜๋Š” ๊ฑฐ์˜ ์—†๋‹ค; ๋”ฐ๋ผ์„œ ์ด ๋‹จ์ ์ด ํฌ๊ฒŒ ๋ถ€๊ฐ๋˜์ง€๋Š” ์•Š๋Š”๋‹ค.

 

Pytorch ์˜ˆ์ œ
https://github.com/V2LLAIN/NLP/blob/main/5.%20RNN__Seq2Seq/seq2seq(with_attention).py

 

 

 

 

 

 

 

 

 

 

 


5. Auto Regressive .&.  Teacher Forcing

์—ฌ๊ธฐ์„œ ์˜๋ฌธ์ ...?

train ์‹œ Decoder์˜ ์ž…๋ ฅ์œผ๋กœ time-step์˜ ์ถœ๋ ฅ์ด ๋“ค์–ด๊ฐ€๋Š”๊ฑธ๊นŒ??

์‚ฌ์‹ค, seq2seq์˜ ๊ธฐ๋ณธ์  ํ›ˆ๋ จ๋ฐฉ์‹์€ ์ถ”๋ก ํ•  ๋•Œ์˜ ๋ฐฉ์‹๊ณผ ์ƒ์ดํ•˜๋‹ค.

 

5.1 AR(Auto Regressive) ์†์„ฑ
seq2seq์˜ train๊ณผ inference์˜ ๊ทผ๋ณธ์ ์ธ ์ฐจ์ด๋Š” AR์†์„ฑ์œผ๋กœ ๋ฐœ์ƒํ•œ๋‹ค.

์ž๊ธฐํšŒ๊ท€(AR): ๊ณผ๊ฑฐ ์ž์‹ ์˜ ๊ฐ’์„ ์ฐธ์กฐํ•ด ํ˜„์žฌ์˜ ๊ฐ’์„ ์ถ”๋ก ํ•˜๋Š” ํŠน์ง•
์ด๋ฅผ ์•„๋ž˜ ์ˆ˜์‹์—์„œ๋„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.
๋‹ค๋งŒ, ๊ณผ๊ฑฐ์˜ ๊ฒฐ๊ณผ๊ฐ’์—๋”ฐ๋ผ ๋ฌธ์žฅ์ด๋‚˜ ์‹œํ€€์Šค์˜ ๊ตฌ์„ฑ์ด ๋ฐ”๋€”๋ฟ๋งŒ์•„๋‹ˆ๋ผ ์˜ˆ์ธก๋ฌธ์žฅ์‹œํ€€์Šค์˜ ๊ธธ์ด๋„ ๋ฐ”๋€” ์ˆ˜ ์žˆ๊ณ ,
๊ณผ๊ฑฐ์— ์ž˜๋ชป๋œ ์˜ˆ์ธก์„ ํ–ˆ์„ ๋•Œ, ๋” ํฐ ์ž˜๋ชป๋œ ์˜ˆ์ธก์„ ํ•  ๊ฐ€๋Šฅ์„ฑ์„ ์•ผ๊ธฐํ•˜๊ธฐ๋„ ํ•œ๋‹ค.

ํ•™์Šต๊ณผ์ •์—์„œ๋Š” ์ด๋ฏธ ์ •๋‹ต์„ ์•Œ๊ณ ์žˆ๊ณ  ํ˜„์žฌ๋ชจ๋ธ์˜ ์˜ˆ์ธก๊ฐ’๊ณผ ์ •๋‹ต๊ณผ์˜ ์ฐจ์ด๋ฅผ ํ†ตํ•ด ํ•™์Šตํ•˜๊ธฐ์— ์ž๊ธฐํšŒ๊ท€(AR)์†์„ฑ์„ ์œ ์ง€ํ•œ ์ฑ„ ํ›ˆ๋ จํ•  ์ˆ˜๋Š” ์—†๋‹ค.

๋”ฐ๋ผ์„œ Teacher Forcing์ด๋ผ ๋ถˆ๋ฆฌ๋Š” ๋ฐฉ๋ฒ•์„ ํ†ตํ•ด ํ›ˆ๋ จ์„ ์ง„ํ–‰ํ•œ๋‹ค.

 

5.2 Teacher Forcing ํ›ˆ๋ จ๋ฐฉ๋ฒ•
Teacher Forcing์€ ํ›ˆ๋ จ ์‹œ decoder์˜ ์ž…๋ ฅ์œผ๋กœ ์ด์ „ time-step์˜ decoder ์ถœ๋ ฅ๊ฐ’์ด ์•„๋‹Œ, ์ •๋‹ต Y๊ฐ€ ๋“ค์–ด๊ฐ„๋‹ค๋Š” ์ ์ด๋‹ค.
ํ•˜์ง€๋งŒ ์ถ”๋ก  ์‹œ, ์ •๋‹ต Y๋ฅผ ๋ชจ๋ฅด๊ธฐ์— ์ด์ „ time-step์—์„œ ๊ณ„์‚ฐ๋˜์–ด ๋‚˜์˜จ y_hat๊ฐ’์„ decoder์˜ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค.
์ด๋ ‡๊ฒŒ ์ž…๋ ฅ์„ ๋„ฃ์–ด์ฃผ๋Š” ํ›ˆ๋ จ๋ฐฉ๋ฒ•์„ teacher forcing์ด๋ผ ํ•œ๋‹ค.

์ด์ :
์ดˆ๊ธฐ ํ›ˆ๋ จ ๋‹จ๊ณ„์—์„œ ์•ˆ์ •์ ์ธ ํ•™์Šต์„ ๋•๊ณ , ๋ชจ๋ธ์ด ์ดˆ๊ธฐ์— ์–ด๋–ค ๊ฒƒ์„ ์ƒ์„ฑํ•ด์•ผ ํ•˜๋Š”์ง€์— ๋Œ€ํ•œ ๊ฐ•๋ ฅํ•œ ์‹ ํ˜ธ๋ฅผ ์ œ๊ณตํ•œ๋‹ค.
๋ชจ๋ธ์ด ์ •๋‹ต ๋ ˆ์ด๋ธ”์„ ๋ณด๊ณ  ํ•™์Šตํ•˜๋ฏ€๋กœ ๊ธฐ์šธ๊ธฐ์†Œ์‹ค ๋ฌธ์ œ๋ฅผ ์™„ํ™”ํ•˜์—ฌ ๊ทธ๋ž˜๋””์–ธํŠธ๊ฐ€ ๋” ์ž˜ ํ๋ฅผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

cf) LM์˜ Teacher Forcing
[ํ›ˆ๋ จ ๊ณผ์ •]
๋ชจ๋ธ์„ ํ›ˆ๋ จํ•  ๋•Œ, ๊ฐ ์‹œ์ ์—์„œ ์ด์ „ ์‹œ์ ์˜ ๋ชจ๋ธ ์ถœ๋ ฅ์ด ์•„๋‹Œ ์‹ค์ œ ์ •๋‹ต ๋ ˆ์ด๋ธ”์„ ์ž…๋ ฅ์œผ๋กœ ์ œ๊ณตํ•œ๋‹ค.
์ฆ‰, ์ด์ „ ์‹œ์ ์˜ ์ถœ๋ ฅ์ด ์•„๋‹Œ "์„ ์ƒ๋‹˜(teacher)" ์—ญํ• ์˜ ์ •๋‹ต ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.

[ํ…Œ์ŠคํŠธ(์ถ”๋ก ) ๊ณผ์ •]
๋ชจ๋ธ์ด ํ›ˆ๋ จ์„ ๋งˆ์นœ ํ›„์—๋Š” ์ด์ „ ์ถœ๋ ฅ์„ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜์—ฌ ์‹œํ€€์Šค๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.
์ด๋•Œ ์ด์ „ ์‹œ์ ์˜ ์ถœ๋ ฅ์„ "์ž๊ธฐ ํšŒ๊ท€์ ์œผ๋กœ(autoregressively)" ์‚ฌ์šฉํ•œ๋‹ค.

 

 

 

 

 

 

 

 

 

 

 


6. Searching Algorithm(Inference). &. Beam Search

X๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ, Y_hat์„ ์ถ”๋ก ํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์ด์•ผ๊ธฐ ํ•ด๋ณด์ž.

์ด๋Ÿฐ ๊ณผ์ •์„ ์ถ”๋ก  ๋˜๋Š” ํƒ์ƒ‰(search)์ด๋ผ ๋ถ€๋ฅด๋Š”๋ฐ, ํƒ์ƒ‰์•Œ๊ณ ๋ฆฌ์ฆ˜์— ๊ธฐ๋ฐ˜ํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

์ฆ‰, ์šฐ๋ฆฌ๊ฐ€ ์›ํ•˜๋Š” ๋‹จ์–ด๋“ค ์‚ฌ์ด ์ตœ๊ณ ์˜ ํ™•๋ฅ ์„ ๊ฐ–๋Š” ๊ฒฝ๋กœ(path)๋ฅผ ์ฐพ๋Š” ๊ณผ์ •์ด๋‹ค.

 

6.1 sampling
๊ฐ€์žฅ ์ •ํ™•ํ•œ ๋ฐฉ๋ฒ•์€ time-step๋ณ„ y_hat์„ ๊ณ ๋ฅผ ๋•Œ, ๋งˆ์ง€๋ง‰ softmax์ธต์—์„œ์˜ ํ™•๋ฅ ๋ถ„ํฌ๋Œ€๋กœ samplingํ•˜๋Š” ๊ฒƒ์ด๋‹ค.
๊ทธ ํ›„ time-step์—์„œ ๊ทธ ์„ ํƒ(y_hat)์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ทธ ๋‹ค์Œ y_hat์„ ๋˜ ๋‹ค์‹œ samplingํ•ด ์ตœ์ข…์ ์œผ๋กœ EOS๊ฐ€ ๋‚˜์˜ฌ๋•Œ ๊นŒ์ง€ samplingํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

๋‹ค๋งŒ, ์ด๋Ÿฐ ๋ฐฉ์‹์€ ๊ฐ™์€ ์ž…๋ ฅ์— ๋Œ€ํ•ด ๋งค๋ฒˆ ๋‹ค๋ฅธ ์ถœ๋ ฅ๊ฒฐ๊ณผ๋ฌผ์ด ๋‚˜์˜ฌ ์ˆ˜ ์žˆ์–ด ์ง€์–‘ํ•˜๋Š”ํŽธ์ด๋‹ค.

6.2 Greedy Search Algorithm ํ™œ์šฉ
DFS, BFS, DP ๋“ฑ ์ˆ˜๋งŽ์€ ํƒ์ƒ‰๊ธฐ๋ฒ•์ด ์กด์žฌํ•˜์ง€๋งŒ Greedy Search Algorithm์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํƒ์ƒ‰์„ ๊ตฌํ˜„ํ•ด๋ณด์ž.
์ฆ‰, ๋ชจ๋“  time-step์— ๋Œ€ํ•œ softmaxํ™•๋ฅ ๊ฐ’๋“ค ์ค‘ ๊ฐ€์žฅ ํ™•๋ฅ ๊ฐ’์ด ํฐ ์ธ๋ฑ์Šค๋ฅผ ๋ฝ‘์•„ ๊ทธ time-step์˜ y_hat์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด๋‹ค. 

์ฆ‰, ๊ฐ ์ถœ๋ ฅ ์˜ˆ์ธก ์‹œ ๊ฐ step์—์„œ ๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ ๋†’์€ ๋‹จ์–ด๋ฅผ ์„ ํƒํ•ด ๋งค์šฐ ๋น ๋ฅธ ํƒ์ƒ‰์ด ๊ฐ€๋Šฅํ•˜๋‹ค.

๋‹จ์  1.) Decision์„ ๋˜๋Œ๋ฆด ์ˆ˜ ์—†๊ฒŒ ๋  ์ˆ˜ ์žˆ๊ณ 
๋‹จ์  2.) ์ตœ์ข…์ถœ๋ ฅ์ด ์ตœ์ ํ™”๋œ ๊ฒฐ๊ณผ์—์„œ ๋ฉ€์–ด์งˆ ์ˆ˜ ์žˆ๋‹ค.
             (โˆต <END>token ์ƒ์„ฑ์ „๊นŒ์ง€ decoding์„ ์ง„ํ–‰ํ•˜๊ธฐ ๋•Œ๋ฌธ)
Pytorch ์˜ˆ์ œ
def search(self, src, is_greedy=True, max_length=255):
        if isinstance(src, tuple):
            x, x_length = src
            mask = self.generate_mask(x, x_length)
        else:
            x, x_length = src, None
            mask = None
        batch_size = x.size(0)

        # Same procedure as teacher forcing.
        emb_src = self.emb_src(x)
        h_src, h_0_tgt = self.encoder((emb_src, x_length))
        decoder_hidden = self.fast_merge_encoder_hiddens(h_0_tgt)

        # Fill a vector, which has 'batch_size' dimension, with BOS value.
        y = x.new(batch_size, 1).zero_() + data_loader.BOS

        is_decoding = x.new_ones(batch_size, 1).bool()
        h_t_tilde, y_hats, indice = None, [], []
        
        # Repeat a loop while sum of 'is_decoding' flag is bigger than 0,
        # or current time-step is smaller than maximum length.
        while is_decoding.sum() > 0 and len(indice) < max_length:
            # Unlike training procedure,
            # take the last time-step's output during the inference.
            emb_t = self.emb_dec(y)
            # |emb_t| = (batch_size, 1, word_vec_size)

            decoder_output, decoder_hidden = self.decoder(emb_t,
                                                          h_t_tilde,
                                                          decoder_hidden)
            context_vector = self.attn(h_src, decoder_output, mask)
            h_t_tilde = self.tanh(self.concat(torch.cat([decoder_output,
                                                         context_vector
                                                         ], dim=-1)))
            y_hat = self.generator(h_t_tilde)
            # |y_hat| = (batch_size, 1, output_size)
            y_hats += [y_hat]

            if is_greedy:
                y = y_hat.argmax(dim=-1)
                # |y| = (batch_size, 1)
            else:
                # Take a random sampling based on the multinoulli distribution.
                y = torch.multinomial(y_hat.exp().view(batch_size, -1), 1)
                # |y| = (batch_size, 1)

            # Put PAD if the sample is done.
            y = y.masked_fill_(~is_decoding, data_loader.PAD)
            # Update is_decoding if there is EOS token.
            is_decoding = is_decoding * torch.ne(y, data_loader.EOS)
            # |is_decoding| = (batch_size, 1)
            indice += [y]

        y_hats = torch.cat(y_hats, dim=1)
        indice = torch.cat(indice, dim=1)
        # |y_hat| = (batch_size, length, output_size)
        # |indice| = (batch_size, length)

        return y_hats, indiceโ€‹

6.3 Beam Search
Greedy Algorithm์€ ๋งค์šฐ ์‰ฝ๊ณ  ๊ฐ„๋‹จํ•˜์ง€๋งŒ, ์ตœ์ (optimal)ํ•ด๋Š” ๋ณด์žฅํ•˜์ง€ ์•Š๋Š”๋‹ค.
๋”ฐ๋ผ์„œ ์•ฝ๊ฐ„์˜ trick์„ ๊ฐ€๋ฏธํ•˜๋Š”๋ฐ, k๊ฐœ์˜ ํ›„๋ณด๋ฅผ ๋” ์ถ”์ ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.
์ด๋•Œ, k๋ฅผ beam_size๋ผ ํ•œ๋‹ค. 

Beam_size k์— ๋Œ€ํ•ด step์ด ์ง„ํ–‰๋˜๋ฉด์„œ k๊ฐœ์˜ ๊ฐ€์ง“์ˆ˜์— ๋Œ€ํ•ด k๋ฅผ ์œ ์ง€, ์ตœ์ข… ํ›„๋ณด๊ตฐ์—์„œ ํ™•๋ฅ ์ด ๊ฐ€์žฅ ๋†’์€ ๊ฒƒ์„ ์„ ํƒํ•œ๋‹ค.
๋‹ค๋ฅธ time-step์—์„œ <END>token ์ƒ์„ฑ์ด ๊ฐ€๋Šฅํ•˜๋ฉฐ
ํ•˜๋‚˜์˜ ๊ฐ€์„ค(hypothesis)์—์„œ <END>๊ฐ€ ๋‚˜์˜ค๋ฉด ์ข…๋ฃŒํ•˜๊ณ , ๋‹ค๋ฅธ ๊ฐ€์„ค๋ถ„๊ธฐ๋ฅผ ๊ณ„์† ํƒ์ƒ‰ํ•œ๋‹ค.
์ฆ‰, ๋‹จ์–ด๊ฐ€ ์ˆœ์ฐจ์  ์ƒ์„ฑ๋˜์–ด ๋™์‹œ์‚ฌ๊ฑดํ™•๋ฅ  ๊ณ ๋ ค ๋ฐ ์ƒ์„ฑํ•  ๋•Œ๋งˆ๋‹ค log๊ฐ’์ด ๋”ํ•ด์ ธ ๋”ํ•ด์ง€๋Š” ์Œ์ˆ˜๊ฐ’์ด ๋งŽ์•„์ ธ ์ž‘์€๊ฐ’์ด ๋˜๋Š”, ์ผ์ข…์˜ Normalizeํ•˜๋Š” ๊ณผ์ •์„ ํ•œ๋ฒˆ ๋” ๊ฑฐ์น  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค.

โˆ™ small k
  - greedy์™€ ๊ฑฐ์˜ ๋น„์Šทํ•˜๋‹ค.(= ungrammatic, unnatural, nonsensical, incorrect)

โˆ™ Large k
  - k๊ฐ€ ์ปค์งˆ์ˆ˜๋ก greedy๋ฌธ์ œ๋Š” ์ค„์ง€๋งŒ ๊ณ„์‚ฐ๋น„์šฉ์ด ์ปค์ง„๋‹ค.
  - BLEU_Score๊ฐ€ ๋–จ์–ด์ง€๋Š” ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•œ๋‹ค. (โˆต too-short translation)

๋”ฐ๋ผ์„œ ๋ณดํ†ต Beam_size๋ฅผ 10 ์ดํ•˜๋กœ ์‚ฌ์šฉํ•œ๋‹ค.



Pytorch ์˜ˆ์ œ
#@profile
    def batch_beam_search(
        self,
        src,
        beam_size=5,
        max_length=255,
        n_best=1,
        length_penalty=.2
    ):
        mask, x_length = None, None

        if isinstance(src, tuple):
            x, x_length = src
            mask = self.generate_mask(x, x_length)
            # |mask| = (batch_size, length)
        else:
            x = src
        batch_size = x.size(0)

        emb_src = self.emb_src(x)
        h_src, h_0_tgt = self.encoder((emb_src, x_length))
        # |h_src| = (batch_size, length, hidden_size)
        h_0_tgt = self.fast_merge_encoder_hiddens(h_0_tgt)

        # initialize 'SingleBeamSearchBoard' as many as batch_size
        boards = [SingleBeamSearchBoard(
            h_src.device,
            {
                'hidden_state': {
                    'init_status': h_0_tgt[0][:, i, :].unsqueeze(1),
                    'batch_dim_index': 1,
                }, # |hidden_state| = (n_layers, batch_size, hidden_size)
                'cell_state': {
                    'init_status': h_0_tgt[1][:, i, :].unsqueeze(1),
                    'batch_dim_index': 1,
                }, # |cell_state| = (n_layers, batch_size, hidden_size)
                'h_t_1_tilde': {
                    'init_status': None,
                    'batch_dim_index': 0,
                }, # |h_t_1_tilde| = (batch_size, 1, hidden_size)
            },
            beam_size=beam_size,
            max_length=max_length,
        ) for i in range(batch_size)]
        is_done = [board.is_done() for board in boards]

        length = 0
        # Run loop while sum of 'is_done' is smaller than batch_size, 
        # or length is still smaller than max_length.
        while sum(is_done) < batch_size and length <= max_length:
            # current_batch_size = sum(is_done) * beam_size

            # Initialize fabricated variables.
            # As far as batch-beam-search is running, 
            # temporary batch-size for fabricated mini-batch is 
            # 'beam_size'-times bigger than original batch_size.
            fab_input, fab_hidden, fab_cell, fab_h_t_tilde = [], [], [], []
            fab_h_src, fab_mask = [], []
            
            # Build fabricated mini-batch in non-parallel way.
            # This may cause a bottle-neck.
            for i, board in enumerate(boards):
                # Batchify if the inference for the sample is still not finished.
                if board.is_done() == 0:
                    y_hat_i, prev_status = board.get_batch()
                    hidden_i    = prev_status['hidden_state']
                    cell_i      = prev_status['cell_state']
                    h_t_tilde_i = prev_status['h_t_1_tilde']

                    fab_input  += [y_hat_i]
                    fab_hidden += [hidden_i]
                    fab_cell   += [cell_i]
                    fab_h_src  += [h_src[i, :, :]] * beam_size
                    fab_mask   += [mask[i, :]] * beam_size
                    if h_t_tilde_i is not None:
                        fab_h_t_tilde += [h_t_tilde_i]
                    else:
                        fab_h_t_tilde = None

            # Now, concatenate list of tensors.
            fab_input  = torch.cat(fab_input,  dim=0)
            fab_hidden = torch.cat(fab_hidden, dim=1)
            fab_cell   = torch.cat(fab_cell,   dim=1)
            fab_h_src  = torch.stack(fab_h_src)
            fab_mask   = torch.stack(fab_mask)
            if fab_h_t_tilde is not None:
                fab_h_t_tilde = torch.cat(fab_h_t_tilde, dim=0)
            # |fab_input|     = (current_batch_size, 1)
            # |fab_hidden|    = (n_layers, current_batch_size, hidden_size)
            # |fab_cell|      = (n_layers, current_batch_size, hidden_size)
            # |fab_h_src|     = (current_batch_size, length, hidden_size)
            # |fab_mask|      = (current_batch_size, length)
            # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size)

            emb_t = self.emb_dec(fab_input)
            # |emb_t| = (current_batch_size, 1, word_vec_size)

            fab_decoder_output, (fab_hidden, fab_cell) = self.decoder(emb_t,
                                                                      fab_h_t_tilde,
                                                                      (fab_hidden, fab_cell))
            # |fab_decoder_output| = (current_batch_size, 1, hidden_size)
            context_vector = self.attn(fab_h_src, fab_decoder_output, fab_mask)
            # |context_vector| = (current_batch_size, 1, hidden_size)
            fab_h_t_tilde = self.tanh(self.concat(torch.cat([fab_decoder_output,
                                                             context_vector
                                                             ], dim=-1)))
            # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size)
            y_hat = self.generator(fab_h_t_tilde)
            # |y_hat| = (current_batch_size, 1, output_size)

            # separate the result for each sample.
            # fab_hidden[:, begin:end, :] = (n_layers, beam_size, hidden_size)
            # fab_cell[:, begin:end, :]   = (n_layers, beam_size, hidden_size)
            # fab_h_t_tilde[begin:end]    = (beam_size, 1, hidden_size)
            cnt = 0
            for board in boards:
                if board.is_done() == 0:
                    # Decide a range of each sample.
                    begin = cnt * beam_size
                    end = begin + beam_size

                    # pick k-best results for each sample.
                    board.collect_result(
                        y_hat[begin:end],
                        {
                            'hidden_state': fab_hidden[:, begin:end, :],
                            'cell_state'  : fab_cell[:, begin:end, :],
                            'h_t_1_tilde' : fab_h_t_tilde[begin:end],
                        },
                    )
                    cnt += 1

            is_done = [board.is_done() for board in boards]
            length += 1

        # pick n-best hypothesis.
        batch_sentences, batch_probs = [], []

        # Collect the results.
        for i, board in enumerate(boards):
            sentences, probs = board.get_n_best(n_best, length_penalty=length_penalty)

            batch_sentences += [sentences]
            batch_probs     += [probs]

        return batch_sentences, batch_probsโ€‹

 

 

 

 

 

 

 

 

 

 

 

 

 

 


7. Performance Metric [PPL / BLEU / METEOR / ROUGE]

7.1 ์ •์„ฑ์  ํ‰๊ฐ€ (Intrinsic Evaluation)
๋ณดํ†ต ์‚ฌ๋žŒ์ด ๋ฒˆ์—ญ๋œ ๋ฌธ์žฅ์„ ์ฑ„์ ํ•˜๋Š” ํ˜•ํƒœ
์ด๋•Œ, ์‚ฌ๋žŒ์˜ ์„ ์ž…๊ฒฌ์ด ๋ฐ˜์˜๋  ์ˆ˜ ์žˆ๊ธฐ์— ๋ณดํ†ต blind test์™€ ๊ฐ™์€ ๋ฐฉ์‹์„ ๊ณ ์ˆ˜ํ•œ๋‹ค.
๊ฐ€์žฅ ์ •ํ™•ํ•  ์ˆ˜๋Š” ์žˆ์ง€๋งŒ ์ž์›๊ณผ ์‹œ๊ฐ„์ด ๋งŽ์ด ๋“ ๋‹ค๋Š” ๋‹จ์ ์ด ์กด์žฌํ•œ๋‹ค.

์ดํ›„ ๊ตฌ๊ธ€์˜ ๋ฒˆ์—ญ์‹œ์Šคํ…œ์„ ์•Œ์•„๋ณผ ๋•Œ, ๊ตฌ๊ธ€์—์„œ ์ •์„ฑ์  ํ‰๊ฐ€๋ฅผ ํ†ตํ•ด ์–ป์€ ์ ์ˆ˜์— ๋Œ€ํ•ด ์•Œ์•„๋ณผ ๊ฒƒ์ด๋‹ค.

 

7.2 ์ •๋Ÿ‰์  ํ‰๊ฐ€ (Extrinsic Evaluation)
์œ„์—์„œ ์–ธ๊ธ‰ํ•œ ์ •์„ฑ์  ํ‰๊ฐ€์˜ ๋‹จ์ ์œผ๋กœ ์ธํ•ด ๋ณดํ†ต ์ž๋™ํ™”๋œ ์ •๋Ÿ‰ํ‰๊ฐ€๋ฅผ ์ฃผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค.
์ด๋•Œ, ์ตœ๋Œ€ํ•œ ๋น„์Šทํ•œ ์ผ๊ด€์„ฑ์„ ๊ฐ–๋Š” ํ‰๊ฐ€๋ฅผ ํ•ด์•ผํ•˜๋ฉฐ, ์–ธ์–ด์  ํŠน์ง•์ด ๋ฐ˜์˜๋œ ํ‰๊ฐ€๋ฐฉ๋ฒ•์ด๋ผ๋ฉด ๋”๋”์šฑ ์ข‹์„ ๊ฒƒ์ด๋‹ค.

PPL (Perplexity)
NMT๋„ ๊ธฐ๋ณธ์ ์œผ๋กœ ๋งค time-step๋งˆ๋‹ค ์ตœ๊ณ ํ™•๋ฅ ์„ ๊ฐ–๋Š” ๋‹จ์–ด๋ฅผ ์„ ํƒ(classification)ํ•˜๋Š” ์ž‘์—…์ด๊ธฐ์— Cross Entropy๋ฅผ ๊ธฐ๋ณธ์ ์ธ Loss Function์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค.
NMT๋˜ํ•œ ์กฐ๊ฑด๋ถ€ ์–ธ์–ด๋ชจ๋ธ์ด๊ธฐ์— PPL๋ฅผ ํ†ตํ•œ ์„ฑ๋Šฅ์ธก์ •์ด ๊ฐ€๋Šฅํ•˜๋‹ค.

๊ฒฐ๊ณผ์ ์œผ๋กœ Cross-Entropy์— exp๋ฅผ ์ทจํ•œ PPL์„ ํ‰๊ฐ€์ง€ํ‘œ๋กœ ํ™œ์šฉ๊ฐ€๋Šฅํ•˜๋‹ค.
(PPL์ด ๋‚ฎ์„์ˆ˜๋ก, N-gram์˜ N์ด ํด์ˆ˜๋ก ์ข‹์€ ๋ชจ๋ธ ; https://chan4im.tistory.com/200#n3)
CE์™€ ์ง๊ฒฐ๋˜์–ด ๊ฐ„ํŽธํ•จ์ด ์žˆ๋‹ค๋Š” ์žฅ์ ์ด ์žˆ์ง€๋งŒ, ์‹ค์ œ ๋ฒˆ์—ญ๊ธฐ ์„ฑ๋Šฅ๊ณผ ์™„๋ฒฝํ•œ ๋น„๋ก€๊ด€๊ณ„์— ์žˆ๋‹ค ํ•  ์ˆ˜๋Š” ์—†๋‹ค.

๊ฐ time-step๋ณ„ ์‹ค์ œ ์ •๋‹ต์— ํ•ด๋‹นํ•˜๋Š” ๋‹จ์–ด์˜ ํ™•๋ฅ ๋งŒ ์ฑ„์ ํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
ํ•˜์ง€๋งŒ ์–ธ์–ด๋Š” ๊ฐ™์€ ์˜๋ฏธ์˜ ๋ฌธ์žฅ๋“ค์ด๋ผ๋„ ์–ด์ˆœ์ด ๋ฐ”๋€”์ˆ˜๋„ ์žˆ๊ณ ,
๋น„์Šทํ•œ ์˜๋ฏธ์˜ ๋‹จ์–ด๋กœ ์น˜ํ™˜๋  ์ˆ˜๋„ ์žˆ๊ธฐ์— ์™„์ „ํžˆ ์ž˜๋ชป๋œ ๋ฒˆ์—ญ์ด๋”๋ผ๋„ Loss๊ฐ’์ด ๋‚ฎ์„์ˆ˜๋„ ์žˆ๋‹ค.

๋”ฐ๋ผ์„œ ์‹ค์ œ ๋ฒˆ์—ญ๋ฌธ์˜ ํ’ˆ์งˆ๊ณผ CE์‚ฌ์ด์—๋Š” ๊ดด๋ฆฌ๊ฐ€ ์กด์žฌํ•œ๋‹ค.
(ํŠนํžˆ๋‚˜ Teacher Forcing๋ฐฉ์‹์ด ๋”ํ•ด์ง€๊ธฐ์— ๋”๋”์šฑ ๊ดด๋ฆฌ๊ฐ€ ์กด์žฌํ•œ๋‹ค.)
 

[Gain Study_NLP]06. Language modeling (N-gram, Metric, SRILM, NNLM, OCR)

๐Ÿ“Œ ๋ชฉ์ฐจ 1. preview 2. N-gram 3. Language Model Metric 4. SRILM ํ™œ์šฉํ•œ N-gram ์‹ค์Šต 5. NNLM (BOS, EOS) 6. Language Model์˜ ํ™œ์šฉ (Speech Recognition / ๊ธฐ๊ณ„๋ฒˆ์—ญ / OCR ๋“ฑ) ๐Ÿ˜š ๊ธ€์„ ๋งˆ์น˜๋ฉฐ... 1. Preview 2.1 LM (Language Model) ์–ธ์–ด๋ชจ๋ธ

chan4im.tistory.com


BLEU (Bi-Lingual Evaluation Understudy)
์œ„์—์„œ ๋งํ•œ PPL์—์„œ CE์˜ ๊ดด๋ฆฌ๋ฅผ ์ค„์ด๊ธฐ ์œ„ํ•ด ์—ฌ๋Ÿฌ ๋ฐฉ๋ฒ•๋“ค์ด ์ œ์‹œ๋˜์—ˆ๋Š”๋ฐ, ๊ฐ€์žฅ ๋Œ€ํ‘œ์ ์ธ BLEU์— ๋Œ€ํ•ด ์•Œ์•„๋ณด๊ณ ์ž ํ•œ๋‹ค.

BLEU Score๋Š” ์ •๋‹ต๊ณผ ์˜ˆ์ธก๋ฌธ์žฅ๊ฐ„์— ์ผ์น˜ํ•˜๋Š” N-gram ๊ฐœ์ˆ˜์˜ ๋น„์œจ์˜ ๊ธฐํ•˜ํ‰๊ท ์— ๋”ฐ๋ผ ์ ์ˆ˜๋ฅผ ๋งค๊ธด๋‹ค.
์ฆ‰, ๊ฐ N-gram๋ณ„ precision์˜ ํ‰๊ท ์„ ๋ฐฑ๋ถ„์œจ๋กœ ๋‚˜ํƒ€๋‚ด๋Š” ๊ฒƒ์ด๊ณ 
์งง์€๋ฌธ์žฅ์—๋Œ€ํ•œ ํŽ˜๋„ํ‹ฐ(brevity_penalty)๋Š” ์˜ˆ์ธก๋œ ๋ฒˆ์—ญ๋ฌธ์ด ์ •๋‹ต๋ฌธ์žฅ๋ณด๋‹ค ์งง์„๊ฒฝ์šฐ, ์ ์ˆ˜๊ฐ€ ์ข‹์•„์ง€๋Š” ๊ฒƒ์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๊ฒƒ์ด๋‹ค.
BLEU Score๊ฐ€ ๋†’์„์ˆ˜๋ก ์ข‹์€ ๋ชจ๋ธ์ž„์„ ์˜๋ฏธํ•œ๋‹ค.

๋˜ํ•œ, ์‹ค์ œ ์„ฑ๋Šฅ์ธก์ • ์‹œ, BLEU๋ฅผ ์ง์ ‘๊ตฌํ˜„ํ•˜๊ธฐ๋ณด๋‹จ SMT Framework์ธ MOSES์— ํฌํ•จ๋œ multi-bleu.perl์„ ์ฃผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค.

 


METEOR (Metric for Evaluation of Translation with Explicit ORdering)
METEOR๋Š” NMT ๋ฐ NLP์ƒ์„ฑ ์ž‘์—…์—์„œ ์‚ฌ์šฉ๋˜๋Š” ์ž๋™ ํ‰๊ฐ€์ง€ํ‘œ ์ค‘ ํ•˜๋‚˜์ด๋‹ค.
METEOR๋Š” ๋ฒˆ์—ญ ํ’ˆ์งˆ์„ ์ธก์ •ํ•˜๊ณ  ์ฐธ์กฐ(reference) ๋ฒˆ์—ญ๊ณผ ์ƒ์„ฑ๋œ ๋ฒˆ์—ญ ๊ฐ„์˜ ์œ ์‚ฌ์„ฑ์„ ํŒ๋‹จํ•œ๋‹ค.
METEOR๋Š” BLEU์™€ ์œ ์‚ฌํ•œ ๋ชฉ์ ์„ ๊ฐ€์ง€๊ณ  ์žˆ์ง€๋งŒ ๋ช‡ ๊ฐ€์ง€ ์ค‘์š”ํ•œ ์ฐจ์ด์ ์ด ์žˆ๋‹ค.

Main ์ฃผ์•ˆ์ )
BLEU๋กœ๋ถ€ํ„ฐ ํŒŒ์ƒ๋œ ๋ฐฉ๋ฒ•์œผ๋กœ BLEU์˜ ๋ถˆ์™„์ „ํ•จ์„ ๋ณด์™„ํ•˜๊ณ  ๋” ๋‚˜์€ ์ง€ํ‘œ๋ฅผ ์„ค๊ณ„ํ•˜๊ณ ์ž
BLEU์™€์˜ ๊ฐ€์žฅ ํฐ ์ฐจ์ด์ ์œผ๋กœ precision๋งŒ ๊ณ ๋ คํ•˜๋Š” BLEU์™€๋Š” ๋‹ฌ๋ฆฌ
 - recall๋„ ํ•จ๊ป˜ ๊ณ ๋ คํ–ˆ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.
 - ์ถ”๊ฐ€์ ์œผ๋กœ ๋‹ค๋ฅธ ๊ฐ€์ค‘์น˜๋ฅผ ์ ์šฉํ•œ ์ด ๋‘˜์˜ ์กฐํ™”ํ‰๊ท ์„ ์„ฑ๋Šฅ ๊ณ„์‚ฐ์— ํ™œ์šฉํ•˜๊ณ ,
 - ์˜ค๋‹ต์— ๋Œ€ํ•ด ๋ณ„๋„์˜ penalty๋ฅผ ๋ถ€๊ณผํ•˜๋Š” ๋ฐฉ์‹์„ ์ฑ„ํƒํ•˜๊ฑฐ๋‚˜,
 - ์—ฌ๋Ÿฌ ๋‹จ์–ด๋‚˜ ๊ตฌ ๋“ฑ์„ ์ •๋‹ต์œผ๋กœ ์ฒ˜๋ฆฌํ•˜๋Š” ๋“ฑ BLEU๋ฅผ ๋ณด์™„ํ•˜๊ณ ์ž ํ–ˆ๋‹ค.


METEOR์˜ ์ฃผ์š” ํŠน์ง•๊ณผ ์ž‘๋™ ๋ฐฉ์‹์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค:

โˆ™ ํ•ญ๋ชฉ ์ •ํ™•๋„ (Precision): METEOR๋Š” ๋‹จ์–ด๋‚˜ ๊ตฌ์ ˆ ์ˆ˜์ค€์˜ ์ผ์น˜๋ฅผ ์ธก์ •ํ•œ๋‹ค.
๋ฒˆ์—ญ ํ›„๋ณด์™€ ์ฐธ์กฐ ๋ฒˆ์—ญ ๊ฐ„์˜ ๊ณตํ†ต๋œ ํ† ํฐ (๋‹จ์–ด ๋˜๋Š” ๊ตฌ์ ˆ) ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.
์ด๊ฒƒ์€ "ํ•ญ๋ชฉ ์ •ํ™•๋„" ๋˜๋Š” "Precision"์œผ๋กœ ์•Œ๋ ค์ ธ ์žˆ์Šต๋‹ˆ๋‹ค. 
์ˆ˜์‹์€ ์•„๋ž˜์™€ ๊ฐ™๋‹ค.


โˆ™ ๊ธฐํ•˜ ํ‰๊ท  F1 ์ ์ˆ˜: METEOR๋Š” ์ •ํ™•๋„์™€ ๋ฆฌ์ฝœ(recall) ๊ฐ„์˜ ๊ท ํ˜•์„ ์ธก์ •ํ•˜๊ธฐ ์œ„ํ•ด ์ •ํ™•๋„์™€ ๋ฆฌ์ฝœ์˜ ์กฐํ™” ํ‰๊ท ์ธ F1 ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ๋ฒˆ์—ญ ํ›„๋ณด์˜ ์ •ํ™•์„ฑ๊ณผ ์ฐธ์กฐ ๋ฒˆ์—ญ๊ณผ์˜ ์œ ์‚ฌ์„ฑ์„ ๋ชจ๋‘ ๊ณ ๋ คํ•ฉ๋‹ˆ๋‹ค.
์ˆ˜์‹์€ ์•„๋ž˜์™€ ๊ฐ™๋‹ค.
penalty์˜ ๊ฒฝ์šฐ, ์•„๋ž˜์™€ ๊ฐ™์ด ๊ณ„์‚ฐ๋˜๋ฉฐ
์ตœ์ข…์ ์ธ METEOR Score๋Š” ์•„๋ž˜์™€ ๊ฐ™๋‹ค.


โˆ™ ๋ฒˆ์—ญ ํ›„๋ณด์™€ ์ฐธ์กฐ ๋ฒˆ์—ญ ์‚ฌ์ด์˜ ์–ดํœ˜์™€ ๊ตฌ์กฐ์  ๋ณ€ํ™”: METEOR์€ ๋‹จ์–ด์˜ ์ˆœ์„œ์™€ ๊ตฌ์กฐ์  ๋ณ€๊ฒฝ์„ ํฌํ•จํ•œ ๋‹ค์–‘ํ•œ ์–ดํœ˜์™€ ๋ฌธ๋ฒ•์ ์ธ ๋ณ€ํ™”๋ฅผ ๊ณ ๋ คํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ๋‹จ์–ด ์ˆœ์„œ๋ฅผ ๋ณด๋‹ค ๊ฐ•์กฐํ•˜๋Š” ํŠน์ง•์ด ์žˆ์œผ๋ฉฐ, ๋ฒˆ์—ญ ํ›„๋ณด์™€ ์ฐธ์กฐ ๋ฒˆ์—ญ ๊ฐ„์˜ ๊ณตํ†ต ์–ดํœ˜ ๋ฐ ๊ตฌ์กฐ๋ฅผ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค.

โˆ™ ์–ดํœ˜, ๊ตฌ์ ˆ ์ •๋ ฌ ๋ฐ ๋™์˜์–ด ์ฒ˜๋ฆฌ: METEOR์€ ์–ดํœ˜์˜ ๋™์˜์–ด ์ฒ˜๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•˜๋ฉฐ, ๊ตฌ์ ˆ ์ •๋ ฌ๊ณผ ์œ ์‚ฌ์„ฑ์„ ๊ณ„์‚ฐํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

METEOR์˜ ๊ฒฐ๊ณผ๋Š” 0๊ณผ 1 ์‚ฌ์ด์˜ ์ ์ˆ˜๋กœ ๋‚˜ํƒ€๋‚ฉ๋‹ˆ๋‹ค. ๋†’์€ METEOR ์ ์ˆ˜๋Š” ๋ฒˆ์—ญ์ด ์ฐธ์กฐ ๋ฒˆ์—ญ๊ณผ ์œ ์‚ฌํ•˜๋‹ค๋Š” ๊ฒƒ์„ ๋‚˜ํƒ€๋‚ด๋ฉฐ, ๋†’์€ ๋ฒˆ์—ญ ํ’ˆ์งˆ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. METEOR๋Š” ์ฃผ๋กœ ๋ฒˆ์—ญ ํ’ˆ์งˆ ํ‰๊ฐ€๋ฅผ ์œ„ํ•ด ์‚ฌ์šฉ๋˜๋ฉฐ, ๋‹ค์–‘ํ•œ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ ์ž‘์—…์—์„œ๋„ ์ ์šฉ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

ROUGE (Recall-Oriented Understudy for Gisting Evaluation)
ROUGE๋Š” NLP์—์„œ ์‚ฌ์šฉ๋˜๋Š” ์ž๋™ ํ‰๊ฐ€์ง€ํ‘œ ์ค‘ ํ•˜๋‚˜๋กœ, ํ…์ŠคํŠธ ์ƒ์„ฑ ๋ฐ ์ž๋™ ์š”์•ฝ ์ž‘์—…์—์„œ ๋งŽ์ด ์‚ฌ์šฉ๋œ๋‹ค.
ROUGE๋Š” ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ ๋˜๋Š” ์š”์•ฝ๊ณผ ๊ธฐ์ค€(reference) ํ…์ŠคํŠธ ๊ฐ„์˜ ์œ ์‚ฌ์„ฑ์„ ์ธก์ •ํ•˜๊ณ  ํ‰๊ฐ€ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋œ๋‹ค.
ROUGE๋Š” 5๊ฐœ์˜ ํ‰๊ฐ€ ์ง€ํ‘œ๊ฐ€ ์žˆ๋‹ค.
- ROUGE-N
- ROUGE-L
- ROUGE-W 
- ROUGE-S
- ROUGE-SU


โˆ™ ROUGE-N (ROUGE-Ngram)
ROUGE-N ๋ฉ”ํŠธ๋ฆญ์€ N-gram (์—ฐ์†๋œ n๊ฐœ์˜ ๋‹จ์–ด ๋˜๋Š” ๋ฌธ์ž) ์ผ์น˜๋ฅผ ์ธก์ •ํ•œ๋‹ค.
์ผ๋ฐ˜์ ์œผ๋กœ ROUGE-1, ROUGE-2, ROUGE-3 ๋“ฑ๊ณผ ๊ฐ™์ด ์ง€์ •๋œ n-gram ๊ธธ์ด๋ฅผ ๋‚˜ํƒ€๋‚ธ๋‹ค.


์ฆ‰, ROUGE-N์€ ์˜ˆ์ธกํ•œ ์š”์•ฝ๋ฌธ๊ณผ ์‹ค์ œ ์š”์•ฝ๋ฌธ๊ฐ„์˜ N-gram์˜ Recall๊ฐ’์œผ๋กœ ์‰ฝ๊ฒŒ ๋‚˜ํƒ€๋‚ด๋ฉด ์•„๋ž˜์™€ ๊ฐ™๋‹ค.
 โ€ฃ ROUGE-1: 1-gram (unigram) = ๋‹จ์–ด ๋‹จ์œ„์˜ ์ผ์น˜๋ฅผ ๊ณ„์‚ฐ
 โ€ฃ ROUGE-2: 2-gram (bigram) ๋‹จ์œ„์˜ ์ผ์น˜๋ฅผ ๊ณ„์‚ฐ
 โ€ฃ ROUGE-3: 3-gram (trigram) ๋‹จ์œ„์˜ ์ผ์น˜๋ฅผ ๊ณ„์‚ฐ
ROUGE-N์€ ์œ ์‚ฌ์„ฑ์„ ์ธก์ •ํ•˜๊ณ  ๋‹จ์–ด ์ˆœ์„œ๋ฅผ ๊ณ ๋ คํ•˜์ง€ ์•Š๋Š”๋‹ค.
ex) ROUGE-1
โˆ™์‹ค์ œ ์š”์•ฝ๋ฌธ uni-gram: 'Korea', 'won', 'the', 'world', 'cup'
โˆ™์˜ˆ์ธก ์š”์•ฝ๋ฌธ uni-gram: 'Korea', 'won', 'the', 'soccer', 'world', 'cup', 'final'
์˜ˆ์ธก ์š”์•ฝ๋ฌธ๊ณผ ์‹ค์ œ ์š”์•ฝ๋ฌธ ์‚ฌ์ด์— ๊ฒน์น˜๋Š” uni-gram ์ˆ˜๋Š” 5์ด๊ณ ,
์‹ค์ œ ์š”์•ฝ๋ฌธ์˜ uni-gram์ˆ˜๋„ 5์ด๋ฏ€๋กœ
โˆ™ROUGE-1 = 5/5 = 1

ex) ROUGE-2
โˆ™์‹ค์ œ ์š”์•ฝ๋ฌธ bi-gram: 'Korea won', 'won the', 'the world', 'world cup'
โˆ™์˜ˆ์ธก ์š”์•ฝ๋ฌธ bi-gram: 'Korea won', 'won the', 'the soccer', 'soccer world', 'world cup', 'cup final'
์˜ˆ์ธก ์š”์•ฝ๋ฌธ๊ณผ ์‹ค์ œ ์š”์•ฝ๋ฌธ ์‚ฌ์ด์— ๊ฒน์น˜๋Š” bi-gram ์ˆ˜๋Š” 3์ด๊ณ ,
์‹ค์ œ ์š”์•ฝ๋ฌธ์˜ bi-gram์ˆ˜๋Š” 4์ด๋ฏ€๋กœ
โˆ™ROUGE-2 = 3/4 = 0.75

+ Recent posts