๐ ๋ชฉ์ฐจ
1. Neural Machine Translation (NMT)
2. seq2seq
3. Attention
4. Input Feeding
5. AutoRegressive. &. Teacher Forcing
6. Searching Algorithm(Inference) & Beam Search
7. Performance Metric [PPL / BLEU / METEOR / ROUGE]
๐ ๊ธ์ ๋ง์น๋ฉฐ...
1. Neural Machine Translation (NMT)
1.1 ๋ฒ์ญ์ ๋ชฉํ
NMT๋ end-to-endํ์ต์ผ๋ก์จ, ๊ท์น๊ธฐ๋ฐ๊ธฐ๊ณ๋ฒ์ญ(RBMT)๊ณผ ํต๊ณ๊ธฐ๋ฐ๊ธฐ๊ณ๋ฒ์ญ(SMT)์ ๋ช ๋ชฉ์ ์ด์ด๋ฐ์์ ๊ฐ์ฅ ํฐ ์ฑ์ทจ๋ฅผ ์ด๋ฃฉํด๋๋ค.
cf) end-to-endํ์ต: ์ ๋ ฅ ๋ฐ์ดํฐ์์๋ถํฐ ์ํ๋ ์ถ๋ ฅ์ ์ง์ ์์ธกํ๊ณ ํ์ตํ๋ ๋ฐฉ์, ์ค๊ฐ ๋จ๊ณ๋ ํน์ง ์ถ์ถ ๋จ๊ณ ์์ด ์ ๋ ฅ๊ณผ ์ถ๋ ฅ ๊ฐ์ ๊ด๊ณ๋ฅผ ๋ชจ๋ธ๋งํ๋ ค๋ ๊ฒ์ ์๋ฏธ
๋ฒ์ญ์ ๊ถ๊ทน์ ์ธ ๋ชฉํ: ์ด๋ค ์ธ์ด f์ ๋ฌธ์ฅ์ด ์ฃผ์ด์ง ๋, ๊ฐ๋ฅํ e ์ธ์ด์ ๋ฒ์ญ๋ฌธ์ฅ ์ค, ์ต๋ํ๋ฅ ์ ๊ฐ๋ ๊ฐ์ ์ฐพ๋๊ฒ
1.2 ๊ธฐ๊ณ๋ฒ์ญ์ ์ญ์ฌ
โ ๊ท์น๊ธฐ๋ฐ ๊ธฐ๊ณ๋ฒ์ญ[RBMT]
์ฃผ์ด์ง ๋ฌธ์ฅ์ ๊ตฌ์กฐ๋ฅผ ๋ถ์, ๊ทธ ๋ถ์์ ๋ฐ๋ผ ๊ท์น์ ์ธ์ฐ๊ณ , ๋ถ๋ฅ๋ฅผ ๋๋ ์ ํด์ง ๊ท์น์ ๋ฐ๋ผ ๋ฒ์ญ
์ด ๊ณผ์ ์ ์ฌ๋์ด ๋ชจ๋ ๊ฐ์ ํด์ผํ๊ธฐ์ ๋น์ฉ์ ์ธก๋ฉด์์ ๋งค์ฐ ๋ถ๋ฆฌํ๋ค.
โ ํต๊ณ๊ธฐ๋ฐ ๊ธฐ๊ณ๋ฒ์ญ[SMT]
๋๋์ Bi-Direction corpus์์ ํต๊ณ๋ฅผ ์ป์ด ๋ฒ์ญ์์คํ ์ ๊ตฌ์ฑํ๋ ๊ฒ์ผ๋ก
์๊ณ ๋ฆฌ์ฆ์ด๋ ์์คํ ์ผ๋ก ์ธํด ์ธ์ด์์ ํ์ฅํ ๋, RBMT์ ๋นํด ํจ์ฌ ์ ๋ฆฌํ๋ค.
โ ์ ๊ฒฝ๋ง ๊ธฐ๊ณ๋ฒ์ญ[NMT]
- DNN ์ด์ : Encoder-Decoderํํ์ ๊ตฌ์กฐ
- DNN ์ดํ: end-to-end ๋ชจ๋ธ, NNLM๊ธฐ๋ฐ, ํ๋ฅญํ ๋ฌธ์ฅ์๋ฒ ๋ฉ ๋ฑ์ ์ฅ์ ์ผ๋ก ๋งค์ฐ powerfulํด์ก๋ค.
2. seq2seq
2.1 Architecture
seq2seq๋ ์ฌํํ๋ฅ P(Y | X;θ)๋ฅผ ์ต๋๋กํ๋ ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฐพ์์ผ ํ๋ฉฐ, ์ด ์ฌํํ๋ฅ ์ ์ต๋๋ก ํ๋ Y๋ฅผ ์ฐพ์์ผ ํ๊ธฐ์ ํฌ๊ฒ 3๊ฐ์ง์ ์๋ธ๋ชจ๋[ Encoder / Decoder ]๋ก ๊ตฌ์ฑ๋๋ค.
seq2seq : Sequence to Sequence[Sutskever2014; https://arxiv.org/abs/1409.3215]์ ํ์ ์ฑ์ "๊ฐ๋ณ๊ธธ์ด์ ๋ฌธ์ฅ์ ๊ฐ๋ณ๊ธธ์ด์ ๋ฌธ์ฅ์ผ๋ก ๋ณํ"ํ ์ ์๋ค๋ ๊ฒ์ด๋ค.
ex) ํ๊ตญ์ด→์์ด๋ก ๋ฒ์ญ ์, ๋์ ๋ฌธ์ฅ๊ธธ์ด๊ฐ ๋ฌ๋ผ seq2seq model์ด ํ์
- ํ์ต ์ Decoder์ input๋ถ๋ถ๊ณผ output๋ถ๋ถ์ด ๋ชจ๋ ๋์ํ๋ค.
์ฆ, ์ ๋ต์ ํด๋นํ๋ ์ถ๋ ฅ์ ์๋ ค์ฃผ๋ ๊ต์ฌ๊ฐ์(teacher forcing)๋ฐฉ๋ฒ์ ์ฌ์ฉ.
teacher forcing์ ๋ํด์๋ ์๋ 5๋ฒํญ๋ชฉ์์ ์ค๋ช ํ๊ฒ ๋ค.
- ์์ธก ์ ์ ๋ต์ ๋ชจ๋ฅด๊ธฐ ๋๋ฌธ์ ํ์ํ์ํ input ๋ถ๋ถ์ ์ ์ธํ๊ณ ์๊ธฐํ๊ท(auto-regressive) ๋ฐฉ์์ผ๋ก ๋์ํ๋ค.
์๊ธฐํ๊ท์์ <SOS>๊ฐ ์ ๋ ฅ๋๋ฉด ์ฒซ ๋จ์ด 'That'์ ์ถ๋ ฅํ๊ณ 'That'์ ๋ณด๊ณ ๊ทธ ๋ค์ ๋์งธ ๋จ์ด 'can't'๋ฅผ ์ถ๋ ฅํ๋ค.
์ฆ, ์ด์ ์ ์ถ๋ ฅ๋ ๋จ์ด๋ฅผ ๋ณด๊ณ ํ์ฌ๋จ์ด๋ฅผ ์ถ๋ ฅํ๋ ์ผ์ ๋ฐ๋ณตํ๋ฉฐ, ๋ฌธ์ฅ๋์ ๋ํ๋ด๋ <EOS>๊ฐ ๋ฐ์ํ๋ฉด ๋ฉ์ถ๋ค.
- ํ๊ณ : ๊ฐ์ฅ ํฐ ๋ฌธ์ ๋ encoder์ ๋ง์ง๋ง hidden state๋ง decoder์ ์ ๋ฌํ๋ค๋ ์ ์ด๋ค.
์๋ ๊ทธ๋ฆผ์์ ๋ณด๋ฉด h5๋ง decoder๋ก ์ ๋ฌ๋๋ค.
๋ฐ๋ผ์ encoder๋ ๋ง์ง๋ง hidden state์ ๋ชจ๋ ์ ๋ณด๋ฅผ ์์ถํด์ผํ๋ ๋ถ๋ด์ด ์กด์ฌํ๋ค.
Encoder
์ฃผ์ด์ง ๋ฌธ์ฅ์ธ ์ฌ๋ฌ ๊ฐ์ ๋ฒกํฐ๋ฅผ ์ ๋ ฅ์ผ๋ก ๋ฐ์ ๋ฌธ์ฅ์ ํจ์ถํ๋ ๋ฌธ์ฅ์๋ฒ ๋ฉ๋ฒกํฐ๋ก ๋ง๋ ๋ค.
์ฆ, P(z | X)๋ฅผ ๋ชจ๋ธ๋ง ํ, ์ฃผ์ด์ง ๋ฌธ์ฅ์ manifold๋ฅผ ๋ฐ๋ผ ์ฐจ์์ถ์, ํด๋น ๋๋ฉ์ธ์ latent space์ ์ด๋ค ํ๋์ ์ ์ ํฌ์ํ๋ ์์ ์ด๋ค.
๋ค๋ง ๊ธฐ์กด์ text classification์์๋ ๋ชจ๋ ์ ๋ณด(feature)๊ฐ ํ์ํ์ง ์๋ค.
๋ฐ๋ผ์ ๋ฒกํฐ์์ฑ ์ ๋ง์ ์ ๋ณด๋ฅผ ๊ฐ์ง ํ์๊ฐ ์๋ค.
ํ์ง๋ง NMT๋ฅผ ์ํ sentence embedding vector ์์ฑ ์, ์ต๋ํ ๋ง์ ์ ๋ณด๋ฅผ ๊ฐ์ ธ์ผํ๋ค.
์ถ๊ฐ์ ์ผ๋ก seq2seq๋ชจ๋ธ์์ hidden layer๊ฐ์ concatenate ์์ ์ผ๋ก ์ ์ฒด time-step์ ํ๋ฒ์ ๋ณ๋ ฌ๋ก ์ฒ๋ฆฌํ๋ค.
Decoder
์ผ์ข ์ ์กฐ๊ฑด๋ถ ์ ๊ฒฝ๋ง ์ธ์ด๋ชจ๋ธ[CNNLM]์ ์กฐ๊ฑด๋ถ ํ๋ฅ ๋ณ์ ๋ถ๋ถ์ X๊ฐ ์ถ๊ฐ๋ ํํ๋ผ ํ ์ ์๋ค.
์ฆ, encoder์ ๊ฒฐ๊ณผ์ธ sentence embedding vector์ ์ด์ time-step๊น์ง ๋ฒ์ญํด ์์ฑํ ๋จ์ด๋ค์ ๊ธฐ๋ฐํด ํ์ฌ time-step์ ๋จ์ด๋ฅผ ์์ฑํ๋ค.
ํน์ดํ ์ ์ Decoder์ ์ ๋ ฅ์ ์ด๊ธฐ๊ฐ์ผ๋ก์จ BOS token์ ์ ๋ ฅ์ผ๋ก ์ค๋ค๋ ์ ์ด๋ค.
โ๏ธBOS (Beginning of Sentence): BOS๋ ๋ฌธ์ฅ์ ์์์ ๋ํ๋ด๋ ํน๋ณํ ํ ํฐ ๋๋ ์ฌ๋ณผ๋ก ์ฃผ๋ก Seq2Seq๋ชจ๋ธ๊ณผ ๊ฐ์ ๋ชจ๋ธ์์ ์ ๋ ฅ ์ํ์ค์ ์์์ ํ์ํ๋ ๋ฐ ์ฌ์ฉ๋๋ค.
Generator
Decoder์์ ๊ฐ time-step๋ณ๋ก ๊ฒฐ๊ณผ๋ฒกํฐ h๋ฅผ ๋ฐ์ softmax๋ฅผ ๊ณ์ฐํด ๊ฐ target์ธ์ด์ ๋จ์ด์ดํ๋ณ ํ๋ฅ ๊ฐ์ ๋ฐํํ๋ค.
์์ฑ์์ ๊ฒฐ๊ณผ๊ฐ์ ๊ฐ ๋จ์ด๊ฐ ๋ํ๋ ํ๋ฅ ์ธ ์ด์ฐํ๋ฅ ๋ถํฌ๊ฐ ๋๋ค.
์ด๋, ์ฃผ์ํ ์ ์ ๋ฌธ์ฅ์ ๊ธธ์ด๊ฐ |Y|=m์ด๋ผ๋ฉด, ๋ง์ง๋ง ๋ฐํ๋จ์ด ym์ EOS token์ด ๋๋ค๋ ์ ์ด๋ค.
์ด EOS๋ก Decoder ๊ณ์ฐ์ ์ข ๋ฃ๋ฅผ ๋ํ๋ธ๋ค.
โ๏ธEOS (End of Sentence): EOS๋ ๋ฌธ์ฅ์ ๋์ ๋ํ๋ด๋ ํน๋ณํ ํ ํฐ ๋๋ ์ฌ๋ณผ๋ก ์ฃผ๋ก Seq2Seq ๋ชจ๋ธ๊ณผ ๊ฐ์ ๋ชจ๋ธ์์ ์ถ๋ ฅ ์ํ์ค์ ๋์ ๋ํ๋ด๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
2.2 seq2seq ํ์ฉ๋ถ์ผ
๊ธฐ๊ณ๋ฒ์ญ(MT)
ChatBot
Summarization
Speech Recognition
Image Captioning
2.3 ํ๊ณ [Bottleneck Problem]
๊ฐ์ฅ ํฐ ๋ฌธ์ ๋ encoder์ ๋ง์ง๋ง hidden state๋ง decoder์ ์ ๋ฌํ๋ค๋ ์ ์ด๋ค.
์ด๋ฅผ Bottleneck Problem์ด๋ผ ํ๋๋ฐ, Encoder๋ ๋ง์ง๋ง hidden state์ ํ๋์ ๊ณ ์ ๋ ํฌ๊ธฐ์ single vector๋ก ๋ชจ๋ ์ ๋ณด๋ฅผ ์์ถํด์ผํ๋ ๋ถ๋ด์ด ์กด์ฌํ๊ฒ ๋๋ค.
์ฆ, ์ ๋ณด์์ค์ด ๋ฐ์ ๋ฐ ๊ธฐ์ธ๊ธฐ ์์ค์ด ๋์ด๋ฒ๋ฆฐ๋ค.
์ด๋ฐ ํ๊ณ๋ฅผ ํด๊ฒฐํ๋ ๊ฒ์ด ๋ฐ๋ก Attention ๋ฉ์ปค๋์ฆ์ ์ด์ฉํ๋ ๋ฐฉ๋ฒ์ด๋ค.
attention ๋ฉ์ปค๋์ฆ์ ๊ด๋ จ์๋ ๋จ์ด์์ attention์ ๋์ฌ ๊ธฐ์กด์ฒ๋ผ ๋ค์ ์ง์ค๋๋ ํ์์ ๋ฐฉ์งํ๋ค.
์ฆ, "ํน์ ๋ถ๋ถ์ ์ง์ค"ํ๊ธฐ ์ํด Decoder์ ๊ฐ ๋จ๊ณ์์ encoder์ ์ง์ ์ ์ธ ์ฐ๊ฒฐ์ ํ๊ฒ ํ๋ค.
Pytorch ์์
Encoder ํด๋์ค
Encoder๋ RNN์ ์ฌ์ฉํ text classification๊ณผ ๊ฑฐ์ ์ ์ฌํ๋ค.
๋ฐ๋ผ์ Bi-Directional LSTM์ ์ฌ์ฉํ๋ค.
class Encoder(nn.Module): def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2): super(Encoder, self).__init__() # Be aware of value of 'batch_first' parameter. # Also, its hidden_size is half of original hidden_size, # because it is bidirectional. self.rnn = nn.LSTM( word_vec_size, int(hidden_size / 2), num_layers=n_layers, dropout=dropout_p, bidirectional=True, batch_first=True, ) def forward(self, emb): # |emb| = (batch_size, length, word_vec_size) if isinstance(emb, tuple): x, lengths = emb x = pack(x, lengths.tolist(), batch_first=True) # Below is how pack_padded_sequence works. # As you can see, # PackedSequence object has information about mini-batch-wise information, # not time-step-wise information. # # a = [torch.tensor([1,2,3]), torch.tensor([3,4])] # b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True) # >>>> # tensor([[ 1, 2, 3], # [ 3, 4, 0]]) # torch.nn.utils.rnn.pack_padded_sequence(b, batch_first=True, lengths=[3,2] # >>>>PackedSequence(data=tensor([ 1, 3, 2, 4, 3]), batch_sizes=tensor([ 2, 2, 1])) else: x = emb y, h = self.rnn(x) # |y| = (batch_size, length, hidden_size) # |h[0]| = (num_layers * 2, batch_size, hidden_size / 2) if isinstance(emb, tuple): y, _ = unpack(y, batch_first=True) return y, h
Decoder ํด๋์ค
์ดํ ๋์ฌ Attention๊ฐ๋ ์ ์ถ๊ฐํ๋ฉด, ์๋์ ๊ฐ๋ค.
class Decoder(nn.Module): def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2): super(Decoder, self).__init__() # Be aware of value of 'batch_first' parameter and 'bidirectional' parameter. self.rnn = nn.LSTM( word_vec_size + hidden_size, hidden_size, num_layers=n_layers, dropout=dropout_p, bidirectional=False, batch_first=True, ) def forward(self, emb_t, h_t_1_tilde, h_t_1): # |emb_t| = (batch_size, 1, word_vec_size) # |h_t_1_tilde| = (batch_size, 1, hidden_size) # |h_t_1[0]| = (n_layers, batch_size, hidden_size) batch_size = emb_t.size(0) hidden_size = h_t_1[0].size(-1) if h_t_1_tilde is None: # If this is the first time-step, h_t_1_tilde = emb_t.new(batch_size, 1, hidden_size).zero_() # Input feeding trick. x = torch.cat([emb_t, h_t_1_tilde], dim=-1) # Unlike encoder, decoder must take an input for sequentially. y, h = self.rnn(x, h_t_1) return y, h
Generator ํด๋์ค
class Generator(nn.Module): def __init__(self, hidden_size, output_size): super(Generator, self).__init__() self.output = nn.Linear(hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=-1) def forward(self, x): # |x| = (batch_size, length, hidden_size) y = self.softmax(self.output(x)) # |y| = (batch_size, length, output_size) # Return log-probability instead of just probability. return y
Loss Function
'softmax+CE'๋ณด๋ค๋ 'logsoftmax + NLL(์์ ๋ก๊ทธ๊ฐ๋ฅ๋)'๋ฅผ ์ฌ์ฉํ๋ค.
# Default weight for loss equals to 1 # But, we don't need to get loss for PAD Token # Thus, set a weight for PAD to 0 loss_weight = torch.ones(output_size) loss_weight[data_loader.PAD] = 0. # Instead of using Cross-Entropy, # We can use NLL(Negative-Log-Likelihood) loss with log-probability crit = nn.NLLLoss(weight=loss_weight, reduction='sum', )
๋ฐ๋ผ์ softmax์ฌ์ฉ๋์ logsoftmaxํจ์๋ก ๋ก๊ทธํ๋ฅ ์ ๊ตฌํ๋ค.def _get_loss(self, y_hat, y, crit=None): # |y_hat| = (batch_size, length, output_size) # |y| = (batch_size, length) crit = self.crit if crit is None else crit loss = crit(y_hat.contiguous().view(-1, y_hat.size(-1)), y_contiguous().view(-1) ) return loss
3. Attention
3.1 Attention์ ๋ชฉํ
Query์ ๋น์ทํ ๊ฐ์ ๊ฐ๋ Key๋ฅผ ์ฐพ์ ๊ทธ ๊ฐ์ ์ป๋ ๊ณผ์ . (์ด๋, ๊ทธ ๊ฐ์ Value๋ผ ํ๋ค.)
3.2 Key-Value ํจ์
โ Python์ Dictionary: Key-Value์ ์์ผ๋ก ์ด๋ฃจ์ด์ง ์๋ฃํ
Dic = {'A.I':9 , 'computer':5, 'NLP':4}
์ด์ ๊ฐ์ด Key์ Value์ ํด๋นํ๋ ๊ฐ๋ค์ ๋ฃ๊ณ , Key๋ฅผ ํตํด Value๊ฐ์ ์ ๊ทผ๊ฐ๋ฅํ๋ค.
์ฆ, Query๊ฐ ์ฃผ์ด์ง ๋, Key๊ฐ์๋ฐ๋ผ Value๊ฐ์ ์ ๊ทผํ ์ ์๋ค.
def KV(Q): weights = [] for K in dic.keys(): weights += [is_same(K, Q)] weight_sum = sum(weights) for i, w in enumerate(weights): weights[i] = weights[i] / weight_sum ans = 0 for weight, V in zip(weights, dic.values()): ans += weight*V return ans
def is same(K, Q): if K == Q: return 1. else: reutnr .0
3.3 ์ฐ์์ ์ธ Key-Value ๋ฒกํฐ ํจ์
๋ง์ฝ Dic์ Value์ 100์ฐจ์ ๋ฒกํฐ๊ฐ ๋ค์ด๊ฐ์๋ค๋ฉด??
ํน์ Query์ Key๊ฐ ๋ชจ๋ ๋ฒกํฐ๋ฅผ ๋ค๋ค์ผ ํ๋ค๋ฉด??
โ ์ฆ, Q, K๊ฐ word embedding vector๋ผ๋ฉด??
โ ๋๋, Dic์ K, V๊ฐ์ด ์๋ก ๊ฐ๋ค๋ฉด??
def KV(Q): weights = [] for K in dic.keys(): weights += [how_similar(K, Q)] # cosine similarity๊ฐ์ ์ฑ์ด๋ค. weights = softmax(weights) # ๋ชจ๋ ๊ฐ์ค์น๋ฅผ ๊ตฌํ ํ softmax๋ฅผ ๊ณ์ฐ(๋ชจ๋ wํฉํฌ๊ธฐ๋ฅผ 1๋ก ๊ณ ์ ) ans = 0 for w, V in zip(weights, dic.values()): ans += w*V return ans
์์ ์ฝ๋์์ ans์๋ ์ด๋ ํ ๋ฒกํฐ๊ฐ์ด ๋ค์ด๊ฐ๋ค.
ans๋ด๋ถ์ ๋ฒกํฐ๋ค์ ์ฝ์ฌ์ธ ์ ์ฌ๋์ ๋ฐ๋ผ ๋ฒกํฐ๊ฐ์ด ์ ํด์ง๋ค.
์ฆ, ์์ ํจ์๋ Q์ ๋น์ทํ K๊ฐ์ ์ฐพ์ ์ ์ฌ๋์ ๋ฐ๋ผ Weight๋ฅผ ์ ํ๊ณ , ๊ฐ K์ V๊ฐ์ W๊ฐ๋งํผ ๊ฐ์ ธ์ ๋ชจ๋ ๋ํ๋ ๊ฒ์ผ๋ก
์ด๊ฒ์ด ๋ฐ๋ก Attention Mechanism์ ํต์ฌ ์์ด๋์ด์ด๋ค.
3.4 NMT์์์ Attention
๊ทธ๋ ๋ค๋ฉด, MT์์ Attention Mechanism์ ์ด๋ป๊ฒ ์๋๋ ๊น?
โ K,V: Encoder์ ๊ฐ time-step๋ณ ์ถ๋ ฅ
โ Q: ํ์ฌ time-step์ Decoder์ถ๋ ฅ
Seq2Seq with Attention ์ํ๋ ์ ๋ณด๋ฅผ Attention์ ํตํด Encoder์์ ์ป๊ณ ,
ํด๋น ์ ๋ณด๋ฅผ Decoder์ ์ถ๋ ฅ๊ณผ ์ด์ด๋ถ์ฌ tanh๋ฅผ ์ทจํ ํ
softmax๊ณ์ฐ์ ํตํด ๋ค์ time-step์ ์ ๋ ฅ์ด ๋๋ y_hat์ ๊ตฌํ๋ค.
Linear Transformation
์ ๊ฒฝ๋ง ๋ด๋ถ์ ๊ฐ ์ฐจ์๋ค์ latent feature๊ฐ์ด๊ธฐ์ ์ ํํ ์ ๋ฆฌํ ์ ์๋ค.
ํ์ง๋ง, ํ์คํ ์ ์ source์ธ์ด์ ๋์์ธ์ด๊ฐ ์ ์ด์ ๋ค๋ฅด๋ค๋ ๊ฒ์ด๋ค.
๋ฐ๋ผ์, ๋จ์ํ ๋ฒกํฐ๋ด์ ์ ํ๊ธฐ๋ณด๋จ ์์ค์ ๋์๊ฐ์ ์ฐ๊ฒฐ๊ณ ๋ฆฌ๊ฐ ํ์ํ๋ค.
๋ฐ๋ผ์, ๋ ์ธ์ด๊ฐ ๊ฐ๊ฐ ์๋ฒ ๋ฉ๋ latent space์ด ์ ํ๊ด๊ณ์ ์๋ค ๊ฐ์ ํ๊ณ ,
๋ด์ ์ฐ์ฐ์ํ์ ์ํด ์ ํ๋ณํ์ ํด์ค๋ค. (์ ํ๋ณํ์ ์ํ W๊ฐ์ ๊ฐ์ค์น๋ก FF, BP๋ก ํ์ต๋๋ค.)
โ์ Attention์ด ํ์ํ ๊น?์ ๋ํ ์ง๋ฌธ์์, ์ด ์ ํ๋ณํ์ ๋ฐฐ์ฐ๋ ๊ฒ ์์ฒด๊ฐ Attention์ด๋ผ ํํํ๋ ๊ฒ์ ๊ณผํ์ง ์๋ค.
(โต ์ ํ๋ณํ๊ณผ์ ์ผ๋ก Decoder์ ํ์ฌ์ํ์ ํ์ํ Q๋ฅผ ์์ฑ, Encoder์ K๊ฐ๋ค๊ณผ ๋น๊ต, ๊ฐ์คํฉ์ ํ๋ ๊ฒ์ด๊ธฐ ๋๋ฌธ)
์ฆ, Attention์ ํตํด Decoder๋ Encoder์ Q๋ฅผ ์ ๋ฌํ๋ฉฐ, ์ด๋ ์ข์ ๊ฐ์ ์ ๋ฌํ๋ ๊ฒ์ ์ข์ ๊ฒฐ๊ณผ๋ก ์ด์ด์ง๊ธฐ ๋๋ฌธ์ ํ์ฌ Decoder์ํ์ ํ์ํ ์ ๋ณด๊ฐ ๋ฌด์์ธ์ง ์ค์ค๋ก ํ๋จํด ์ ํ๋ณํ์ ํตํด Q๋ฅผ ๋ง๋๋ ๊ฒ์ด ๋งค์ฐ ์ค์ํ ๊ฒ์ด๋ค.
๋ํ ์ ํ๋ณํ์ ์ํ ๊ฐ์ค์น ์์ฒด๋ ํ๊ณ๊ฐ ์๊ธฐ์ Decoder์ ์ํ ์์ฒด๊ฐ ์ ํ๋ณํ์ด ๋์ด Q๊ฐ ์ข์ ํํ๊ฐ ๋๋๋ก RNN์ด ๋์ํ ๊ฒ์ด๋ค.
์์ ๋ฑ์ ์์ธํ ์์ ๋ฐ ๊ณผ์ ๋ค ์ฐธ๊ณ : https://chan4im.tistory.com/161
Pytorch ์์
โ Attention ํด๋์ค: ์ ํ๋ณํ์ ์ํ ๊ฐ์ค์นํ๋ผ๋ฏธํฐ๋ฅผ bias๊ฐ ์๋ ์ ํ์ธต์ผ๋ก ๋์ฒดํ์๋ค.
์กฐ๊ธ ๋ ์์ธํ ์ค๋ช ์ ๋ค์ Section์์ ๊ณ์ ์งํํ๋ค.
class Attention(nn.Module): def __init__(self, hidden_size): super(Attention, self).__init__() self.linear = nn.Linear(hidden_size, hidden_size, bias=False) self.softmax = nn.Softmax(dim = -1) def forward(self, h_src, h_t_target, mask=None): # |h_src| = (batch_size, length, hidden_size) # |h_t_target| = (batch_size, 1, hidden_size) # |mask| = (batch_size, length) Q = self.linear(h_t_target.squeeze(1)).unsqueeze(-1) # |Q| = (batch_size, hidden_size, 1) weight = torch.bmm(h_src, Q).squeeze(-1) # |weight| = (batch_size, length) if mask is not None: # Set each weight as -inf, if the mask value equals to 1. # Since the softmax operation makes -inf to 0, # masked weights would be set to 0 after softmax operation. # Thus, if the sample is shorter than other samples in mini-batch, # the weight for empty time-step would be set to 0. weight.masked_fill_(mask.unsqueeze(1), -float('inf')) weight = self.softmax(weight) context_vector = torch.bmm(weight, h_src) # |context_vector| = (batch_size, 1, hidden_size) return context_vector
cf) bmm (batch matrix multiplication; ๋ฐฐ์น ํ๋ ฌ๊ณฑ)
torch.bmm(x, y)์ ๋ํด ์ค๋ช ํด๋ณด์.
โ |x| = (batch_size, m, k)
โ |y| = (batch_size, h, m)
| torch.bmm(x, y) | = (batch_size, n, m)
4. Input Feeding
๊ฐ time-step์ decoder์ถ๋ ฅ๊ฐ๊ณผ attention๊ฒฐ๊ณผ๊ฐ์ ์ด์ด๋ถ์ธ ํ Generator Module์์ softmax๋ฅผ ์ทจํด ํ๋ฅ ๋ถํฌ๋ฅผ ๊ตฌํ๋ค.
์ดํ ํด๋น ํ๋ฅ ๋ถํฌ์์ argmax๋ฅผ ์ํํด y_hat์ samplingํ๋ค.
๋ค๋ง, ๋ถํฌ์์ samplingํ๋ ๊ณผ์ ์์๋ณด๋ค ๋ ๋ง์ ์ ๋ณด๊ฐ ์์ค๋๋ค.
๋ฐ๋ผ์ softmax์ด์ ๊ฐ๋ ๊ฐ์ด ๋ฃ์ด์ฃผ๋ ํธ์ด ์ ๋ณด์์ค์์ด ๋ ์ข์ ํจ๊ณผ๋ฅผ ์ป๋๋ค.,
y์ ๋ฌ๋ฆฌ concat์ธต์ ์ถ๋ ฅ์ y๊ฐ embedding์ธต์์ dense๋ฒกํฐ๋ก ๋ณํ๋๊ณ ๋ ํ ์๋ฒ ๋ฉ ๋ฒกํฐ์ ์ด์ด๋ถ์ฌ Decoder RNN์ ์ ๋ ฅ์ผ๋ก ์ฃผ์ด์ง๋ ๊ณผ์ ์ input feeding์ด๋ผ ํ๋ค.
4.1 ๋จ์
์ด ๋ฐฉ์์ train์๋์ ํ์ ๋จ์ ์ด ์กด์ฌํ๋๋ฐ, input feeding์ด์ ๋ฐฉ์์์๋ ํ๋ จ ์ decoder ๋ํ encoder์ฒ๋ผ ๋ชจ๋ time-step์ ํ๋ฒ์ ๊ณ์ฐํ๋ค.
ํ์ง๋ง input feeding์ผ๋ก์ธํด decoder RNN์ ๋ ฅ์ผ๋ก ์ด์ time-step์ ๊ฒฐ๊ณผ๊ฐ ํ์ํ๊ฒ ๋์ด ์์ฐจ์ ์ผ๋ก time-step๋ณ๋ก ๊ณ์ฐํด์ผํ๋ค.
๋ค๋ง, ์ด ๋จ์ ์ ์ถ๋ก ๋จ๊ณ์์ ์ด์ฐจํผ decoder๋ input feeding์ด ์๋๋๋ผ๋ time-step๋ณ ๋ณ๋ ฌ์ฒ๋ฆฌ๊ฐ ์๋ ์์ฐจ์ ๊ณ์ฐ์ด ํ์ํ๊ธฐ์ ์ถ๋ก ์ input feeding์ผ๋ก ์ธํ ์๋ ์ ํ๋ ๊ฑฐ์ ์๋ค; ๋ฐ๋ผ์ ์ด ๋จ์ ์ด ํฌ๊ฒ ๋ถ๊ฐ๋์ง๋ ์๋๋ค.
Pytorch ์์
https://github.com/V2LLAIN/NLP/blob/main/5.%20RNN__Seq2Seq/seq2seq(with_attention).py
5. Auto Regressive .&. Teacher Forcing
์ฌ๊ธฐ์ ์๋ฌธ์ ...?
train ์ Decoder์ ์ ๋ ฅ์ผ๋ก time-step์ ์ถ๋ ฅ์ด ๋ค์ด๊ฐ๋๊ฑธ๊น??
์ฌ์ค, seq2seq์ ๊ธฐ๋ณธ์ ํ๋ จ๋ฐฉ์์ ์ถ๋ก ํ ๋์ ๋ฐฉ์๊ณผ ์์ดํ๋ค.
5.1 AR(Auto Regressive) ์์ฑ
seq2seq์ train๊ณผ inference์ ๊ทผ๋ณธ์ ์ธ ์ฐจ์ด๋ AR์์ฑ์ผ๋ก ๋ฐ์ํ๋ค.
์๊ธฐํ๊ท(AR): ๊ณผ๊ฑฐ ์์ ์ ๊ฐ์ ์ฐธ์กฐํด ํ์ฌ์ ๊ฐ์ ์ถ๋ก ํ๋ ํน์ง
์ด๋ฅผ ์๋ ์์์์๋ ํ์ธํ ์ ์๋ค.
๋ค๋ง, ๊ณผ๊ฑฐ์ ๊ฒฐ๊ณผ๊ฐ์๋ฐ๋ผ ๋ฌธ์ฅ์ด๋ ์ํ์ค์ ๊ตฌ์ฑ์ด ๋ฐ๋๋ฟ๋ง์๋๋ผ ์์ธก๋ฌธ์ฅ์ํ์ค์ ๊ธธ์ด๋ ๋ฐ๋ ์ ์๊ณ ,
๊ณผ๊ฑฐ์ ์๋ชป๋ ์์ธก์ ํ์ ๋, ๋ ํฐ ์๋ชป๋ ์์ธก์ ํ ๊ฐ๋ฅ์ฑ์ ์ผ๊ธฐํ๊ธฐ๋ ํ๋ค.
ํ์ต๊ณผ์ ์์๋ ์ด๋ฏธ ์ ๋ต์ ์๊ณ ์๊ณ ํ์ฌ๋ชจ๋ธ์ ์์ธก๊ฐ๊ณผ ์ ๋ต๊ณผ์ ์ฐจ์ด๋ฅผ ํตํด ํ์ตํ๊ธฐ์ ์๊ธฐํ๊ท(AR)์์ฑ์ ์ ์งํ ์ฑ ํ๋ จํ ์๋ ์๋ค.
๋ฐ๋ผ์ Teacher Forcing์ด๋ผ ๋ถ๋ฆฌ๋ ๋ฐฉ๋ฒ์ ํตํด ํ๋ จ์ ์งํํ๋ค.
5.2 Teacher Forcing ํ๋ จ๋ฐฉ๋ฒ
Teacher Forcing์ ํ๋ จ ์ decoder์ ์ ๋ ฅ์ผ๋ก ์ด์ time-step์ decoder ์ถ๋ ฅ๊ฐ์ด ์๋, ์ ๋ต Y๊ฐ ๋ค์ด๊ฐ๋ค๋ ์ ์ด๋ค.
ํ์ง๋ง ์ถ๋ก ์, ์ ๋ต Y๋ฅผ ๋ชจ๋ฅด๊ธฐ์ ์ด์ time-step์์ ๊ณ์ฐ๋์ด ๋์จ y_hat๊ฐ์ decoder์ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํ๋ค.
์ด๋ ๊ฒ ์ ๋ ฅ์ ๋ฃ์ด์ฃผ๋ ํ๋ จ๋ฐฉ๋ฒ์ teacher forcing์ด๋ผ ํ๋ค.
์ด์ :
์ด๊ธฐ ํ๋ จ ๋จ๊ณ์์ ์์ ์ ์ธ ํ์ต์ ๋๊ณ , ๋ชจ๋ธ์ด ์ด๊ธฐ์ ์ด๋ค ๊ฒ์ ์์ฑํด์ผ ํ๋์ง์ ๋ํ ๊ฐ๋ ฅํ ์ ํธ๋ฅผ ์ ๊ณตํ๋ค.
๋ชจ๋ธ์ด ์ ๋ต ๋ ์ด๋ธ์ ๋ณด๊ณ ํ์ตํ๋ฏ๋ก ๊ธฐ์ธ๊ธฐ์์ค ๋ฌธ์ ๋ฅผ ์ํํ์ฌ ๊ทธ๋๋์ธํธ๊ฐ ๋ ์ ํ๋ฅผ ์ ์์ต๋๋ค.
cf) LM์ Teacher Forcing
[ํ๋ จ ๊ณผ์ ]
๋ชจ๋ธ์ ํ๋ จํ ๋, ๊ฐ ์์ ์์ ์ด์ ์์ ์ ๋ชจ๋ธ ์ถ๋ ฅ์ด ์๋ ์ค์ ์ ๋ต ๋ ์ด๋ธ์ ์ ๋ ฅ์ผ๋ก ์ ๊ณตํ๋ค.
์ฆ, ์ด์ ์์ ์ ์ถ๋ ฅ์ด ์๋ "์ ์๋(teacher)" ์ญํ ์ ์ ๋ต ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ๋ค.
[ํ ์คํธ(์ถ๋ก ) ๊ณผ์ ]
๋ชจ๋ธ์ด ํ๋ จ์ ๋ง์น ํ์๋ ์ด์ ์ถ๋ ฅ์ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํ์ฌ ์ํ์ค๋ฅผ ์์ฑํ๋ค.
์ด๋ ์ด์ ์์ ์ ์ถ๋ ฅ์ "์๊ธฐ ํ๊ท์ ์ผ๋ก(autoregressively)" ์ฌ์ฉํ๋ค.
6. Searching Algorithm(Inference). &. Beam Search
X๊ฐ ์ฃผ์ด์ก์ ๋, Y_hat์ ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ด์ผ๊ธฐ ํด๋ณด์.
์ด๋ฐ ๊ณผ์ ์ ์ถ๋ก ๋๋ ํ์(search)์ด๋ผ ๋ถ๋ฅด๋๋ฐ, ํ์์๊ณ ๋ฆฌ์ฆ์ ๊ธฐ๋ฐํ๊ธฐ ๋๋ฌธ์ด๋ค.
์ฆ, ์ฐ๋ฆฌ๊ฐ ์ํ๋ ๋จ์ด๋ค ์ฌ์ด ์ต๊ณ ์ ํ๋ฅ ์ ๊ฐ๋ ๊ฒฝ๋ก(path)๋ฅผ ์ฐพ๋ ๊ณผ์ ์ด๋ค.
6.1 sampling
๊ฐ์ฅ ์ ํํ ๋ฐฉ๋ฒ์ time-step๋ณ y_hat์ ๊ณ ๋ฅผ ๋, ๋ง์ง๋ง softmax์ธต์์์ ํ๋ฅ ๋ถํฌ๋๋ก samplingํ๋ ๊ฒ์ด๋ค.
๊ทธ ํ time-step์์ ๊ทธ ์ ํ(y_hat)์ ๊ธฐ๋ฐ์ผ๋ก ๊ทธ ๋ค์ y_hat์ ๋ ๋ค์ samplingํด ์ต์ข ์ ์ผ๋ก EOS๊ฐ ๋์ฌ๋ ๊น์ง samplingํ๋ ๊ฒ์ด๋ค.
๋ค๋ง, ์ด๋ฐ ๋ฐฉ์์ ๊ฐ์ ์ ๋ ฅ์ ๋ํด ๋งค๋ฒ ๋ค๋ฅธ ์ถ๋ ฅ๊ฒฐ๊ณผ๋ฌผ์ด ๋์ฌ ์ ์์ด ์ง์ํ๋ํธ์ด๋ค.
6.2 Greedy Search Algorithm ํ์ฉ
DFS, BFS, DP ๋ฑ ์๋ง์ ํ์๊ธฐ๋ฒ์ด ์กด์ฌํ์ง๋ง Greedy Search Algorithm์ ๊ธฐ๋ฐ์ผ๋ก ํ์์ ๊ตฌํํด๋ณด์.
์ฆ, ๋ชจ๋ time-step์ ๋ํ softmaxํ๋ฅ ๊ฐ๋ค ์ค ๊ฐ์ฅ ํ๋ฅ ๊ฐ์ด ํฐ ์ธ๋ฑ์ค๋ฅผ ๋ฝ์ ๊ทธ time-step์ y_hat์ ์ฌ์ฉํ๋ ๊ฒ์ด๋ค.
์ฆ, ๊ฐ ์ถ๋ ฅ ์์ธก ์ ๊ฐ step์์ ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ ๋์ ๋จ์ด๋ฅผ ์ ํํด ๋งค์ฐ ๋น ๋ฅธ ํ์์ด ๊ฐ๋ฅํ๋ค.
๋จ์ 1.) Decision์ ๋๋๋ฆด ์ ์๊ฒ ๋ ์ ์๊ณ
๋จ์ 2.) ์ต์ข ์ถ๋ ฅ์ด ์ต์ ํ๋ ๊ฒฐ๊ณผ์์ ๋ฉ์ด์ง ์ ์๋ค.
(โต <END>token ์์ฑ์ ๊น์ง decoding์ ์งํํ๊ธฐ ๋๋ฌธ)
Pytorch ์์
def search(self, src, is_greedy=True, max_length=255): if isinstance(src, tuple): x, x_length = src mask = self.generate_mask(x, x_length) else: x, x_length = src, None mask = None batch_size = x.size(0) # Same procedure as teacher forcing. emb_src = self.emb_src(x) h_src, h_0_tgt = self.encoder((emb_src, x_length)) decoder_hidden = self.fast_merge_encoder_hiddens(h_0_tgt) # Fill a vector, which has 'batch_size' dimension, with BOS value. y = x.new(batch_size, 1).zero_() + data_loader.BOS is_decoding = x.new_ones(batch_size, 1).bool() h_t_tilde, y_hats, indice = None, [], [] # Repeat a loop while sum of 'is_decoding' flag is bigger than 0, # or current time-step is smaller than maximum length. while is_decoding.sum() > 0 and len(indice) < max_length: # Unlike training procedure, # take the last time-step's output during the inference. emb_t = self.emb_dec(y) # |emb_t| = (batch_size, 1, word_vec_size) decoder_output, decoder_hidden = self.decoder(emb_t, h_t_tilde, decoder_hidden) context_vector = self.attn(h_src, decoder_output, mask) h_t_tilde = self.tanh(self.concat(torch.cat([decoder_output, context_vector ], dim=-1))) y_hat = self.generator(h_t_tilde) # |y_hat| = (batch_size, 1, output_size) y_hats += [y_hat] if is_greedy: y = y_hat.argmax(dim=-1) # |y| = (batch_size, 1) else: # Take a random sampling based on the multinoulli distribution. y = torch.multinomial(y_hat.exp().view(batch_size, -1), 1) # |y| = (batch_size, 1) # Put PAD if the sample is done. y = y.masked_fill_(~is_decoding, data_loader.PAD) # Update is_decoding if there is EOS token. is_decoding = is_decoding * torch.ne(y, data_loader.EOS) # |is_decoding| = (batch_size, 1) indice += [y] y_hats = torch.cat(y_hats, dim=1) indice = torch.cat(indice, dim=1) # |y_hat| = (batch_size, length, output_size) # |indice| = (batch_size, length) return y_hats, indiceโ
6.3 Beam Search
Greedy Algorithm์ ๋งค์ฐ ์ฝ๊ณ ๊ฐ๋จํ์ง๋ง, ์ต์ (optimal)ํด๋ ๋ณด์ฅํ์ง ์๋๋ค.
๋ฐ๋ผ์ ์ฝ๊ฐ์ trick์ ๊ฐ๋ฏธํ๋๋ฐ, k๊ฐ์ ํ๋ณด๋ฅผ ๋ ์ถ์ ํ๋ ๊ฒ์ด๋ค.
์ด๋, k๋ฅผ beam_size๋ผ ํ๋ค.
Beam_size k์ ๋ํด step์ด ์งํ๋๋ฉด์ k๊ฐ์ ๊ฐ์ง์์ ๋ํด k๋ฅผ ์ ์ง, ์ต์ข ํ๋ณด๊ตฐ์์ ํ๋ฅ ์ด ๊ฐ์ฅ ๋์ ๊ฒ์ ์ ํํ๋ค.
๋ค๋ฅธ time-step์์ <END>token ์์ฑ์ด ๊ฐ๋ฅํ๋ฉฐ
ํ๋์ ๊ฐ์ค(hypothesis)์์ <END>๊ฐ ๋์ค๋ฉด ์ข ๋ฃํ๊ณ , ๋ค๋ฅธ ๊ฐ์ค๋ถ๊ธฐ๋ฅผ ๊ณ์ ํ์ํ๋ค.
์ฆ, ๋จ์ด๊ฐ ์์ฐจ์ ์์ฑ๋์ด ๋์์ฌ๊ฑดํ๋ฅ ๊ณ ๋ ค ๋ฐ ์์ฑํ ๋๋ง๋ค log๊ฐ์ด ๋ํด์ ธ ๋ํด์ง๋ ์์๊ฐ์ด ๋ง์์ ธ ์์๊ฐ์ด ๋๋, ์ผ์ข ์ Normalizeํ๋ ๊ณผ์ ์ ํ๋ฒ ๋ ๊ฑฐ์น ์ ์๊ฒ ๋๋ค.
โ small k
- greedy์ ๊ฑฐ์ ๋น์ทํ๋ค.(= ungrammatic, unnatural, nonsensical, incorrect)
โ Large k
- k๊ฐ ์ปค์ง์๋ก greedy๋ฌธ์ ๋ ์ค์ง๋ง ๊ณ์ฐ๋น์ฉ์ด ์ปค์ง๋ค.
- BLEU_Score๊ฐ ๋จ์ด์ง๋ ๋ฌธ์ ๊ฐ ๋ฐ์ํ๋ค. (โต too-short translation)
๋ฐ๋ผ์ ๋ณดํต Beam_size๋ฅผ 10 ์ดํ๋ก ์ฌ์ฉํ๋ค.
Pytorch ์์
#@profile def batch_beam_search( self, src, beam_size=5, max_length=255, n_best=1, length_penalty=.2 ): mask, x_length = None, None if isinstance(src, tuple): x, x_length = src mask = self.generate_mask(x, x_length) # |mask| = (batch_size, length) else: x = src batch_size = x.size(0) emb_src = self.emb_src(x) h_src, h_0_tgt = self.encoder((emb_src, x_length)) # |h_src| = (batch_size, length, hidden_size) h_0_tgt = self.fast_merge_encoder_hiddens(h_0_tgt) # initialize 'SingleBeamSearchBoard' as many as batch_size boards = [SingleBeamSearchBoard( h_src.device, { 'hidden_state': { 'init_status': h_0_tgt[0][:, i, :].unsqueeze(1), 'batch_dim_index': 1, }, # |hidden_state| = (n_layers, batch_size, hidden_size) 'cell_state': { 'init_status': h_0_tgt[1][:, i, :].unsqueeze(1), 'batch_dim_index': 1, }, # |cell_state| = (n_layers, batch_size, hidden_size) 'h_t_1_tilde': { 'init_status': None, 'batch_dim_index': 0, }, # |h_t_1_tilde| = (batch_size, 1, hidden_size) }, beam_size=beam_size, max_length=max_length, ) for i in range(batch_size)] is_done = [board.is_done() for board in boards] length = 0 # Run loop while sum of 'is_done' is smaller than batch_size, # or length is still smaller than max_length. while sum(is_done) < batch_size and length <= max_length: # current_batch_size = sum(is_done) * beam_size # Initialize fabricated variables. # As far as batch-beam-search is running, # temporary batch-size for fabricated mini-batch is # 'beam_size'-times bigger than original batch_size. fab_input, fab_hidden, fab_cell, fab_h_t_tilde = [], [], [], [] fab_h_src, fab_mask = [], [] # Build fabricated mini-batch in non-parallel way. # This may cause a bottle-neck. for i, board in enumerate(boards): # Batchify if the inference for the sample is still not finished. if board.is_done() == 0: y_hat_i, prev_status = board.get_batch() hidden_i = prev_status['hidden_state'] cell_i = prev_status['cell_state'] h_t_tilde_i = prev_status['h_t_1_tilde'] fab_input += [y_hat_i] fab_hidden += [hidden_i] fab_cell += [cell_i] fab_h_src += [h_src[i, :, :]] * beam_size fab_mask += [mask[i, :]] * beam_size if h_t_tilde_i is not None: fab_h_t_tilde += [h_t_tilde_i] else: fab_h_t_tilde = None # Now, concatenate list of tensors. fab_input = torch.cat(fab_input, dim=0) fab_hidden = torch.cat(fab_hidden, dim=1) fab_cell = torch.cat(fab_cell, dim=1) fab_h_src = torch.stack(fab_h_src) fab_mask = torch.stack(fab_mask) if fab_h_t_tilde is not None: fab_h_t_tilde = torch.cat(fab_h_t_tilde, dim=0) # |fab_input| = (current_batch_size, 1) # |fab_hidden| = (n_layers, current_batch_size, hidden_size) # |fab_cell| = (n_layers, current_batch_size, hidden_size) # |fab_h_src| = (current_batch_size, length, hidden_size) # |fab_mask| = (current_batch_size, length) # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size) emb_t = self.emb_dec(fab_input) # |emb_t| = (current_batch_size, 1, word_vec_size) fab_decoder_output, (fab_hidden, fab_cell) = self.decoder(emb_t, fab_h_t_tilde, (fab_hidden, fab_cell)) # |fab_decoder_output| = (current_batch_size, 1, hidden_size) context_vector = self.attn(fab_h_src, fab_decoder_output, fab_mask) # |context_vector| = (current_batch_size, 1, hidden_size) fab_h_t_tilde = self.tanh(self.concat(torch.cat([fab_decoder_output, context_vector ], dim=-1))) # |fab_h_t_tilde| = (current_batch_size, 1, hidden_size) y_hat = self.generator(fab_h_t_tilde) # |y_hat| = (current_batch_size, 1, output_size) # separate the result for each sample. # fab_hidden[:, begin:end, :] = (n_layers, beam_size, hidden_size) # fab_cell[:, begin:end, :] = (n_layers, beam_size, hidden_size) # fab_h_t_tilde[begin:end] = (beam_size, 1, hidden_size) cnt = 0 for board in boards: if board.is_done() == 0: # Decide a range of each sample. begin = cnt * beam_size end = begin + beam_size # pick k-best results for each sample. board.collect_result( y_hat[begin:end], { 'hidden_state': fab_hidden[:, begin:end, :], 'cell_state' : fab_cell[:, begin:end, :], 'h_t_1_tilde' : fab_h_t_tilde[begin:end], }, ) cnt += 1 is_done = [board.is_done() for board in boards] length += 1 # pick n-best hypothesis. batch_sentences, batch_probs = [], [] # Collect the results. for i, board in enumerate(boards): sentences, probs = board.get_n_best(n_best, length_penalty=length_penalty) batch_sentences += [sentences] batch_probs += [probs] return batch_sentences, batch_probsโ
7. Performance Metric [PPL / BLEU / METEOR / ROUGE]
7.1 ์ ์ฑ์ ํ๊ฐ (Intrinsic Evaluation)
๋ณดํต ์ฌ๋์ด ๋ฒ์ญ๋ ๋ฌธ์ฅ์ ์ฑ์ ํ๋ ํํ
์ด๋, ์ฌ๋์ ์ ์ ๊ฒฌ์ด ๋ฐ์๋ ์ ์๊ธฐ์ ๋ณดํต blind test์ ๊ฐ์ ๋ฐฉ์์ ๊ณ ์ํ๋ค.
๊ฐ์ฅ ์ ํํ ์๋ ์์ง๋ง ์์๊ณผ ์๊ฐ์ด ๋ง์ด ๋ ๋ค๋ ๋จ์ ์ด ์กด์ฌํ๋ค.
์ดํ ๊ตฌ๊ธ์ ๋ฒ์ญ์์คํ ์ ์์๋ณผ ๋, ๊ตฌ๊ธ์์ ์ ์ฑ์ ํ๊ฐ๋ฅผ ํตํด ์ป์ ์ ์์ ๋ํด ์์๋ณผ ๊ฒ์ด๋ค.
7.2 ์ ๋์ ํ๊ฐ (Extrinsic Evaluation)
์์์ ์ธ๊ธํ ์ ์ฑ์ ํ๊ฐ์ ๋จ์ ์ผ๋ก ์ธํด ๋ณดํต ์๋ํ๋ ์ ๋ํ๊ฐ๋ฅผ ์ฃผ๋ก ์ฌ์ฉํ๋ค.
์ด๋, ์ต๋ํ ๋น์ทํ ์ผ๊ด์ฑ์ ๊ฐ๋ ํ๊ฐ๋ฅผ ํด์ผํ๋ฉฐ, ์ธ์ด์ ํน์ง์ด ๋ฐ์๋ ํ๊ฐ๋ฐฉ๋ฒ์ด๋ผ๋ฉด ๋๋์ฑ ์ข์ ๊ฒ์ด๋ค.
PPL (Perplexity)
NMT๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋งค time-step๋ง๋ค ์ต๊ณ ํ๋ฅ ์ ๊ฐ๋ ๋จ์ด๋ฅผ ์ ํ(classification)ํ๋ ์์ ์ด๊ธฐ์ Cross Entropy๋ฅผ ๊ธฐ๋ณธ์ ์ธ Loss Function์ผ๋ก ์ฌ์ฉํ๋ค.
NMT๋ํ ์กฐ๊ฑด๋ถ ์ธ์ด๋ชจ๋ธ์ด๊ธฐ์ PPL๋ฅผ ํตํ ์ฑ๋ฅ์ธก์ ์ด ๊ฐ๋ฅํ๋ค.
๊ฒฐ๊ณผ์ ์ผ๋ก Cross-Entropy์ exp๋ฅผ ์ทจํ PPL์ ํ๊ฐ์งํ๋ก ํ์ฉ๊ฐ๋ฅํ๋ค.
(PPL์ด ๋ฎ์์๋ก, N-gram์ N์ด ํด์๋ก ์ข์ ๋ชจ๋ธ ; https://chan4im.tistory.com/200#n3)
CE์ ์ง๊ฒฐ๋์ด ๊ฐํธํจ์ด ์๋ค๋ ์ฅ์ ์ด ์์ง๋ง, ์ค์ ๋ฒ์ญ๊ธฐ ์ฑ๋ฅ๊ณผ ์๋ฒฝํ ๋น๋ก๊ด๊ณ์ ์๋ค ํ ์๋ ์๋ค.
๊ฐ time-step๋ณ ์ค์ ์ ๋ต์ ํด๋นํ๋ ๋จ์ด์ ํ๋ฅ ๋ง ์ฑ์ ํ๊ธฐ ๋๋ฌธ์ด๋ค.
ํ์ง๋ง ์ธ์ด๋ ๊ฐ์ ์๋ฏธ์ ๋ฌธ์ฅ๋ค์ด๋ผ๋ ์ด์์ด ๋ฐ๋์๋ ์๊ณ ,
๋น์ทํ ์๋ฏธ์ ๋จ์ด๋ก ์นํ๋ ์๋ ์๊ธฐ์ ์์ ํ ์๋ชป๋ ๋ฒ์ญ์ด๋๋ผ๋ Loss๊ฐ์ด ๋ฎ์์๋ ์๋ค.
๋ฐ๋ผ์ ์ค์ ๋ฒ์ญ๋ฌธ์ ํ์ง๊ณผ CE์ฌ์ด์๋ ๊ดด๋ฆฌ๊ฐ ์กด์ฌํ๋ค.
(ํนํ๋ Teacher Forcing๋ฐฉ์์ด ๋ํด์ง๊ธฐ์ ๋๋์ฑ ๊ดด๋ฆฌ๊ฐ ์กด์ฌํ๋ค.)
BLEU (Bi-Lingual Evaluation Understudy)
์์์ ๋งํ PPL์์ CE์ ๊ดด๋ฆฌ๋ฅผ ์ค์ด๊ธฐ ์ํด ์ฌ๋ฌ ๋ฐฉ๋ฒ๋ค์ด ์ ์๋์๋๋ฐ, ๊ฐ์ฅ ๋ํ์ ์ธ BLEU์ ๋ํด ์์๋ณด๊ณ ์ ํ๋ค.
BLEU Score๋ ์ ๋ต๊ณผ ์์ธก๋ฌธ์ฅ๊ฐ์ ์ผ์นํ๋ N-gram ๊ฐ์์ ๋น์จ์ ๊ธฐํํ๊ท ์ ๋ฐ๋ผ ์ ์๋ฅผ ๋งค๊ธด๋ค.
์ฆ, ๊ฐ N-gram๋ณ precision์ ํ๊ท ์ ๋ฐฑ๋ถ์จ๋ก ๋ํ๋ด๋ ๊ฒ์ด๊ณ
์งง์๋ฌธ์ฅ์๋ํ ํ๋ํฐ(brevity_penalty)๋ ์์ธก๋ ๋ฒ์ญ๋ฌธ์ด ์ ๋ต๋ฌธ์ฅ๋ณด๋ค ์งง์๊ฒฝ์ฐ, ์ ์๊ฐ ์ข์์ง๋ ๊ฒ์ ๋ฐฉ์งํ๊ธฐ ์ํ ๊ฒ์ด๋ค.
BLEU Score๊ฐ ๋์์๋ก ์ข์ ๋ชจ๋ธ์์ ์๋ฏธํ๋ค.
๋ํ, ์ค์ ์ฑ๋ฅ์ธก์ ์, BLEU๋ฅผ ์ง์ ๊ตฌํํ๊ธฐ๋ณด๋จ SMT Framework์ธ MOSES์ ํฌํจ๋ multi-bleu.perl์ ์ฃผ๋ก ์ฌ์ฉํ๋ค.
METEOR (Metric for Evaluation of Translation with Explicit ORdering)
METEOR๋ NMT ๋ฐ NLP์์ฑ ์์ ์์ ์ฌ์ฉ๋๋ ์๋ ํ๊ฐ์งํ ์ค ํ๋์ด๋ค.
METEOR๋ ๋ฒ์ญ ํ์ง์ ์ธก์ ํ๊ณ ์ฐธ์กฐ(reference) ๋ฒ์ญ๊ณผ ์์ฑ๋ ๋ฒ์ญ ๊ฐ์ ์ ์ฌ์ฑ์ ํ๋จํ๋ค.
METEOR๋ BLEU์ ์ ์ฌํ ๋ชฉ์ ์ ๊ฐ์ง๊ณ ์์ง๋ง ๋ช ๊ฐ์ง ์ค์ํ ์ฐจ์ด์ ์ด ์๋ค.
Main ์ฃผ์์ )
BLEU๋ก๋ถํฐ ํ์๋ ๋ฐฉ๋ฒ์ผ๋ก BLEU์ ๋ถ์์ ํจ์ ๋ณด์ํ๊ณ ๋ ๋์ ์งํ๋ฅผ ์ค๊ณํ๊ณ ์
BLEU์์ ๊ฐ์ฅ ํฐ ์ฐจ์ด์ ์ผ๋ก precision๋ง ๊ณ ๋ คํ๋ BLEU์๋ ๋ฌ๋ฆฌ
- recall๋ ํจ๊ป ๊ณ ๋ คํ๋ค๋ ๊ฒ์ด๋ค.
- ์ถ๊ฐ์ ์ผ๋ก ๋ค๋ฅธ ๊ฐ์ค์น๋ฅผ ์ ์ฉํ ์ด ๋์ ์กฐํํ๊ท ์ ์ฑ๋ฅ ๊ณ์ฐ์ ํ์ฉํ๊ณ ,
- ์ค๋ต์ ๋ํด ๋ณ๋์ penalty๋ฅผ ๋ถ๊ณผํ๋ ๋ฐฉ์์ ์ฑํํ๊ฑฐ๋,
- ์ฌ๋ฌ ๋จ์ด๋ ๊ตฌ ๋ฑ์ ์ ๋ต์ผ๋ก ์ฒ๋ฆฌํ๋ ๋ฑ BLEU๋ฅผ ๋ณด์ํ๊ณ ์ ํ๋ค.
METEOR์ ์ฃผ์ ํน์ง๊ณผ ์๋ ๋ฐฉ์์ ๋ค์๊ณผ ๊ฐ๋ค:
โ ํญ๋ชฉ ์ ํ๋ (Precision): METEOR๋ ๋จ์ด๋ ๊ตฌ์ ์์ค์ ์ผ์น๋ฅผ ์ธก์ ํ๋ค.
๋ฒ์ญ ํ๋ณด์ ์ฐธ์กฐ ๋ฒ์ญ ๊ฐ์ ๊ณตํต๋ ํ ํฐ (๋จ์ด ๋๋ ๊ตฌ์ ) ์๋ฅผ ๊ณ์ฐํ๋ค.
์ด๊ฒ์ "ํญ๋ชฉ ์ ํ๋" ๋๋ "Precision"์ผ๋ก ์๋ ค์ ธ ์์ต๋๋ค.
์์์ ์๋์ ๊ฐ๋ค.
โ ๊ธฐํ ํ๊ท F1 ์ ์: METEOR๋ ์ ํ๋์ ๋ฆฌ์ฝ(recall) ๊ฐ์ ๊ท ํ์ ์ธก์ ํ๊ธฐ ์ํด ์ ํ๋์ ๋ฆฌ์ฝ์ ์กฐํ ํ๊ท ์ธ F1 ์ ์๋ฅผ ๊ณ์ฐํฉ๋๋ค. ์ด๊ฒ์ ๋ฒ์ญ ํ๋ณด์ ์ ํ์ฑ๊ณผ ์ฐธ์กฐ ๋ฒ์ญ๊ณผ์ ์ ์ฌ์ฑ์ ๋ชจ๋ ๊ณ ๋ คํฉ๋๋ค.
์์์ ์๋์ ๊ฐ๋ค.
penalty์ ๊ฒฝ์ฐ, ์๋์ ๊ฐ์ด ๊ณ์ฐ๋๋ฉฐ
์ต์ข ์ ์ธ METEOR Score๋ ์๋์ ๊ฐ๋ค.
โ ๋ฒ์ญ ํ๋ณด์ ์ฐธ์กฐ ๋ฒ์ญ ์ฌ์ด์ ์ดํ์ ๊ตฌ์กฐ์ ๋ณํ: METEOR์ ๋จ์ด์ ์์์ ๊ตฌ์กฐ์ ๋ณ๊ฒฝ์ ํฌํจํ ๋ค์ํ ์ดํ์ ๋ฌธ๋ฒ์ ์ธ ๋ณํ๋ฅผ ๊ณ ๋ คํฉ๋๋ค. ์ด๊ฒ์ ๋จ์ด ์์๋ฅผ ๋ณด๋ค ๊ฐ์กฐํ๋ ํน์ง์ด ์์ผ๋ฉฐ, ๋ฒ์ญ ํ๋ณด์ ์ฐธ์กฐ ๋ฒ์ญ ๊ฐ์ ๊ณตํต ์ดํ ๋ฐ ๊ตฌ์กฐ๋ฅผ ๋น๊ตํฉ๋๋ค.
โ ์ดํ, ๊ตฌ์ ์ ๋ ฌ ๋ฐ ๋์์ด ์ฒ๋ฆฌ: METEOR์ ์ดํ์ ๋์์ด ์ฒ๋ฆฌ๋ฅผ ์ํํ๋ฉฐ, ๊ตฌ์ ์ ๋ ฌ๊ณผ ์ ์ฌ์ฑ์ ๊ณ์ฐํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
METEOR์ ๊ฒฐ๊ณผ๋ 0๊ณผ 1 ์ฌ์ด์ ์ ์๋ก ๋ํ๋ฉ๋๋ค. ๋์ METEOR ์ ์๋ ๋ฒ์ญ์ด ์ฐธ์กฐ ๋ฒ์ญ๊ณผ ์ ์ฌํ๋ค๋ ๊ฒ์ ๋ํ๋ด๋ฉฐ, ๋์ ๋ฒ์ญ ํ์ง์ ์๋ฏธํฉ๋๋ค. METEOR๋ ์ฃผ๋ก ๋ฒ์ญ ํ์ง ํ๊ฐ๋ฅผ ์ํด ์ฌ์ฉ๋๋ฉฐ, ๋ค์ํ ์์ฐ์ด ์ฒ๋ฆฌ ์์ ์์๋ ์ ์ฉ๋ ์ ์์ต๋๋ค.
ROUGE (Recall-Oriented Understudy for Gisting Evaluation)
ROUGE๋ NLP์์ ์ฌ์ฉ๋๋ ์๋ ํ๊ฐ์งํ ์ค ํ๋๋ก, ํ ์คํธ ์์ฑ ๋ฐ ์๋ ์์ฝ ์์ ์์ ๋ง์ด ์ฌ์ฉ๋๋ค.
ROUGE๋ ์์ฑ๋ ํ ์คํธ ๋๋ ์์ฝ๊ณผ ๊ธฐ์ค(reference) ํ ์คํธ ๊ฐ์ ์ ์ฌ์ฑ์ ์ธก์ ํ๊ณ ํ๊ฐํ๋ ๋ฐ ์ฌ์ฉ๋๋ค.
ROUGE๋ 5๊ฐ์ ํ๊ฐ ์งํ๊ฐ ์๋ค.
- ROUGE-N
- ROUGE-L
- ROUGE-W
- ROUGE-S
- ROUGE-SU
โ ROUGE-N (ROUGE-Ngram)
ROUGE-N ๋ฉํธ๋ฆญ์ N-gram (์ฐ์๋ n๊ฐ์ ๋จ์ด ๋๋ ๋ฌธ์) ์ผ์น๋ฅผ ์ธก์ ํ๋ค.
์ผ๋ฐ์ ์ผ๋ก ROUGE-1, ROUGE-2, ROUGE-3 ๋ฑ๊ณผ ๊ฐ์ด ์ง์ ๋ n-gram ๊ธธ์ด๋ฅผ ๋ํ๋ธ๋ค.
์ฆ, ROUGE-N์ ์์ธกํ ์์ฝ๋ฌธ๊ณผ ์ค์ ์์ฝ๋ฌธ๊ฐ์ N-gram์ Recall๊ฐ์ผ๋ก ์ฝ๊ฒ ๋ํ๋ด๋ฉด ์๋์ ๊ฐ๋ค.
โฃ ROUGE-1: 1-gram (unigram) = ๋จ์ด ๋จ์์ ์ผ์น๋ฅผ ๊ณ์ฐ
โฃ ROUGE-2: 2-gram (bigram) ๋จ์์ ์ผ์น๋ฅผ ๊ณ์ฐ
โฃ ROUGE-3: 3-gram (trigram) ๋จ์์ ์ผ์น๋ฅผ ๊ณ์ฐ
ROUGE-N์ ์ ์ฌ์ฑ์ ์ธก์ ํ๊ณ ๋จ์ด ์์๋ฅผ ๊ณ ๋ คํ์ง ์๋๋ค.
ex) ROUGE-1
โ์ค์ ์์ฝ๋ฌธ uni-gram: 'Korea', 'won', 'the', 'world', 'cup'
โ์์ธก ์์ฝ๋ฌธ uni-gram: 'Korea', 'won', 'the', 'soccer', 'world', 'cup', 'final'
์์ธก ์์ฝ๋ฌธ๊ณผ ์ค์ ์์ฝ๋ฌธ ์ฌ์ด์ ๊ฒน์น๋ uni-gram ์๋ 5์ด๊ณ ,
์ค์ ์์ฝ๋ฌธ์ uni-gram์๋ 5์ด๋ฏ๋ก
โROUGE-1 = 5/5 = 1
ex) ROUGE-2
โ์ค์ ์์ฝ๋ฌธ bi-gram: 'Korea won', 'won the', 'the world', 'world cup'
โ์์ธก ์์ฝ๋ฌธ bi-gram: 'Korea won', 'won the', 'the soccer', 'soccer world', 'world cup', 'cup final'
์์ธก ์์ฝ๋ฌธ๊ณผ ์ค์ ์์ฝ๋ฌธ ์ฌ์ด์ ๊ฒน์น๋ bi-gram ์๋ 3์ด๊ณ ,
์ค์ ์์ฝ๋ฌธ์ bi-gram์๋ 4์ด๋ฏ๋ก
โROUGE-2 = 3/4 = 0.75