๐Ÿ“Œ ๋ชฉ์ฐจ

1. preview
2. Reinforcement Learning ๊ธฐ์ดˆ

3. Policy based RL
4. NLG์— RL ์ ์šฉ
5. RL์„ ํ™œ์šฉํ•œ Supervised Learning
6. RL์„ ํ™œ์šฉํ•œ Unsupervised Learning

๐Ÿ˜š ๊ธ€์„ ๋งˆ์น˜๋ฉฐ...

Reinforcement Learning์˜ ๊ฒฝ์šฐ, ๋งค์šฐ ๋ฐฉ๋Œ€ํ•œ ๋ถ„์•ผ์ด๊ธฐ์— ๊ทธ ๋ฐฉ๋Œ€ํ•œ ์˜์—ญ์˜ ์ผ๋ถ€๋ถ„์ธ Policy Gradient๋ฅผ ํ™œ์šฉํ•ด ์ž์—ฐ์–ด์ƒ์„ฑ(NLG; Natural Language Generation)์˜ ์„ฑ๋Šฅ์„ ๋Œ์–ด์˜ฌ๋ฆฌ๋Š” ๋ฐฉ๋ฒ•์„ ๋‹ค๋ค„๋ณผ ๊ฒƒ์ด๋‹ค.

๋จผ์ €, Reinforcement Learning์ด ๋ฌด์—‡์ธ์ง€, NLG์— ์™œ ํ•„์š”ํ•œ์ง€ ์ฐจ๊ทผ์ฐจ๊ทผ ๋‹ค๋ค„๋ณผ ๊ฒƒ์ด๋‹ค.

 

 

 

 

 

 


1. Preview

1.1 GAN (Generative Adversarial Network)
2016๋…„๋ถ€ํ„ฐ ์ฃผ๋ชฉ๋ฐ›์œผ๋ฉฐ 2017๋…„ ๊ฐ€์žฅ ํฐ ํ™”์ œ๊ฐ€ ๋˜์—ˆ๋˜ ๋ถ„์•ผ์ด๊ณ , ํ˜„์žฌ Stable Diffusion ๋“ฑ์œผ๋กœ ์ธํ•ด Vision์—์„œ ๊ฐ€์žฅ ํฐ ํ™”์ œ์˜ ๋ถ„์•ผ๋Š” ๋‹จ์—ฐ์ฝ” ์ƒ์„ฑ์ ๋Œ€์‹ ๊ฒฝ๋ง(GAN)์ด๋‹ค. ์ด๋Š” ๋ณ€๋ถ„์˜คํ† ์ธ์ฝ”๋”(VAE)์™€ ํ•จ๊ป˜ ์ƒ์„ฑ๋ชจ๋ธํ•™์Šต์„ ๋Œ€ํ‘œํ•˜๋Š” ๋ฐฉ๋ฒ• ์ค‘ ํ•˜๋‚˜์ด๋‹ค.
์ด๋ ‡๊ฒŒ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋Š” ์‹ค์ƒํ™œ์— ์ค‘์š”ํ•˜์ง€๋งŒ trainset์„ ์–ป๊ธฐ ํž˜๋“  ๋ฌธ์ œ๋“ค์˜ ํ•ด๊ฒฐ์— ํฐ ๋„์›€์„ ์ค„ ๊ฒƒ์ด๋ผ ๊ธฐ๋Œ€๋˜๊ณ  ์žˆ๋‹ค.
์œ„์˜ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ GAN์€ ์ƒ์„ฑ์ž(Generator)G์™€ ํŒ๋ณ„์ž(Discriminator) D๋ผ๋Š” 2๊ฐœ์˜ ๋ชจ๋ธ์„ ๊ฐ๊ธฐ๋‹ค๋ฅธ ๋ชฉํ‘œ๋ฅผ ๊ฐ–๊ณ  "๋™์‹œ์— ํ›ˆ๋ จ"์‹œํ‚จ๋‹ค.
๋‘ ๋ชจ๋ธ์ด ๊ท ํ˜•์„ ์ด๋ฃจ๋ฉด min/max ๊ฒŒ์ž„์„ ํŽผ์น˜๊ฒŒ ๋˜๋ฉด ์ตœ์ข…์ ์œผ๋กœ G๋Š” ํ›Œ๋ฅญํ•œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค.

 

1.2 GAN์„ ์ž์—ฐ์–ด์ƒ์„ฑ์— ์ ์šฉ
ํ•œ๋ฒˆ GAN์„ NLG์— ์ ์šฉํ•ด๋ณด์ž.
์˜ˆ๋ฅผ๋“ค์–ด, CE๋ฅผ ์‚ฌ์šฉํ•ด ๋ฐ”๋กœ ํ•™์Šตํ•˜๊ธฐ๋ณด๋‹จ

 - ์‹ค์ œ corpus์—์„œ ๋‚˜์˜จ ๋ฌธ์žฅ์ธ์ง€
 - seq2seq์—์„œ ๋‚˜์˜จ ๋ฌธ์žฅ์ธ์ง€
์œ„์˜ ๋‘ ๊ฒฝ์šฐ์— ๋Œ€ํ•œ ํŒ๋ณ„์ž D๋ฅผ ๋‘์–ด seq2seq์—์„œ ๋‚˜์˜จ ๋ฌธ์žฅ์ด ์ง„์งœ ๋ฌธ์žฅ๊ณผ ๊ฐ™์•„์ง€๋„๋ก ํ›ˆ๋ จํ•˜๋Š” ๋“ฑ์„ ์˜ˆ์‹œ๋กœ ๋“ค ์ˆ˜ ์žˆ๋‹ค.



๋‹ค๋งŒ, ์•„์‰ฝ๊ฒŒ๋„ ์ด ์ข‹์•„๋ณด์ด๋Š” ์•„์ด๋””์–ด๋Š” ๋ฐ”๋กœ ์ ์šฉํ•  ์ˆ˜ ์—†๋Š”๋ฐ,
seq2seq์˜ ๊ฒฐ๊ณผ๋Š” ์ด์‚ฐํ™•๋ฅ ๋ถ„ํฌ์ด๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
๋”ฐ๋ผ์„œ ์—ฌ๊ธฐ์„œ sampling์ด๋‚˜ argmax๋กœ ์–ป์–ด์ง€๋Š” ๊ฒฐ๊ณผ๋ฌผ์€ discreteํ•œ ๊ฐ’์ด๊ธฐ์— one-hot๋ฒกํ„ฐ๋กœ ํ‘œํ˜„๋˜์–ด์•ผ ํ•  ๊ฒƒ์ด๋‹ค.
ํ•˜์ง€๋งŒ ์ด ๊ณผ์ •์€ ํ™•๋ฅ ์ ์ธ ๊ณผ์ •(stochastic process)์œผ๋กœ ๊ธฐ์šธ๊ธฐ๋ฅผ ์—ญ์ „ํŒŒํ•  ์ˆ˜ ์—†๊ฑฐ๋‚˜ ๋ฏธ๋ถ„ ๊ฒฐ๊ณผ๊ฐ’์ด 0์ด๋‚˜ ๋ถˆ์—ฐ์†์ ์ธ ๊ฒฝ์šฐ๊ฐ€ ๋˜๊ธฐ์— D๊ฐ€ ๋งž์ถ˜ ์—ฌ๋ถ€๋ฅผ ์—ญ์ „ํŒŒ๋ฅผ ํ†ตํ•ด seq2seq G๋กœ ์ „๋‹ฌ๋  ์ˆ˜ ์—†๊ณ , ๊ฒฐ๊ณผ์ ์œผ๋กœ ํ•™์Šต์ด ๋ถˆ๊ฐ€๋Šฅํ•˜๋‹ค.

 

1.3 GAN๊ณผ ์ž์—ฐ์–ด์ƒ์„ฑ
1.1๊ณผ 1.2์—์„œ ๋งํ–ˆ๋“ฏ, GAN์€ Vision์—์„œ๋Š” ๋Œ€์„ฑ๊ณต์„ ์ด๋ค˜์ง€๋งŒ NLG์—์„œ๋Š” ์ ์šฉ์ด ์–ด๋ ค์šด๋ฐ, ์ด๋Š” ์ž์—ฐ์–ด ๊ทธ ์ž์ฒด์˜ ํŠน์ง•๋•Œ๋ฌธ์ด๋‹ค.
โˆ™ ์ด๋ฏธ์ง€๋Š” ์–ด๋–ค "์—ฐ์†์ "์ธ ๊ฐ’๋“ค๋กœ ์ฑ„์›Œ์ง„ ํ–‰๋ ฌ
โˆ™ ์–ธ์–ด๋Š” ๋ถˆ์—ฐ์†์ ์ธ ๊ฐ’๋“ค์˜ ์ˆœ์ฐจ์  ๋ฐฐ์—ด์ด๊ธฐ์— ์šฐ๋ฆฐ LM์„ ํ†ตํ•ด latent space์— ์—ฐ์†์  ๋ณ€์ˆ˜๋กœ ๊ทธ ๊ฐ’๋“ค์„ ์น˜ํ™˜ํ•ด ๋‹ค๋ฃฌ๋‹ค.
๊ฒฐ๊ตญ, ์™ธ๋ถ€์ ์œผ๋กœ ์–ธ์–ด๋ฅผ ํ‘œํ˜„ํ•˜๋ ค๋ฉด ์ด์‚ฐํ™•๋ฅ ๋ถ„ํฌ์™€ ๋ณ€์ˆ˜๋กœ ๋‚˜ํƒ€๋‚ด์•ผํ•˜๊ณ ,
๋ถ„ํฌ๊ฐ€ ์•„๋‹Œ ์–ด๋–ค sample๋กœ ํ‘œํ˜„ํ•˜๋ ค๋ฉด ์ด์‚ฐํ™•๋ฅ ๋ถ„ํฌ์—์„œ samplingํ•˜๋Š” ๊ณผ์ •์ด ํ•„์š”ํ•˜๋‹ค.
์ด๋Ÿฐ ์ด์œ ๋กœ D์˜ loss๋ฅผ G์— ์ „๋‹ฌํ•  ์ˆ˜ ์—†๊ณ , ๋”ฐ๋ผ์„œ ์ ๋Œ€์ ์‹ ๊ฒฝ๋ง๋ฐฉ๋ฒ•์„ NLG์— ์ ์šฉํ•  ์ˆ˜ ์—†๋‹ค๋Š” ์ธ์‹์ด ์ฃผ๋ฅผ ์ด๋ฃจ๊ฒŒ ๋˜์—ˆ๋‹ค.
But!! ๊ฐ•ํ™”ํ•™์Šต์„ ํ†ตํ•ด ์ ๋Œ€์  ํ•™์Šต๋ฐฉ์‹์„ ์šฐํšŒ์ ์œผ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋˜์—ˆ๋‹ค.

 

1.4 ๊ฐ•ํ™”ํ•™์Šต ์‚ฌ์šฉ์ด์œ 
์–ด๋–ค taskํ•ด๊ฒฐ์„ ์œ„ํ•ด CE๋ฅผ ์“ธ ์ˆ˜ ์žˆ๋Š” classification์ด๋‚˜ continuous๋ณ€์ˆ˜๋ฅผ ๋‹ค๋ฃจ๋Š” MSE๋“ฑ์œผ๋กœ๋Š” ์ •์˜ํ•  ์ˆ˜ ์—†๋Š” ๋ณต์žกํ•œ ๋ชฉ์ ํ•จ์ˆ˜๊ฐ€ ๋งŽ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

์ฆ‰, ์šฐ๋ฆฌ๋Š” CE๋‚˜ MSE๋กœ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ–ˆ์ง€๋งŒ, ์ด๋Š” ๋ฌธ์ œ๋ฅผ ๋‹จ์ˆœํ™”ํ•ด ์ ‘๊ทผํ–ˆ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.
์ด๋Ÿฐ ๋ฌธ์ œ๋“ค์„ ๊ฐ•ํ™”ํ•™์Šต์„ ํ†ตํ•ด ํ•ด๊ฒฐํ•˜๊ณ , ์„ฑ๋Šฅ์„ ๊ทน๋Œ€ํ™”ํ•  ์ˆ˜ ์žˆ๋‹ค.

์ด๋ฅผ ์œ„ํ•ด ์ž˜ ์„ค๊ณ„๋œ ๋ณด์ƒ(reward)์„ ์‚ฌ์šฉํ•ด ๋” ๋ณต์žกํ•˜๊ณ  ์ •๊ตํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋‹ค.

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 


2. Reinforcement Learning ๊ธฐ์ดˆ

Reinforcement Learning์€ ์ด๋ฏธ ์˜ค๋ž˜์ „ Machine Learning์˜ ํ•œ ์ข…๋ฅ˜๋กœ ๋‚˜์™”๋˜ ๋ฐฉ๋Œ€ํ•˜๊ณ ๋„ ์œ ์„œ๊นŠ์€ ํ•™๋ฌธ์ด๊ธฐ์— ์ด๋ฒˆ์— ํ•˜๋‚˜์˜ ๊ธ€๋กœ๋Š” ๋‹ค๋ฃจ๊ธฐ ๋ฌด๋ฆฌ๊ฐ€ ์žˆ๋‹ค.

 

๋”ฐ๋ผ์„œ ์ด๋ฒˆ์‹œ๊ฐ„์— ๋‹ค๋ฃฐ, Policy Gradient์— ๋Œ€ํ•œ ์ดํ•ด๋ฅผ ์œ„ํ•ด ํ•„์š”ํ•œ ์ •๋„๋งŒ ์ดํ•ดํ•˜๊ณ  ๋„˜์–ด๊ฐ€๋ณด๋ ค ํ•œ๋‹ค.

์ถ”๊ฐ€์ ์ธ Reinforcement Learning์— ๋Œ€ํ•œ ๊ธฐํƒ€ ์ถ”์ฒœ์€ ์•„๋ž˜ ๋งํฌ ์ฐธ๊ณ .

⌈Reinforcement Learning: An Introduction ; (MIT Press, 2018)⌋

2.1 Universe
๋จผ์ € ๊ฐ•ํ™”ํ•™์Šต์ด ๋™์ž‘ํ•˜๋Š” ๊ณผ์ •๋ถ€ํ„ฐ ์•Œ์•„๋ณด์ž.

Q: ๊ฐ•ํ™”ํ•™์Šต์ด๋ž€??
A: ์–ด๋–ค ๊ฐ์ฒด๊ฐ€ ์ฃผ์–ด์ง„ ํ™˜๊ฒฝ์—์„œ, ์ƒํ™ฉ์— ๋”ฐ๋ผ ์–ด๋–ป๊ฒŒ ํ–‰๋™ํ•ด์•ผํ•  ์ง€ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•
๊ฐ•ํ™”ํ•™์Šต ๋™์ž‘๊ณผ์ •
์ฒ˜์Œ ์ƒํƒœ์ธ St(t=0)์„ ๋ฐ›๊ณ , 
โˆ™Agent๋Š” ์ž์‹ ์˜ policy์— ๋”ฐ๋ผ action At๋ฅผ ์„ ํƒํ•œ๋‹ค.
โˆ™Environment๋Š” Agent๋กœ๋ถ€ํ„ฐ ์„ ํƒ๋œ At๋ฅผ ๋ฐ›์•„
๋ณด์ƒ Rt+1๊ณผ ์ƒˆ๋กญ๊ฒŒ ๋ฐ”๋€ ์ƒํƒœ St+1์„ ๋ฐ˜ํ™˜ํ•œ๋‹ค.

์ด ๊ณผ์ •์„ ํŠน์ •์กฐ๊ฑด์ด ๋งŒ์กฑ๋  ๋•Œ๊นŒ์ง€ ๋ฐ˜๋ณตํ•˜๋ฉฐ, ํ™˜๊ฒฝ์€ ์ด ์‹œํ€€์Šค๋ฅผ ์ข…๋ฃŒํ•œ๋‹ค.
์ด๋ฅผ ํ•˜๋‚˜์˜ episode๋ผ ํ•œ๋‹ค.

๋ชฉํ‘œ: ๋ฐ˜๋ณต๋˜๋Š” episode์—์„œ agent๊ฐ€ RL์„ ํ†ตํ•ด ์ ์ ˆํ•œ ํ–‰๋™(๋ณด์ƒ์ด ์ตœ๋Œ€๊ฐ€ ๋˜๋„๋ก)์„ ํ•˜๋„๋ก ํ›ˆ๋ จ์‹œํ‚ค๋Š” ๊ฒƒ

 

2.2 MDP (Markov Decision Process)
์—ฌ๊ธฐ์— ๋”ํ•ด Markov๊ฒฐ์ •๊ณผ์ •(MDP)๋ผ๋Š” ๊ฐœ๋…์„ ๋„์ž…ํ•œ๋‹ค.
์šฐ๋ฆฐ ์˜จ ์„ธ์ƒ์˜ ํ˜„์žฌ T=t๋ผ๋Š” ์ˆœ๊ฐ„์„ ํ•˜๋‚˜์˜ ์ƒํƒœ(state)๋กœ ์ •์˜ํ•  ์ˆ˜ ์žˆ๋‹ค.

๊ฐ€์ •) ํ˜„์žฌ์ƒํƒœ(present state)๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ, ๋ฏธ๋ž˜(T>t)๋Š” ๊ณผ๊ฑฐ(T<t)๋กœ๋ถ€ํ„ฐ "๋…๋ฆฝ์ ".
์ด ๊ฐ€์ • ํ•˜์—, ์˜จ ์„ธ์ƒ์€ Markov๊ณผ์ • ์œ„์—์„œ ๋™์ž‘ํ•œ๋‹ค ํ•  ์ˆ˜ ์žˆ๋‹ค.
์ด๋•Œ, ํ˜„์žฌ์ƒํ™ฉ์—์„œ ๋ฏธ๋ž˜์ƒํ™ฉ์œผ๋กœ ๋ฐ”๋€” ํ™•๋ฅ ์„ P(S' | S)๋กœ ํ‘œํ˜„๊ฐ€๋Šฅํ•˜๋‹ค.

์—ฌ๊ธฐ์— MDP๋Š” ๊ฒฐ์ •์„ ๋‚ด๋ฆฌ๋Š” ๊ณผ์ •(= ํ–‰๋™์„ ์„ ํƒํ•˜๋Š” ๊ณผ์ •)์ด ์ถ”๊ฐ€๋œ ๊ฒƒ์ด๋‹ค.
์ฆ‰, ํ˜„์žฌ์ƒํ™ฉ์—์„œ ์–ด๋–คํ–‰๋™์„ ์„ ํƒ ์‹œ, ๋ฏธ๋ž˜์ƒํ™ฉ์œผ๋กœ ๋ฐ”๋€” ํ™•๋ฅ ๋กœ P(S' | S, A)์ด๋‹ค.


์‰ฝ๊ฒŒ ์„ค๋ช…ํ•˜์ž๋ฉด, ์•„๋ž˜ ๊ฐ€์œ„๋ฐ”์œ„๋ณด๊ฒŒ์ž„์„ ์˜ˆ์‹œ๋กœ ๋“ค ์ˆ˜ ์žˆ๋‹ค.
์˜ˆ๋ฅผ๋“ค์–ด ์‚ฌ๋žŒ๋งˆ๋‹ค ๊ฐ€์œ„๋ฐ”์œ„๋ณด๋ฅผ ๋‚ด๋Š” ํŒจํ„ด์ด ๋‹ค๋ฅด๊ธฐ์— ๊ฐ€์œ„๋ฐ”์œ„๋ณด ์ƒ๋Œ€๋ฐฉ์— ๋”ฐ๋ผ ์ „๋žต์ด ๋ฐ”๋€Œ์–ด์•ผ ํ•˜๋ฏ€๋กœ
โˆ™ ์ƒ๋Œ€๋ฐฉ์ด ์ฒซ ์ƒํƒœ S0๋ฅผ ๊ฒฐ์ •ํ•œ๋‹ค.
โˆ™ Agent๋Š” S0์— ๋”ฐ๋ผ ํ–‰๋™ A0๋ฅผ ์„ ํƒํ•œ๋‹ค.
โˆ™ ์ƒ๋Œ€๋ฐฉ์€ ์ƒ๋Œ€๋ฐฉ์˜ ์ •์ฑ…์— ๋”ฐ๋ผ ๊ฐ€์œ„/๋ฐ”์œ„/๋ณด ์ค‘ ํ•˜๋‚˜๋ฅผ ๋‚ธ๋‹ค.
โˆ™ ์ŠนํŒจ๊ฐ€ ๊ฒฐ์ •, ํ™˜๊ฒฝ์œผ๋กœ๋ถ€ํ„ฐ ๋ฐฉ๊ธˆ ์„ ํƒํ•œ ํ–‰๋™์— ๋Œ€ํ•œ ๋ณด์ƒ R1์„ ๋ฐ›๋Š”๋‹ค.
โˆ™ update๋œ ์ƒํƒœ S1์„ ์–ป๋Š”๋‹ค.

 

2.3 Reward
์•ž์„œ Agent๊ฐ€ ์–ด๋–ค ํ–‰๋™์„ ํ–ˆ์„ ๋•Œ, ํ™˜๊ฒฝ์œผ๋กœ๋ถ€ํ„ฐ "๋ณด์ƒ"์„ ๋ฐ›๋Š”๋‹ค.
์ด๋•Œ, ์šฐ๋ฆฐ Gt๋ฅผ ์–ด๋–ค ์‹œ์ ์œผ๋กœ๋ถ€ํ„ฐ ๋ฐ›๋Š” ๋ณด์ƒ์˜ ๋ˆ„์ ํ•ฉ์ด๋ผ ์ •์˜ํ•˜์ž.
๋”ฐ๋ผ์„œ Gt๋Š” ์•„๋ž˜์™€ ๊ฐ™์ด ์ •์˜๋œ๋‹ค.

์ด๋•Œ ๊ฐ์†Œ์œจ(discount factor) γ(0~1๊ฐ’)๋ฅผ ๋„์ž…ํ•ด ์ˆ˜์‹์„ ์กฐ๊ธˆ ๋ณ€ํ˜•ํ•  ์ˆ˜ ์žˆ๋‹ค.
γ์˜ ๋„์ž…์œผ๋กœ ๋จผ ๋ฏธ๋ž˜์˜ ๋ณด์ƒ๋ณด๋‹ค ๊ฐ€๊นŒ์šด ๋ฏธ๋ž˜์˜ ๋ณด์ƒ์„ ๋” ์ค‘์‹œํ•ด ๋‹ค๋ฃฐ ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค.

 

2.4 Policy
Agent๋Š” ์ฃผ์–ด์ง„ ์ƒํƒœ์—์„œ ์•ž์œผ๋กœ ๋ฐ›์„ ๋ณด์ƒ์˜ ๋ˆ„์ ํ•ฉ์„ ์ตœ๋Œ€๋กœ ํ•˜๋„๋ก ํ–‰๋™ํ•ด์•ผํ•œ๋‹ค.
์ฆ‰, ๋ˆˆ์•ž์˜ ์ž‘์€ ์ž‘์€ ์†ํ•ด๋ณด๋‹ค ๋จผ๋ฏธ๋ž˜๊นŒ์ง€ ํฌํ•จํ•œ ๋ณด์ƒ์˜ ์ดํ•ฉ์ด ์ตœ๋Œ€๊ฐ€ ๋˜๋Š”๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.
(๐Ÿ˜‚ ๋งˆ์น˜ ์šฐ๋ฆฌ๊ฐ€ ์‹œํ—˜๊ธฐ๊ฐ„์— ๋†€์ง€๋ชปํ•˜๊ณ  ๊ณต๋ถ€ํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ...?)

์ •์ฑ…(policy)์€ Agent๊ฐ€ ์ƒํ™ฉ์— ๋”ฐ๋ผ ์–ด๋–ป๊ฒŒ ํ–‰๋™์„ ํ•ด์•ผํ•  ์ง€, "ํ™•๋ฅ ์ ์œผ๋กœ ๋‚˜ํƒ€๋‚ธ ๊ธฐ์ค€"์ด๋‹ค.
์ฆ‰, ๊ฐ™์€ ์ƒํ™ฉ์ด ์ฃผ์–ด์กŒ์„ ๋•Œ, ์–ด๋–ค ํ–‰๋™์„ ์„ ํƒํ• ์ง€์— ๋Œ€ํ•œ ํ™•๋ฅ ํ•จ์ˆ˜์ด๊ธฐ์— ๋”ฐ๋ผ์„œ ์šฐ๋ฆฌ๊ฐ€ ํ–‰๋™ํ•˜๋Š” ๊ณผ์ •์€ ํ™•๋ฅ ์ ์ธ ํ”„๋กœ์„ธ์Šค๋ผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ์ด๋•Œ, ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ์ฃผ์–ด์ง„ ํ–‰๋™์„ ์ทจํ•  ํ™•๋ฅ ๊ฐ’์„ ์•„๋ž˜ ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ๊ตฌํ•  ์ˆ˜ ์žˆ๋‹ค.

 

2.5 Value Function (๊ฐ€์น˜ํ•จ์ˆ˜)
๊ฐ€์น˜ํ•จ์ˆ˜๋ž€ ์ฃผ์–ด์ง„ policy ๐›‘ ๋‚ด์— ํŠน์ • ์ƒํƒœ s์—์„œ๋ถ€ํ„ฐ ์•ž์œผ๋กœ ์–ป์„ ์ˆ˜ ์žˆ๋Š” ๋ณด์ƒ์˜ ๋ˆ„์ ์ดํ•ฉ์˜ ๊ธฐ๋Œ“๊ฐ’์„ ์˜๋ฏธํ•œ๋‹ค.
์•„๋ž˜ ์ˆ˜์‹๊ณผ ๊ฐ™์ด ๋‚˜ํƒ€๋‚ผ ์ˆ˜ ์žˆ๋Š”๋ฐ, ์•ž์œผ๋กœ ์–ป์„ ์ˆ˜ ์žˆ๋Š” ๋ณด์ƒ์˜ ์ดํ•ฉ์˜ ๊ธฐ๋Œ“๊ฐ’์€ ๊ธฐ๋Œ€๋ˆ„์ ๋ณด์ƒ(Expected Cumulative Reward)๋ผ๊ณ ๋„ ํ•œ๋‹ค.
ํ–‰๋™๊ฐ€์น˜ํ•จ์ˆ˜ (Q-function ; Q ํ•จ์ˆ˜)

ํ–‰๋™๊ฐ€์น˜ํ•จ์ˆ˜(activation-value function ; Q-function)๋Š” ์ฃผ์–ด์ง„ policy ๐›‘ ์•„๋ž˜ ์ƒํ™ฉ s์—์„œ action a๋ฅผ ์„ ํƒํ–ˆ์„ ๋•Œ, ์•ž์œผ๋กœ ์–ป์„ ์ˆ˜ ์žˆ๋Š” ๋ณด์ƒ์˜ ๋ˆ„์ ํ•ฉ์˜ ๊ธฐ๋Œ“๊ฐ’(๊ธฐ๋Œ€๋ˆ„์ ๋ณด์ƒ)์„ ํ‘œํ˜„ํ•œ๋‹ค.

๊ฐ€์น˜ํ•จ์ˆ˜๊ฐ€ ์–ด๋–ค s์—์„œ ์–ด๋–ค a๋ฅผ ์„ ํƒํ• ์ง€์™€ ๊ด€๊ณ„์—†์ด ์–ป์„ ์ˆ˜ ์žˆ๋Š” ๋ˆ„์ ๋ณด์ƒ์˜ ๊ธฐ๋Œ“๊ฐ’์ด๋ผ ํ•œ๋‹ค๋ฉด
Qํ•จ์ˆ˜๋Š” ์–ด๋–ค a๋ฅผ ์„ ํƒํ•˜๋Š”๊ฐ€์— ๋Œ€ํ•œ ๊ฐœ๋…์ด ์ถ”๊ฐ€๋œ ๊ฒƒ์ด๋‹ค.
์ฆ‰, ์ƒํƒœ์™€ ํ–‰๋™์— ๋”ฐ๋ฅธ ๊ธฐ๋Œ€๋ˆ„์ ๋ณด์ƒ์„ ๋‚˜ํƒ€๋‚ด๋Š” Qํ•จ์ˆ˜์˜ ์‹์€ ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

 

2.6 Bellman ๋ฐฉ์ •์‹
๊ฐ€์น˜ํ•จ์ˆ˜์™€ ํ–‰๋™๊ฐ€์น˜ํ•จ์ˆ˜์˜ ์ •์˜์— ๋”ฐ๋ผ ์ด์ƒ์ ์ธ ๊ฐ€์น˜ํ•จ์ˆ˜์™€ ์ด์ƒ์ ์ธ Qํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ด๋ณด๊ณ ์ž ํ•œ๋‹ค๋ฉด...?
→ Bellman Equation(๋ฒจ๋งŒ ๋ฐฉ์ •์‹)์„ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋‚˜ํƒ€๋‚ผ ์ˆ˜ ์žˆ๋‹ค.
์ขŒ) Back-Tracking / ์šฐ) Dynamic Programming.&nbsp; //. Back-Tracking์˜ ๊ฒฝ์šฐ, ๋ชจ๋“  ๊ฒฝ์šฐ์˜ ์ˆ˜๋ฅผ ์ „๋ถ€ ํƒ์ƒ‰ํ•ด์•ผ ํ•œ๋‹ค.

DP: ๋ฌธ์ œ๋ฅผ ๊ฒน์น˜๋Š” ํ•˜์œ„ ๋ฌธ์ œ(sub-problems)๋กœ ๋ถ„ํ•ดํ•˜๊ณ  ์ตœ์  ๋ถ€๋ถ„ ๊ตฌ์กฐ(optimal substructure)๋ฅผ ๋”ฐ๋ฅด๋Š” ๋ฐฉ๋ฒ•์œผ๋กœ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๊ธฐ์ˆ ๋กœ ํฐ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๊ฒƒ์€ ์ž‘์€ ํ•˜์œ„ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๋‚˜๋ˆ„์–ด์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์ž‘์€ ํ•˜์œ„ ๋ฌธ์ œ๋“ค์„ ํ•ด๊ฒฐํ•œ ํ›„์—๋Š” ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ์กฐํ•ฉํ•˜์—ฌ ์›๋ž˜ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•ฉ๋‹ˆ๋‹ค. ์ค‘์š”ํ•œ ์ ์€ ๋™์ผํ•œ ํ•˜์œ„ ๋ฌธ์ œ๊ฐ€ ์—ฌ๋Ÿฌ ๋ฒˆ ๊ณ„์‚ฐ๋˜๋Š” ๋Œ€์‹ , ํ•œ ๋ฒˆ ๊ณ„์‚ฐ๋œ ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•˜๊ณ  ์žฌ์‚ฌ์šฉํ•˜์—ฌ ๊ณ„์‚ฐ ๋น„์šฉ์„ ์ค„์ด๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด๊ฒƒ์ด DP์˜ ํ•ต์‹ฌ ์•„์ด๋””์–ด์ด๋ฉฐ, ๊ฒน์น˜๋Š” ํ•˜์œ„ ๋ฌธ์ œ๊ฐ€ ์žˆ๋‹ค๋ฉด DP๋ฅผ ํšจ๊ณผ์ ์œผ๋กœ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

Bellman๋ฐฉ์ •์‹์€ ๋™์ ํ”„๋กœ๊ทธ๋ž˜๋ฐ(DP; Dynamic Programming)์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋ฌธ์ œ๋กœ ์ ‘๊ทผ๊ฐ€๋Šฅํ•˜๋‹ค.
์ฆ‰, ๋‹จ์ˆœํžˆ ์ตœ๋‹จ๊ฒฝ๋กœ๋ฅผ ์ฐพ๋Š” ๋ฌธ์ œ์™€ ๋น„์Šทํ•œ๋ฐ, ์•„๋ž˜ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ ๋ชจ๋“  ๊ฒฝ์šฐ์— ๋Œ€ํ•ด ํƒ์ƒ‰์„ ์ˆ˜ํ–‰ํ•˜๋Š” Back-Tracking๋ณด๋‹ค ํ›จ์”ฌ ํšจ์œจ์ ์ด๊ณ  ๋น ๋ฅด๊ฒŒ ๋ฌธ์ œํ’€์ด์— ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๋‹ค.

 

2.7 Monte Carlo Method
ํ•˜์ง€๋งŒ ๊ทธ๋ ‡๊ฒŒ ์‰ฝ๊ฒŒ ๋ฌธ์ œ๊ฐ€ ํ’€๋ฆฌ๋ฉด ๊ฐ•ํ™”ํ•™์Šต์ด ํ•„์š”ํ•˜์ง„ ์•Š์•˜์„ ๊ฒƒ์ด๋‹ค.
Prob) ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ, ์œ„์˜ 2.6์˜ ์ˆ˜์‹์—์„œ ๊ฐ€์žฅ ์ค‘์š”ํ•œ P(s', r | s, a)๋ถ€๋ถ„์„ ๋ชจ๋ฅธ๋‹ค๋Š” ์ ์ด๋‹ค.
์ฆ‰, ์–ด๋–ค ์ƒํƒœ→์–ด๋–คํ–‰๋™→์–ด๋–คํ™•๋ฅ ๋กœ ๋‹ค๋ฅธ์ƒํƒœ s'๊ณผ ๋ณด์ƒ r'์„ ๋ฐ›๊ฒŒ๋˜๋Š”์ง€, "์ง์ ‘ ํ•ด๋ด์•ผ" ์•ˆ๋‹ค๋Š” ์ ์ด๋‹ค.
∴ DP๊ฐ€ ์ ์šฉ๋  ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋Œ€๋ถ€๋ถ„์ด๋‹ค.


๋”ฐ๋ผ์„œ RL์ฒ˜๋Ÿผ simulation๋“ฑ์˜ ๊ฒฝํ—˜์„ ํ†ตํ•ด Agent๋ฅผ ํ•™์Šตํ•ด์•ผํ•œ๋‹ค.
์ด๋Ÿฐ ์ด์œ ๋กœ Monte Carlo Method์ฒ˜๋Ÿผ sampling์„ ํ†ตํ•ด Bellman์˜ Expectation Equation์„ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋‹ค.
Prob) ๊ธด episode์— ๋Œ€ํ•ด sampling์œผ๋กœ ํ•™์Šตํ•˜๋Š” ๊ฒฝ์šฐ์ด๋‹ค.
์‹ค์ œ episode๊ฐ€ ๋๋‚˜์•ผ Gt๋ฅผ ๊ตฌํ•  ์ˆ˜ ์žˆ๊ธฐ์— episode์˜ ๋๊นŒ์ง€ ๊ธฐ๋‹ค๋ ค์•ผ ํ•œ๋‹ค.
๋‹ค๋งŒ, ๊ทธ๊ฐ„ ์ตํžˆ ๋ณด์•„์˜จ AlphaGo์ฒ˜๋Ÿผ ๊ต‰์žฅํžˆ ๊ธด episode์˜ ๊ฒฝ์šฐ, ๋งค์šฐ ๋งŽ์€ ์‹œ๊ฐ„๊ณผ ๊ธด ์‹œ๊ฐ„์ด ํ•„์š”ํ•˜๊ฒŒ ๋œ๋‹ค.

 

2.8 TD ํ•™์Šต (Temporal Difference Learning)
์ด๋•Œ, ์‹œ๊ฐ„์ฐจํ•™์Šต(TD)๋ฐฉ๋ฒ•์ด ์œ ์šฉํ•˜๋‹ค.
TDํ•™์Šต๋ฒ•์€ ๋‹ค์Œ์ˆ˜์‹์ฒ˜๋Ÿผ episode๋ณด์ƒ์˜ ๋ˆ„์ ํ•ฉ ์—†์ด๋„ ๋ฐ”๋กœ ๊ฐ€์น˜ํ•จ์ˆ˜๋ฅผ updateํ•  ์ˆ˜ ์žˆ๋‹ค.
Q-Learning
๋งŒ์•ฝ ์˜ฌ๋ฐ”๋ฅธ Qํ•จ์ˆ˜๋ฅผ ์•Œ๊ณ ์žˆ๋‹ค๋ฉด, ์–ด๋–ค์ƒํ™ฉ์ด๋”๋ผ๋„ ํ•ญ์ƒ ๊ธฐ๋Œ€๋ˆ„์ ๋ณด์ƒ์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ๋งค์šฐ ์ข‹์€ ์„ ํƒ์ด ๊ฐ€๋Šฅํ•˜๋‹ค.
์ด๋•Œ, Qํ•จ์ˆ˜๋ฅผ ์ž˜ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์„ Q-Learning์ด๋ผ ํ•œ๋‹ค.

์•„๋ž˜์ˆ˜์‹์ฒ˜๋Ÿผ target๊ณผ ํ˜„์žฌ๊ฐ€์น˜ํ•จ์ˆ˜(current)์˜ ์ฐจ์ด๋ฅผ ์ค„์ด๋ฉด ์˜ฌ๋ฐ”๋ฅธ Qํ•จ์ˆ˜๋ฅผ ํ•™์Šตํ•  ๊ฒƒ์ด๋‹ค.
DQN (Deep Q-Learning)
Qํ•จ์ˆ˜๋ฅผ ํ•™์Šตํ•  ๋•Œ, state๊ณต๊ฐ„์˜ ํฌ๊ธฐ์™€ action๊ณต๊ฐ„์˜ ํฌ๊ธฐ๊ฐ€ ๋„ˆ๋ฌด ์ปค ์ƒํ™ฉ๊ณผ ํ–‰๋™์ด ํฌ์†Œํ•  ๊ฒฝ์šฐ, ๋ฌธ์ œ๊ฐ€ ์ƒ๊ธด๋‹ค.
ํ›ˆ๋ จ๊ณผ์ •์—์„œ ํฌ์†Œ์„ฑ์œผ๋กœ ์ธํ•ด ์ž˜ ๋ณผ ์ˆ˜ ์—†๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
์ด์ฒ˜๋Ÿผ ์ƒํ™ฉ๊ณผ ํ–‰๋™์ด ๋ถˆ์—ฐ์†์ ์ธ ๋ณ„๊ฐœ์˜ ๊ฐ’์ด๋”๋ผ๋„, Qํ•จ์ˆ˜๋ฅผ ๊ทผ์‚ฌํ•˜๋ฉด ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋‹ค.

DeepMind๋Š” ์‹ ๊ฒฝ๋ง์„ ์‚ฌ์šฉํ•ด ๊ทผ์‚ฌํ•œ Q-Learning์„ ํ†ตํ•ด Atari๊ฒŒ์ž„์„ ํ›Œ๋ฅญํžˆ ํ”Œ๋ ˆ์ดํ•˜๋Š” ๊ฐ•ํ™”ํ•™์Šต๋ฐฉ๋ฒ•์„ ์ œ์‹œํ–ˆ๋Š”๋ฐ, ์ด๋ฅผ DQN(Deep Q-Learning)์ด๋ผ ํ•œ๋‹ค.
์‹ ๊ฒฝ๋ง์„ ํ™œ์šฉํ•œ Q-Learning ๊ฐœ์š”
์•„๋ž˜ ์ˆ˜์‹์ฒ˜๋Ÿผ Qํ•จ์ˆ˜ ๋ถ€๋ถ„์„ ์‹ ๊ฒฝ๋ง์œผ๋กœ ๊ทผ์‚ฌํ•ด ํฌ์†Œ์„ฑ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ–ˆ๊ณ , ์‹ฌ์ง€์–ด Atari๊ฒŒ์ž„์„ ์‚ฌ๋žŒ๋ณด๋‹ค ๋” ์ž˜ ํ”Œ๋ ˆ์ดํ•˜๋Š” Agent๋ฅผ ํ•™์Šตํ•˜๊ธฐ๋„ ํ–ˆ๋‹ค.

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 


3.  Policy based Reinforcement Learning

3.1 Policy Gradient
Policy Gradient๋Š” Policy based Reinforcement Learning๋ฐฉ์‹์— ์†ํ•œ๋‹ค.
cf) DQN์€ ๊ฐ€์น˜๊ธฐ๋ฐ˜(Value based Reinforcement Learning)๋ฐฉ์‹์— ์†ํ•œ๋‹ค.-ex) DeepMind๊ฐ€ ์‚ฌ์šฉํ•œ AlphaGo.

์ •์ฑ…๊ธฐ๋ฐ˜๊ณผ ๊ฐ€์น˜๊ธฐ๋ฐ˜ ๊ฐ•ํ™”ํ•™์Šต, ๋‘ ๋ฐฉ์‹์˜ ์ฐจ์ด๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.
โˆ™ ๊ฐ€์น˜๊ธฐ๋ฐ˜ํ•™์Šต: ANN์„ ์‚ฌ์šฉํ•ด ์–ด๋–ค ํ–‰๋™์„ ์„ ํƒ ์‹œ, ์–ป์„ ์ˆ˜ ์žˆ๋Š” ๋ณด์ƒ์„ ์˜ˆ์ธกํ•˜๋„๋ก ํ›ˆ๋ จ
โˆ™ ์ •์ฑ…๊ธฐ๋ฐ˜ํ•™์Šต: ANN์˜ ํ–‰๋™์— ๋Œ€ํ•œ ๋ณด์ƒ์„ ์—ญ์ „ํŒŒ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ํ†ตํ•ด ์ „๋‹ฌํ•ด ํ•™์Šต
∴ DQN์˜ ๊ฒฝ์šฐ, ํ–‰๋™์˜ ์„ ํƒ์ด ํ™•๋ฅ ์ (stochastic)์œผ๋กœ ๋‚˜์˜ค์ง€ ์•Š์ง€๋งŒ
Policy Gradient๋Š” ํ–‰๋™ ์„ ํƒ ์‹œ, ํ™•๋ฅ ์ ์ธ ๊ณผ์ •์„ ๊ฑฐ์นœ๋‹ค.


Policy Gradient์— ๋Œ€ํ•œ ์ˆ˜์‹์€ ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

์ด ์ˆ˜์‹์—์„œ ์•ž์˜ ๐œ‹๋Š” policy๋ฅผ ์˜๋ฏธํ•œ๋‹ค.
์ฆ‰, ์‹ ๊ฒฝ๋ง ๊ฐ€์ค‘์น˜ θ๋Š” ํ˜„์žฌ์ƒํƒœ s๊ฐ€ ์ฃผ์–ด์กŒ์„๋•Œ, ์–ด๋–คํ–‰๋™ a๋ฅผ ์„ ํƒํ•ด์•ผํ•˜๋Š”์ง€์— ๊ด€ํ•œ ํ™•๋ฅ ์„ ๋ฐ˜ํ™˜ํ•œ๋‹ค.


์šฐ๋ฆฌ์˜ ๋ชฉํ‘œ๋Š” ์ตœ์ดˆ์ƒํƒœ(initial state)์—์„œ์˜ ๊ธฐ๋Œ€๋ˆ„์ ๋ณด์ƒ์„ ์ตœ๋Œ€๋กœ ํ•˜๋Š” ์ •์ฑ… θ๋ฅผ ์ฐพ๋Š”๊ฒƒ์ด๊ณ 
์ตœ์†Œํ™”ํ•ด์•ผํ•˜๋Š” ์†์‹ค๊ณผ ๋‹ฌ๋ฆฌ ๋ณด์ƒ์€ ์ตœ๋Œ€ํ™”ํ•ด์•ผํ•˜๋ฏ€๋กœ
๊ธฐ์กด์˜ ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•๋Œ€์‹ , ๊ฒฝ์‚ฌ์ƒ์Šน๋ฒ•(Gradient Ascent)๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.


์ด๋Ÿฐ ๊ฒฝ์‚ฌ์ƒ์Šน๋ฒ•์— ๋”ฐ๋ผ θJ(θ)๋ฅผ ๊ตฌํ•ด θ๋ฅผ updateํ•ด์•ผํ•œ๋‹ค.
์ด๋•Œ, d(s)๋Š” Markov Chain์˜ ์ •์ ๋ถ„ํฌ(stationary distribution)๋กœ์จ ์‹œ์ž‘์ ์— ์ƒ๊ด€์—†์ด ์ „์ฒด๊ฒฝ๋กœ์—์„œ s์— ๋จธ๋ฌด๋ฅด๋Š” ์‹œ๊ฐ„๋น„์œจ์„ ์˜๋ฏธํ•œ๋‹ค.

์ด๋•Œ, ๋กœ๊ทธ๋ฏธ๋ถ„์˜ ์„ฑ์งˆ์„ ์ด์šฉํ•ด ์•„๋ž˜ ์ˆ˜์‹์ฒ˜๋Ÿผ θJ(θ)๋ฅผ ๊ตฌํ•  ์ˆ˜ ์žˆ๋‹ค.
์ด ์ˆ˜์‹์„ ํ•ด์„ํ•˜์ž๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.
โˆ™ ๋งค time-step๋ณ„ ์ƒํ™ฉ s๊ฐ€ ์ฃผ์–ด์งˆ ๋•Œ, a๋ฅผ ์„ ํƒํ•  ๋กœ๊ทธํ™•๋ฅ ์˜ ๊ธฐ์šธ๊ธฐ์™€ ๊ทธ์—๋”ฐ๋ฅธ ๋ณด์ƒ์„ ๊ณฑํ•œ ๊ฐ’์˜ ๊ธฐ๋Œ“๊ฐ’์ด ๋œ๋‹ค.

Policy Gradient Theorem์— ๋”ฐ๋ฅด๋ฉด
์—ฌ๊ธฐ์„œ ํ•ด๋‹น time-step์— ๋Œ€ํ•œ ์ฆ‰๊ฐ์  ๋ณด์ƒ r๋Œ€์‹ , episode ์ข…๋ฃŒ๊นŒ์ง€์˜ ๊ธฐ๋Œ€๋ˆ„์ ๋ณด์ƒ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.
์ฆ‰, Qํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ธ๋ฐ, ์ด๋•Œ Policy Gradient์˜ ์ง„๊ฐ€๊ฐ€ ๋“œ๋Ÿฌ๋‚œ๋‹ค.
์šฐ๋ฆฐ Policy Gradient ์‹ ๊ฒฝ๋ง์— ๋Œ€ํ•ด ๋ฏธ๋ถ„๊ณ„์‚ฐ์ด ํ•„์š”ํ•˜์ง€๋งŒ,
"Qํ•จ์ˆ˜์— ๋Œ€ํ•ด์„œ๋Š” ๋ฏธ๋ถ„ํ•  ํ•„์š”๊ฐ€ ์—†๋‹ค!!"
์ฆ‰, ๋ฏธ๋ถ„๊ฐ€๋Šฅ์—ฌ๋ถ€๋ฅผ ๋– ๋‚˜, ์ž„์˜์˜ ์–ด๋–ค ํ•จ์ˆ˜๋ผ๋„ ๋ณด์ƒํ•จ์ˆ˜๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒƒ์ด๋‹ค!!


์ด๋ ‡๊ฒŒ ์–ด๋–ค ํ•จ์ˆ˜๋“  ๋ณด์ƒํ•จ์ˆ˜๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ๋˜๋ฉด์„œ, ๊ธฐ์กด Cross-Entropy๋‚˜ MSE๊ฐ™์€ ์†์‹คํ•จ์ˆ˜๋กœ fittingํ•˜๋Š” ๋Œ€์‹ , ์ข€ ๋” ์‹ค๋ฌด์— ๋ถ€ํ•ฉํ•˜๋Š” ํ•จ์ˆ˜(ex. ๋ฒˆ์—ญ์˜ ๊ฒฝ์šฐ, BLEU)๋ฅผ ์‚ฌ์šฉํ•ด θ๋ฅผ ํ›ˆ๋ จ์‹œํ‚ฌ ์ˆ˜ ์žˆ๊ฒŒ๋˜์—ˆ๋‹ค.

์ถ”๊ฐ€์ ์œผ๋กœ ์œ„์˜ ์ˆ˜์‹์—์„œ ๊ธฐ๋Œ“๊ฐ’ ์ˆ˜์‹์„ Monte Carlo Sampling์œผ๋กœ ๋Œ€์ฒดํ•˜๋ฉด ์•„๋ž˜์ฒ˜๋Ÿผ ์‹ ๊ฒฝ๋ง ํŒŒ๋ผ๋ฏธํ„ฐ θ๋ฅผ updateํ•  ์ˆ˜ ์žˆ๋‹ค.
์ด ์ˆ˜์‹์„ ๋” ํ’€์–ด์„œ ์„ค๋ช…ํ•ด๋ณด์ž.
โˆ™ log๐œ‹θ(at | st) : st๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ, ์ •์ฑ…ํŒŒ๋ผ๋ฏธํ„ฐ θ์ƒ์˜ ํ™•๋ฅ ๋ถ„ํฌ์—์„œ sampling๋˜์–ด ์„ ํƒ๋œ ํ–‰๋™์ด at์ผ ํ™•๋ฅ ๊ฐ’์ด๋‹ค.
ํ•ด๋‹นํ™•๋ฅ ๊ฐ’์„ θ์— ๋Œ€ํ•ด ๋ฏธ๋ถ„ํ•œ ๊ฐ’์ด  θlog๐œ‹θ(at | st)์ด๋‹ค.
๋”ฐ๋ผ์„œ ํ•ด๋‹น ๊ธฐ์šธ๊ธฐ๋ฅผ ํ†ตํ•œ ๊ฒฝ์‚ฌ์ƒ์Šน๋ฒ•์€ log๐œ‹θ(at | st)๋ฅผ ์ตœ๋Œ€ํ™”ํ•จ์„ ์˜๋ฏธํ•œ๋‹ค.
โˆ™ ์ฆ‰, at์˜ ํ™•๋ฅ ์„ ๋†’์ด๋„๋ก ํ•˜์—ฌ ์•ž์œผ๋กœ ๋™์ผ์ƒํƒœ์—์„œ ํ•ด๋‹นํ–‰๋™์ด ๋” ์ž์ฃผ ์„ ํƒ๋˜๊ฒŒ ํ•œ๋‹ค.


Gradient ∇θlog๐œ‹θ(at | st)์— ๋ณด์ƒ์„ ๊ณฑํ•ด์ฃผ์—ˆ๊ธฐ์— ๋งŒ์•ฝ sampling๋œ ํ•ด๋‹น ํ–‰๋™๋“ค์ด ํฐ ๋ณด์ƒ์„ ๋ฐ›์•˜๋‹ค๋ฉด,
ํ•™์Šต๋ฅ  γ์— ์ถ”๊ฐ€์ ์ธ ๊ณฑ์…ˆ์„ ํ†ตํ•ด ๋” ํฐ step์œผ๋กœ ๊ฒฝ์‚ฌ์ƒ์Šน์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค.

ํ•˜์ง€๋งŒ ์Œ์˜ ๋ณด์ƒ๊ฐ’์„ ๋ฐ›๊ฒŒ๋œ๋‹ค๋ฉด, ๊ฒฝ์‚ฌ์˜ ๋ฐ˜๋Œ€๋ฐฉํ–ฅ์œผ๋กœ step์„ ๊ฐ–๊ฒŒ ๊ฐ’์ด ๊ณฑํ•ด์งˆ ๊ฒƒ์ด๋ฏ€๋กœ ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ๊ณผ ๊ฐ™์€ ํšจ๊ณผ๊ฐ€ ๋ฐœ์ƒํ•  ๊ฒƒ์ด๋‹ค.
๋”ฐ๋ผ์„œ ํ•ด๋‹น sampling๋œ a๋“ค์ด ์•ž์œผ๋กœ๋Š” ์ž˜ ๋‚˜์˜ค์ง€ ์•Š๊ฒŒ ์‹ ๊ฒฝ๋ง ํŒŒ๋ผ๋ฏธํ„ฐ θ๊ฐ€ update๋  ๊ฒƒ์ด๋‹ค.
๋”ฐ๋ผ์„œ ์‹ค์ œ ๋ณด์ƒ์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ํ–‰๋™์˜ ํ™•๋ฅ ์„ ์ตœ๋Œ€๋กœํ•˜๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ θ๋ฅผ ์ฐพ๋„๋ก ํ• ๊ฒƒ์ด๋‹ค.

๋‹ค๋งŒ, ๊ธฐ์กด์˜ ๊ฒฝ์‚ฌ๋„๋Š” ๋ฐฉํ–ฅ๊ณผ ํฌ๊ธฐ๋ฅผ ๋‚˜ํƒ€๋‚ผ ์ˆ˜ ์žˆ์—ˆ์ง€๋งŒ,
Policy Gradient๋Š” ๊ธฐ์กด ๊ฒฝ์‚ฌ๋„์˜ ๋ฐฉํ–ฅ์— ์Šค์นผ๋ผ ํฌ๊ธฐ๊ฐ’์„ ๊ณฑํ•ด์ฃผ๋ฏ€๋กœ
์‹ค์ œ ๋ณด์ƒ์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ์ง์ ‘์ ์ธ ๋ฐฉํ–ฅ์„ ์ง€์ •ํ•  ์ˆ˜ ์—†๊ธฐ์— ์‚ฌ์‹ค์ƒ ํ›ˆ๋ จ์ด ์–ด๋ ต๊ณ  ๋น„ํšจ์œจ์ ์ด๋ผ๋Š” ๋‹จ์ ์ด ์กด์žฌํ•œ๋‹ค.

 

3.2 MLE  v.s  Policy Gradient
๋‹ค์Œ ์˜ˆ์‹œ๋กœ ์ตœ๋Œ€๊ฐ€๋Šฅ๋„์ถ”์ •(MLE)๊ณผ์˜ ๋น„๊ต๋ฅผ ํ†ตํ•ด Policy Gradient๋ฅผ ๋” ์ดํ•ดํ•ด๋ณด์ž.

โˆ™ n๊ฐœ์˜ sequence๋กœ ์ด๋ค„์ง„ data๋ฅผ ์ž…๋ ฅ๋ฐ›์•„
โˆ™ m๊ฐœ์˜ sequence๋กœ ์ด๋ค„์ง„ data๋ฅผ ์ถœ๋ ฅํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ๊ทผ์‚ฌ์‹œํ‚ค๋Š” ๊ฒƒ์ด ๋ชฉํ‘œ

๊ทธ๋ ‡๋‹ค๋ฉด sequence x1:n๊ณผ y1:m์€ B๋ผ๋Š” dataset์— ์กด์žฌํ•œ๋‹ค.


๋ชฉํ‘œ) ์‹ค์ œํ•จ์ˆ˜ f: x→y๋ฅผ ๊ทผ์‚ฌํ•˜๋Š” ์‹ ๊ฒฝ๋ง parameter θ๋ฅผ ์ฐพ๋Š”๊ฒƒ์ด๋ฏ€๋กœ


์ด์ œ ํ•ด๋‹น ํ•จ์ˆ˜๋ฅผ ๊ทผ์‚ฌํ•˜๊ธฐ ์œ„ํ•ด parameter θ๋ฅผ ํ•™์Šตํ•ด์•ผํ•œ๋‹ค.
θ๋Š” ์•„๋ž˜์™€ ๊ฐ™์ด MLE๋กœ ์–ป์„ ์ˆ˜ ์žˆ๋‹ค.

Dataset B์˜ ๊ด€๊ณ„๋ฅผ ์ž˜ ์„ค๋ช…ํ•˜๋Š” θ๋ฅผ ์–ป๊ธฐ์œ„ํ•ด, ๋ชฉ์ ํ•จ์ˆ˜๋ฅผ ์•„๋ž˜์™€ ๊ฐ™์ด ์ •์˜ํ•œ๋‹ค.
์•„๋ž˜ ์ˆ˜์‹์€ Cross-Entropy ๋ฅผ ๋ชฉ์ ํ•จ์ˆ˜๋กœ ์ •์˜ํ•œ ๊ฒƒ์ด๋‹ค.
๋ชฉํ‘œ) ์†์‹คํ•จ์ˆ˜์˜ ๊ฐ’์„ ์ตœ์†Œํ™” ํ•˜๋Š” ๊ฒƒ.

์•ž์—์„œ ์ •์˜ํ•œ ๋ชฉ์ ํ•จ์ˆ˜๋ฅผ ์ตœ์†Œํ™” ํ•ด์•ผํ•˜๋ฏ€๋กœ Optimizer๋ฅผ ํ†ตํ•ด ๊ทผ์‚ฌํ•  ์ˆ˜ ์žˆ๋‹ค.
(์•„๋ž˜ ์ˆ˜์‹์€ Optimizer๋กœ Gradient Descent๋ฅผ ์‚ฌ์šฉํ•˜์˜€๋‹ค.)
ํ•ด๋‹น ์ˆ˜์‹์—์„œ ํ•™์Šต๋ฅ  γ๋ฅผ ํ†ตํ•ด update์˜ ํฌ๊ธฐ๋ฅผ ์กฐ์ ˆํ•œ๋‹ค.

๋‹ค์Œ ์ˆ˜์‹์€ Policy Gradient์— ๊ธฐ๋ฐ˜ํ•ด ๋ˆ„์ ๊ธฐ๋Œ€๋ณด์ƒ์„ ์ตœ๋Œ€๋กœํ•˜๋Š” ๊ฒฝ์‚ฌ์ƒ์Šน๋ฒ• ์ˆ˜์‹์ด๋‹ค.


์ด ์ˆ˜์‹์—์„œ๋„ ์ด์ „ MLE์˜ ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ• ์ˆ˜์‹์ฒ˜๋Ÿผ γ์˜ ์ถ”๊ฐ€๋กœ Q๐›‘θ(st, at)๊ฐ€ ๊ธฐ์šธ๊ธฐ ์•ž์— ๋ถ™์–ด ํ•™์Šต๋ฅ ์—ญํ• ์„ ํ•˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค.
๋”ฐ๋ผ์„œ ๋ณด์ƒ์˜ ํฌ๊ธฐ์— ๋”ฐ๋ผ ํ•ด๋‹น ํ–‰๋™์„ ๋”์šฑ ๊ฐ•ํ™”ํ•˜๊ฑฐ๋‚˜ ๋ฐ˜๋Œ€ ๋ฐฉํ–ฅ์œผ๋กœ ๋ถ€์ •ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋˜๋Š” ๊ฒƒ์ด๋‹ค.
์ฆ‰, ๊ฒฐ๊ณผ์— ๋”ฐ๋ผ ๋™์ ์œผ๋กœ ํ•™์Šต๋ฅ ์„ ์•Œ๋งž๊ฒŒ ์กฐ์ ˆํ•ด์ค€๋‹ค๊ณ  ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋‹ค.

Sampling ํ™•๋ฅ ์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ๊ฒฝ์‚ฌ๋„๋ฅผ ๊ตฌํ•˜๋Š” Policy Gradient

 

3.3 Baseline์„ ๊ณ ๋ คํ•œ Reinforce ์•Œ๊ณ ๋ฆฌ์ฆ˜
์•ž์—์„œ ์„ค๋ช…ํ•œ Policy Gradient๋ฅผ ์ˆ˜ํ–‰ํ•  ๋•Œ, ๋ณด์ƒ์ด ์–‘์ˆ˜์ธ ๊ฒฝ์šฐ ์–ด๋–ป๊ฒŒ ๋™์ž‘ํ• ๊นŒ?

ex) ์‹œํ—˜์ด 0~100์ ์‚ฌ์ด ๋ถ„ํฌํ•  ๋•Œ, ์ •๊ทœ๋ถ„ํฌ์— ์˜ํ•ด ๋Œ€๋ถ€๋ถ„ ํ‰๊ท ๊ทผ์ฒ˜์— ์ ์ˆ˜๊ฐ€ ๋ถ„ํฌํ•  ๊ฒƒ์ด๋‹ค.
๋”ฐ๋ผ์„œ ๋Œ€๋ถ€๋ถ„์˜ ํ•™์ƒ๋“ค์€ ์–‘์˜ ๋ณด์ƒ์„ ๋ฐ›๋Š”๋‹ค.
๊ทธ๋ ‡๊ฒŒ๋˜๋ฉด, ์•ž์˜ ๊ธฐ์กด policy gradient๋Š” ํ•ญ์ƒ ์–‘์˜ ๋ณด์ƒ์„ ๋ฐ›์•„ Agent(ํ•™์ƒ)์—๊ฒŒ ํ•ด๋‹น policy๋ฅผ ๋”์šฑ ๋…๋ คํ•  ๊ฒƒ์ด๋‹ค.
But! ํ‰๊ท ์ ์ˆ˜๊ฐ€ 50์ ์ผ ๋•Œ, 10์ ์€ ์ƒ๋Œ€์ ์œผ๋กœ ๋งค์šฐ ๋‚˜์œ์ ์ˆ˜์ด๋ฏ€๋กœ ๊ธฐ์กด ์ •์ฑ…์˜ ๋ฐ˜๋Œ€๋ฐฉํ–ฅ์œผ๋กœ ํ•™์Šตํ•˜๊ฒŒ๋œ๋‹ค.

์ฆ‰, ์ฃผ์–ด์ง„ ์ƒํ™ฉ์— ๋งˆ๋•…ํ•œ ๋ณด์ƒ์ด ์žˆ๊ธฐ์— ์šฐ๋ฆฐ ์ด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ํ˜„์žฌ policy๊ฐ€ ์–ผ๋งˆ๋‚˜ ํ›Œ๋ฅญํ•œ์ง€ ํ‰๊ฐ€๋ฅผ ํ•  ์ˆ˜ ์žˆ๋‹ค.
์ด๋ฅผ ์•„๋ž˜ ์ˆ˜์‹์ฒ˜๋Ÿผ policy gradient์ˆ˜์‹์œผ๋กœ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค.
์ด์ฒ˜๋Ÿผ Reinforce ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ baseline์„ ๊ณ ๋ คํ•ด ์ข€ ๋” ์•ˆ์ •์  ๊ฐ•ํ™”ํ•™์Šต์ˆ˜ํ–‰์ด ๊ฐ€๋Šฅํ•˜๋‹ค.

 

 

 

 

 

 

 

 

 

 

 

 

 

 


4. Natural Language Generation์— Reinforcement Learning ์ ์šฉ

๊ฐ•ํ™”ํ•™์Šต์€ ๋งˆ๋ฅด์ฝ”ํ”„ ๊ฒฐ์ • ๊ณผ์ •(MDP) ์ƒ์—์„œ ์ •์˜๋˜๊ณ  ๋™์ž‘ํ•œ๋‹ค.
์—ฌ๋Ÿฌ Decision Action → ์—ฌ๋Ÿฌ ์ƒํ™ฉ์„ ์ด๋™(transition)ํ•˜๋ฉฐ episode๊ฐ€ ๊ตฌ์„ฑ → ์„ ํƒ๋œ ํ–‰๋™๊ณผ ์ƒํƒœ์— ๋”ฐ๋ผ ๋ณด์ƒ์ด ์ฃผ์–ด์ง„๋‹ค.
์ด๊ฒƒ์ด ๋ˆ„์ ๋˜๊ณ  ์—ํ”ผ์†Œ๋“œ๊ฐ€ ์ข…๋ฃŒ๋˜๋ฉด ๋ˆ„์ ๋ณด์ƒ์„ ์–ป์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด๋Ÿฐ ๊ณผ์ •์€ NLP์—์„œ text classification๋ณด๋‹ค๋Š” sequential data๋ฅผ ์˜ˆ์ธกํ•ด์•ผ ํ•˜๋Š” ์ž์—ฐ์–ด ์ƒ์„ฑ(NLG)์— ์ ์šฉ๋œ๋‹ค.
 
โˆ™ ์ด์ œ๊นŒ์ง€ ์ƒ์„ฑ๋œ word sequence = current state(ํ˜„์žฌ์ƒํ™ฉ)
โˆ™ ์ด์ œ๊นŒ์ง€ ์ƒ์„ฑ๋œ ๋‹จ์–ด๋ฅผ ๊ธฐ๋ฐ˜ → ์ƒˆ๋กญ๊ฒŒ ์„ ํƒํ•˜๋Š” ๋‹จ์–ด๊ฐ€ ํ–‰๋™์ด ๋  ๊ฒƒ
ํ•˜๋‚˜์˜ ๋ฌธ์žฅ์„ ์ƒ์„ฑํ•˜๋Š” ๊ณผ์ •(= (BOS)~(EOS)๊นŒ์ง€ ์„ ํƒํ•˜๋Š” ๊ณผ์ •)์ด ํ•˜๋‚˜์˜ Episode๊ฐ€ ๋œ๋‹ค.
 
โˆ™ episode๋ฅผ ๋ฐ˜๋ณตํ•ด ๋ฌธ์žฅ์ƒ์„ฑ๊ฒฝํ—˜์„ ์ถ•์  → ์‹ค์ œ ์ •๋‹ต๊ณผ์˜ ๋น„๊ต → ๊ธฐ๋Œ€๋ˆ„์ ๋ณด์ƒ์„ ์ตœ๋Œ€ํ™”ํ•˜๋„๋ก θ๋ฅผ ํ›ˆ๋ จ

 
NMT์— RL์„ ์ ์šฉํ•œ๋‹ค๋ฉด...?
NMT์— RL์„ ๊ตฌ์ฒด์ ์œผ๋กœ ๋Œ€์ž…ํ•ด๋ณด์ž.
โˆ™ ํ˜„์žฌ ์ƒํƒœ = ์ฃผ์–ด์ง„ src๋ฌธ์žฅ๊ณผ ์ด์ „๊นŒ์ง€ ์ƒ์„ฑ(๋ฒˆ์—ญ)๋œ ๋‹จ์–ด๋“ค์˜ ์‹œํ€€์Šค
โˆ™ ํ–‰๋™์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ = ํ˜„์žฌ ์ƒํƒœ์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ์ƒˆ๋กœ์šด ๋‹จ์–ด๋ฅผ ์„ ํƒํ•˜๋Š” ๊ฒƒ.
โˆ™ ํ˜„์žฌ time-step์˜ ํ–‰๋™์„ ์„ ํƒ ์‹œ → ๋‹ค์Œ time-step์˜ ์ƒํƒœ๋Š” ์†Œ์Šค ๋ฌธ์žฅ๊ณผ ์ด์ „๊นŒ์ง€ ์ƒ์„ฑ๋œ ๋‹จ์–ด๋“ค์˜ ์‹œํ€€์Šค์— ํ˜„์žฌ time-step์— ์„ ํƒ๋œ ๋‹จ์–ด๊ฐ€ ์ถ”๊ฐ€๋˜์–ด ์ •ํ•ด์ง„๋‹ค.
โ—๏ธ์ค‘์š”ํ•œ ์ 
ํ–‰๋™์„ ์„ ํƒํ•œ ํ›„, ํ™˜๊ฒฝ์œผ๋กœ๋ถ€ํ„ฐ ์ฆ‰๊ฐ์ ์ธ ๋ณด์ƒ์„ ๋ฐ›์ง€๋Š” ์•Š์œผ๋ฉฐ,
๋ชจ๋“  ๋‹จ์–ด์˜ ์„ ํƒ์ด ๋๋‚˜๊ณ  ์ตœ์ข…์ ์œผ๋กœ EOS๋ฅผ ์„ ํƒํ•ด ๋””์ฝ”๋”ฉ์ด ์ข…๋ฃŒ๋˜์–ด ์—ํ”ผ์†Œ๋“œ๊ฐ€ ๋๋‚˜๋ฉด,
๋น„๋กœ์†Œ BLEU ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ ๋ˆ„์  ๋ณด์ƒ์„ ๋ฐ›์„ ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ
์ฆ‰, ์ข…๋ฃŒ ์‹œ ๋ฐ›๋Š” ๋ณด์ƒ๊ฐ’ = ์—ํ”ผ์†Œ๋“œ ๋ˆ„์ ๋ณด์ƒ(cumulative reward) ๊ฐ’
๊ฐ•ํ™”ํ•™์Šต์„ ํ†ตํ•ด ๋ชจ๋ธ์„ ํ›ˆ๋ จ ์‹œ, ํ›ˆ๋ จ์˜ ๋„์ž…๋ถ€๋ถ€ํ„ฐ ๊ฐ•ํ™”ํ•™์Šต๋งŒ ์ ์šฉํ•˜๊ธฐ์—๋Š” ๊ทธ ํ›ˆ๋ จ๋ฐฉ์‹์ด ๋น„ํšจ์œจ์ ์ด๊ณ  ์–ด๋ ค์›€์ด ํฌ๋ฏ€๋กœ,
๋ณดํ†ต ๊ธฐ์กด์˜ MLE๋ฅผ ํ†ตํ•ด ์–ด๋Š ์ •๋„ ํ•™์Šต์ด ๋œ ์‹ ๊ฒฝ๋ง θ์— ๊ฐ•ํ™”ํ•™์Šต์„ ์ ์šฉํ•œ๋‹ค.
์ฆ‰, ๊ฐ•ํ™”ํ•™์Šต์€ ํƒํ—˜(exploration)์„ ํ†ตํ•ด ๋” ๋‚˜์€ ์ •์ฑ…์˜ ๊ฐ€๋Šฅ์„ฑ์„ ์ฐพ์•„๋‚ด๊ณ  ์ฐฉ์ทจ(exploitation)๋ฅผ ํ†ตํ•ด ๊ทธ ์ •์ฑ…์„ ๋ฐœ์ „์‹œ์ผœ ๋‚˜๊ฐ‘๋‹ˆ๋‹ค.

 

4.1 NLG์—์„œ ๊ฐ•ํ™”ํ•™์Šต์˜ ํŠน์ง•
 ์•ž์—์„œ RL์˜ Policy based Learning์ธ Policy Gradient๋ฐฉ์‹์— ๋Œ€ํ•ด ๊ฐ„๋‹จํžˆ ์•Œ์•„๋ดค๋‹ค.
Policy Gradient์˜ ๊ฒฝ์šฐ, ์œ„์—์„œ ์„ค๋ช…ํ•œ ๋‚ด์šฉ ์ด์™ธ์—๋„ ๋ฐœ์ „๋œ ๋ฐฉ๋ฒ•๋“ค์ด ๋งŽ๋‹ค.
ex) Actor Critic, A3C, ...
โˆ™ Actor Critic: ์ •์ฑ…๋ง θ์ด์™ธ์—๋„ ๊ฐ€์ค‘์น˜๋ง W๋ฅผ ๋”ฐ๋กœ ๋‘์–ด episode์ข…๋ฃŒ๊นŒ์ง€ ๊ธฐ๋‹ค๋ฆฌ์ง€ ์•Š๊ณ  TDํ•™์Šต๋ฒ•์œผ๋กœ ํ•™์Šตํ•œ๋‹ค.
โˆ™ A3C(Asunchronous Advantage Actor Critic): Actor Critic์—์„œ ๋”์šฑ ๋ฐœ์ „ ๋ฐ ๊ธฐ์กด ๋‹จ์ ์„ ๋ณด์™„

๋‹ค๋งŒ, NLP์˜ RL์€ ์ด๋Ÿฐ ๋‹ค์–‘ํ•œ ๋ฐฉ๋ฒ•๋“ค์„ ๊ตณ์ด ์‚ฌ์šฉํ•  ํ•„์š”์—†์ด ๊ฐ„๋‹จํ•œ Reinforce์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์‚ฌ์šฉํ•ด๋„ ํฐ ๋ฌธ์ œ๊ฐ€ ์—†๋Š”๋ฐ, ์ด๋Š” NLP๋ถ„์•ผ์˜ ํŠน์ง• ๋•๋ถ„์œผ๋กœ ๊ฐ•ํ™”ํ•™์Šต์„ ์ž์—ฐ์–ด์ฒ˜๋ฆฌ์— ์ ์šฉ ์‹œ, ์•„๋ž˜์™€ ๊ฐ™์€ ํŠน์ง•์ด ์กด์žฌํ•œ๋‹ค.
1. ์„ ํƒ ๊ฐ€๋Šฅํ•œ ๋งค์šฐ ๋งŽ์€ ํ–‰๋™(action) at๊ฐ€ ์กด์žฌ
โˆ™ ๋ณดํ†ต ๋‹ค์Œ ๋‹จ์–ด๋ฅผ ์„ ํƒํ•˜๋Š” ๊ฒƒ = ํ–‰๋™์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ
โˆ™ ์„ ํƒ ๊ฐ€๋Šฅํ•œ ํ–‰๋™์˜ ์ง‘ํ•ฉ์˜ ํฌ๊ธฐ = ์–ดํœ˜ ์‚ฌ์ „์˜ ํฌ๊ธฐ
∴  ๊ทธ ์ง‘ํ•ฉ์˜ ํฌ๊ธฐ๋Š” ๋ณดํ†ต ๋ช‡ ๋งŒ ๊ฐœ๊ฐ€ ๋˜๊ธฐ ๋งˆ๋ จ.



2. ๋งค์šฐ ๋งŽ์€ ์ƒํƒœ (state)๊ฐ€ ์กด์žฌ
 ๋‹จ์–ด๋ฅผ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ํ–‰๋™์ด์—ˆ๋‹ค๋ฉด,
โˆ™ ์ด์ œ๊นŒ์ง€ ์„ ํƒ๋œ ๋‹จ์–ด๋“ค์˜ ์‹œํ€€์Šค = ์ƒํƒœ
โˆ™ ์—ฌ๋Ÿฌ time-step์„ ๊ฑฐ์ณ ์ˆ˜๋งŽ์€ ํ–‰๋™(๋‹จ์–ด)์ด ์„ ํƒ๋˜์—ˆ๋‹ค๋ฉด ๊ฐ€๋Šฅํ•œ ์ƒํƒœ์˜ ๊ฒฝ์šฐ์˜ ์ˆ˜๋Š” ๋งค์šฐ ์ปค์งˆ ๊ฒƒ.



3. ๋”ฐ๋ผ์„œ ๋งค์šฐ ๋งŽ์€ ํ–‰๋™์„ ์„ ํƒํ•˜๊ณ , ๋งค์šฐ ๋งŽ์€ ์ƒํƒœ๋ฅผ ํ›ˆ๋ จ ๊ณผ์ •์—์„œ ๋ชจ๋‘ ๊ฒช๋Š” ๊ฒƒ์€ ๊ฑฐ์˜ ๋ถˆ๊ฐ€๋Šฅํ•˜๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค.
โˆ™ ๊ฒฐ๊ตญ ์ถ”๋ก  ๊ณผ์ •์—์„œ unseen sample์„ ๋งŒ๋‚˜๋Š” ๊ฒƒ์€ ๋งค์šฐ ๋‹น์—ฐํ•  ๊ฒƒ์ด๋‹ค.
โˆ™ ์ด๋Ÿฐ ํฌ์†Œ์„ฑ ๋ฌธ์ œ๋Š” ํฐ ๊ณจ์นซ๊ฑฐ๋ฆฌ๊ฐ€ ๋  ์ˆ˜ ์žˆ์ง€๋งŒ DNN์œผ๋กœ ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐ ํ•  ์ˆ˜ ์žˆ๋‹ค.



4. ๊ฐ•ํ™”ํ•™์Šต์„ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ์— ์ ์šฉํ•  ๋•Œ ์‰ฌ์šด ์ ๋„ ์žˆ๋‹ค.
โˆ™ ๋Œ€๋ถ€๋ถ„ ํ•˜๋‚˜์˜ ๋ฌธ์žฅ ์ƒ์„ฑ = ํ•˜๋‚˜์˜ ์—ํ”ผ์†Œ๋“œ (์ด๋•Œ, ๋ฌธ์žฅ์˜ ๊ธธ์ด๋Š” ๋ณดํ†ต 100 ๋‹จ์–ด ๋ฏธ๋งŒ)
→ ๋‹ค๋ฅธ ๋ถ„์•ผ์˜ ๊ฐ•ํ™”ํ•™์Šต๋ณด๋‹ค ํ›จ์”ฌ ์‰ฝ๋‹ค๋Š” ์ด์ ์„ ๊ฐ–๋Š”๋‹ค.

ex) DeepMind์˜ AlphaGo, Starcraft์˜ ๊ฒฝ์šฐ, ํ•˜๋‚˜์˜ ์—ํ”ผ์†Œ๋“œ๊ฐ€ ๋๋‚˜๊ธฐ๊นŒ์ง€ ๋งค์šฐ ๊ธด ์‹œ๊ฐ„์ด ๋“ ๋‹ค.
→ ์—ํ”ผ์†Œ๋“œ ๋‚ด์—์„œ ์„ ํƒ๋œ ํ–‰๋™๋“ค์ด ์ •์ฑ…์„ ์—…๋ฐ์ดํŠธํ•˜๋ ค๋ฉด ๋งค์šฐ ๊ธด ์—ํ”ผ์†Œ๋“œ๊ฐ€ ๋๋‚˜๊ธฐ๋ฅผ ๊ธฐ๋‹ค๋ ค์•ผ ํ•œ๋‹ค.
→ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, 10๋ถ„ ์ „์— ์„ ํƒํ–ˆ๋˜ ํ–‰๋™์ด ํ•ด๋‹น ๊ฒŒ์ž„์˜ ์ŠนํŒจ์— ์–ผ๋งˆ๋‚˜ ํฐ ์˜ํ–ฅ์„ ๋ฏธ์ณค๋Š”์ง€ ์•Œ์•„๋‚ด๊ธฐ๋ž€ ๋งค์šฐ ์–ด๋ ค์šด ์ผ์ด ๋  ๊ฒƒ์ด๋‹ค.
โ—๏ธ์ด๋•Œ ์ž์—ฐ์–ด ์ƒ์„ฑ ๋ถ„์•ผ๊ฐ€ ๋‹ค๋ฅธ ๋ถ„์•ผ์— ๋น„ํ•ด ์—ํ”ผ์†Œ๋“œ๊ฐ€ ์งง๋‹ค๋Š” ๊ฒƒ์€ ๋งค์šฐ ํฐ ์ด์ ์œผ๋กœ ์ž‘์šฉํ•˜์—ฌ ์ •์ฑ…๋ง์„ ํ›จ์”ฌ ๋” ์‰ฝ๊ฒŒ ํ›ˆ๋ จ ์‹œํ‚ฌ ์ˆ˜ ์žˆ๋‹ค.



5. ๋Œ€์‹ , ๋ฌธ์žฅ ๋‹จ์œ„์˜ ์—ํ”ผ์†Œ๋“œ๋ฅผ ๊ฐ€์ง€๋Š” ๊ฐ•ํ™”ํ•™์Šต์—์„œ๋Š” ๋ณดํ†ต ์—ํ”ผ์†Œ๋“œ ์ค‘๊ฐ„์— ๋ณด์ƒ์„ ์–ป๊ธฐ ์–ด๋ ต๋‹ค.
ex) ๋ฒˆ์—ญ์˜ ๊ฒฝ์šฐ, ๊ฐ time-step๋งˆ๋‹ค ๋‹จ์–ด๋ฅผ ์„ ํƒํ•  ๋•Œ ์ฆ‰๊ฐ์ ์ธ ๋ณด์ƒ์„ ์–ป์ง€ ๋ชปํ•˜๊ณ , ๋ฒˆ์—ญ์ด ๋ชจ๋‘ ๋๋‚œ ์ดํ›„ ์™„์„ฑ๋œ ๋ฌธ์žฅ๊ณผ ์ •๋‹ต ๋ฌธ์žฅ์„ ๋น„๊ตํ•˜์—ฌ BLEU ์ ์ˆ˜๋ฅผ ๋ˆ„์  ๋ณด์ƒ์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค.
๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ์—ํ”ผ์†Œ๋“œ๊ฐ€ ๋งค์šฐ ๊ธธ๋‹ค๋ฉด ์ด๊ฒƒ์€ ๋งค์šฐ ํฐ ๋ฌธ์ œ๊ฐ€ ๋˜์—ˆ๊ฒ ์ง€๋งŒ, ๋‹คํ–‰ํžˆ๋„ ๋ฌธ์žฅ ๋‹จ์œ„์˜ ์—ํ”ผ์†Œ๋“œ์—์„œ๋Š” ํฐ ๋ฌธ์ œ๊ฐ€ ๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค

 

4.2 RL ์ ์šฉ์˜ ์žฅ์ 
Teacher Forcing์„ ์ด์šฉํ•œ ๋ฌธ์ œํ•ด๊ฒฐ
seq2seq๊ฐ™์€ AR์†์„ฑ์˜ ๋ชจ๋ธํ›ˆ๋ จ ์‹œ, teacher forcing ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•œ๋‹ค.
์ด ๋ฐฉ๋ฒ•์€ train๊ณผ inference๋ฐฉ์‹์˜ ์ฐจ์ด๊ฐ€ ๋ฐœ์ƒ, ์‹ค์ œ ์ถ”๋ก ๋ฐฉ์‹๊ณผ ๋‹ค๋ฅด๊ฒŒ ๋ฌธ์ œ๋ฅผ ํ›ˆ๋ จํ•ด์•ผํ•œ๋‹ค.

โ—๏ธํ•˜์ง€๋งŒ, RL์„ ํ†ตํ•ด ์‹ค์ œ ์ถ”๋ก ํ˜•ํƒœ์™€ ๊ฐ™์ด sampling์œผ๋กœ ๋ชจ๋ธ์„ ํ•™์Šต
→ "train๊ณผ inference"์˜ ์ฐจ์ด๊ฐ€ ์—†์–ด์กŒ๋‹ค.
๋” ์ •ํ™•ํ•œ ๋ชฉ์ ํ•จ์ˆ˜์˜ ์‚ฌ์šฉ
BLEU๋Š” PPL์— ๋น„ํ•ด ๋” ๋‚˜์€ ๋ฒˆ์—ญํ’ˆ์งˆ์„ ๋ฐ˜์˜ํ•œ๋‹ค. [Gain_NLP_07]
๋‹ค๋งŒ, ์ด๋Ÿฐ metric๋“ค์€ ๋ฏธ๋ถ„์„ ํ†ตํ•ด ํ›ˆํ˜„ํ•˜๊ธฐ ๋ถˆ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ๊ฐ€ ๋Œ€๋ถ€๋ถ„์ด์–ด์„œ ์‹ ๊ฒฝ๋ง ํ›ˆ๋ จ์— ์‚ฌ์šฉ์ด ์–ด๋ ต๋‹ค.

โ—๏ธํ•˜์ง€๋งŒ, RL์˜ Policy Gradient๋ฅผ ์‘์šฉํ•ด ๋ณด์ƒํ•จ์ˆ˜์— ๋Œ€ํ•ด ๋ฏธ๋ถ„์„ ๊ณ„์‚ฐํ•  ํ•„์š”๊ฐ€ ์—†์–ด์ง€๋ฉด์„œ ์ •ํ™•ํ•œ Metric ์‚ฌ์šฉ์ด ๊ฐ€๋Šฅํ–ˆ๋‹ค.

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 


5. RL์„ ํ™œ์šฉํ•œ Supervised Learning

BLEU๋ฅผ ํ›ˆ๋ จ๊ณผ์ •์˜ ๋ชฉ์ ํ•จ์ˆ˜๋กœ ์‚ฌ์šฉํ•œ๋‹ค๋ฉด ๋” ์ข‹์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์„ํ…๋ฐ...

๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ๋‹ค๋ฅธ NLG๋ฌธ์ œ์— ๋Œ€ํ•ด์„œ๋„ ๋น„์Šทํ•œ ์ ‘๊ทผ์„ ํ•  ์ˆ˜ ์žˆ์œผ๋ฉด ์ข‹์„ํ…๋ฐ...

 

5.1 MRT (Minimum Risk Training)
์œ„์˜ ๋ฐ”๋žŒ์—์„œ ์ถœ๋ฐœํ•˜์—ฌ ์œ„ํ—˜์ตœ์†Œํ™” ํ›ˆ๋ จ[MRT; Minimum Risk Training]๋…ผ๋ฌธ์ด ์ œ์•ˆ๋˜์—ˆ๋‹ค.
๋‹น์‹œ ์ €์ž๋Š” Policy Gradient๋ฅผ ์ง์ ‘์ ์œผ๋กœ ์‚ฌ์šฉํ•˜์ง€ ์•Š์•˜์œผ๋‚˜ ์œ ์‚ฌํ•œ ์ˆ˜์‹์ด ์œ ๋„๋˜์—ˆ๋‹ค๋Š” ์ ์—์„œ ๋งค์šฐ ์ธ์ƒ์ ์ด๋‹ค.



๊ธฐ์กด์˜ ์ตœ๋Œ€๊ฐ€๋Šฅ๋„์ถ”์ •(MLE)๋ฐฉ์‹์€ ์œ„์™€๊ฐ™์€ ์†์‹คํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด |S|๊ฐœ์˜ ์ž…์ถœ๋ ฅ์— ๋Œ€ํ•œ ์†์‹ค๊ฐ’์„ ๊ตฌํ•˜๊ณ , ์ด๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” θ๋ฅผ ์ฐพ๋Š” ๊ฒƒ์ด ๋ชฉํ‘œ์˜€๋‹ค.
ํ•˜์ง€๋งŒ, ์ด ๋…ผ๋ฌธ์—์„œ๋Š” risk๋ฅผ ์•„๋ž˜์™€ ๊ฐ™์ด ์ •์˜ํ•˜๊ณ , ์ด๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ํ•™์Šต๋ฐฉ์‹์ธ MRT๋ฅผ ์ œ์•ˆํ–ˆ๋‹ค.

์œ„์˜ ์ˆ˜์‹์—์„œ y(x(s))๋Š” search scape(ํƒ์ƒ‰๊ณต๊ฐ„)์˜ ์ „์ฒด ์ง‘ํ•ฉ์ด๋‹ค.
์ด๋Š” S๋ฒˆ์งธ x(s)๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ ๊ฐ€๋Šฅํ•œ ์ •๋‹ต์ง‘ํ•ฉ์„ ์˜๋ฏธํ•œ๋‹ค.

๋˜ํ•œ, Δ(y, y(s))๋Š” ์ž…๋ ฅ ํŒŒ๋ผ๋ฏธํ„ฐ θ๊ฐ€ ์ฃผ์–ด์กŒ์„ ๋•Œ, samplingํ•œ y์™€ ์‹ค์ œ ์ •๋‹ต y(s)์˜ ์ฐจ์ด(= error)๊ฐ’์„ ์˜๋ฏธํ•œ๋‹ค.

์ฆ‰, ์ด ์ˆ˜์‹์— ๋”ฐ๋ฅด๋ฉด risk R์€ ์ฃผ์–ด์ง„ ์ž…๋ ฅ๊ณผ ํ˜„์žฌ ํŒŒ๋ผ๋ฏธํ„ฐ ์ƒ์—์„œ ์–ป์€ y๋ฅผ ํ†ตํ•ด ํ˜„์žฌ ๋ชจ๋ธํ•จ์ˆ˜๋ฅผ ๊ตฌํ•˜๊ณ , ๋™์‹œ์— ์ด๋ฅผ ์‚ฌ์šฉํ•ด Risk์˜ ๊ธฐ๋Œ€๊ฐ’์„ ๊ตฌํ•œ๋‹ค ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

์ด๋ ‡๊ฒŒ ์ •์˜๋œ Risk๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ๊ฒƒ์ด ๋ชฉํ‘œ์ด๋‹ค.
๋ฐ˜๋Œ€๋กœ risk๋Œ€์‹  ๋ณด์ƒ์œผ๋กœ ์ƒ๊ฐํ•˜๋ฉด, ๋ณด์ƒ์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ๊ฒƒ์ด ๋ชฉํ‘œ์ด๋‹ค.
๊ฒฐ๊ตญ, risk๋ฅผ ์ตœ์†Œํ™”ํ•  ๋•Œ ๊ฒฝ์‚ฌํ•˜๊ฐ•๋ฒ•, ๋ณด์ƒ์„ ์ตœ๋Œ€ํ™”ํ•  ๋•Œ๋Š” ๊ฒฝ์‚ฌ์ƒ์Šน๋ฒ•์„ ์‚ฌ์šฉํ•˜๊ธฐ์—
์ˆ˜์‹์„ ๋ถ„ํ•ดํ•˜๋ฉด ๊ฒฐ๊ตญ ์˜จ์ „ํžˆ ๋™์ผํ•œ ๋‚ด์šฉ์ž„์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.
๋”ฐ๋ผ์„œ ์‹ค์ œ ๊ตฌํ˜„ ์‹œ,  Δ(y, y(s))์‚ฌ์šฉ์„ ์œ„ํ•œ ๋ณด์ƒํ•จ์ˆ˜ BLEU์— -1์„ ๊ณฑํ•ด Riskํ•จ์ˆ˜๋กœ ๋งŒ๋“ ๋‹ค.

๋‹ค๋งŒ ์ฃผ์–ด์ง„ ์ž…๋ ฅ์— ๋Œ€ํ•ด ๊ฐ€๋Šฅํ•œ ์ •๋‹ต์— ๊ด€ํ•œ ์ „์ฒด๊ณต๊ฐ„์„ ํƒ์ƒ‰ํ•  ์ˆ˜ ์—†๊ธฐ์—
์ „์ฒดํƒ์ƒ‰๊ณต๊ฐ„์— samplingํ•œ sub-space์—์„œ samplingํ•˜๋Š” ๊ฒƒ์„ ํƒํ•œ๋‹ค.



๊ทธ ํ›„ ์œ„์˜ ์ˆ˜์‹์—์„œ θ์— ๋Œ€ํ•ด ๋ฏธ๋ถ„์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.
์ด์ œ, ๋ฏธ๋ถ„์„ ํ†ตํ•ด ์–ป์€ MRT์˜ ์ตœ์ข… ์ˆ˜์‹์„ ํ•ด์„ํ•ด๋ณด์ž.


์ตœ์ข…์ ์œผ๋กœ๋Š” ์ˆ˜์‹์—์„œ ๊ธฐ๋Œ€๊ฐ’๋ถ€๋ถ„์„ Monte Carlo Sampling์„ ํ†ตํ•ด ์ œ๊ฑฐํ•  ์ˆ˜ ์žˆ๋‹ค.
์•„๋ž˜์ˆ˜์‹์€ Policy Gradient์˜ Reinforce์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ˆ˜์‹์œผ๋กœ ์•ž์˜ MRT์ˆ˜์‹๊ณผ ๋น„๊ตํ•˜์—ฌ ์ฐธ๊ณ ํ•ด๋ณด์ž.


[MRT์˜ Reinforce์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ˆ˜์‹]


MRT๋Š” ๊ฐ•ํ™”ํ•™์Šต์œผ๋กœ์จ์˜ ์ ‘๊ทผ์„ ์ „ํ˜€ํ•˜์ง€ ์•Š๊ณ , ์ˆ˜์‹์ ์œผ๋กœ Policy Gradient์˜ Reinforce์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ˆ˜์‹์„ ๋„์ถœํ•ด๋‚ด์–ด ์„ฑ๋Šฅ์„ ๋Œ์–ด์˜ฌ๋ฆฐ๋‹ค๋Š” ์ ์—์„œ ๋งค์šฐ ์ธ์ƒ๊นŠ์€ ๋ฐฉ๋ฒ•์ž„์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.



[Policy Gradient์˜ Reinforce์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ˆ˜์‹]

 

Pytorch NMT_with_MRT ์˜ˆ์ œ์ฝ”๋“œ

๊ตฌํ˜„๊ณผ์ •

1. ์ฃผ์–ด์ง„ ์ž…๋ ฅ๋ฌธ์žฅ์—๋Œ€ํ•ด ์ •์ฑ… ๐œƒ๋ฅผ ์ด์šฉํ•ด ๋ฒˆ์—ญ๋ฌธ์žฅ sampling
2. sampling๋ฌธ์žฅ๊ณผ ์ •๋‹ต๋ฌธ์žฅ์‚ฌ์ด BLEU๋ฅผ ๊ณ„์‚ฐ, -1์„ ๊ณฑํ•ด Risk๋กœ ๋ณ€ํ™˜
3. logํ™•๋ฅ ๋ถ„ํฌ์ „์ฒด์— Risk๋ฅผ ๊ณฑํ•จ
4. ๊ฐ sample๊ณผ time-step๋ณ„ ๊ตฌํ•ด์ง„ NLL๊ฐ’์˜ ํ•ฉ์— -1์„ ๊ณฑํ•ด์คŒ(=PLL)
5. ๋กœ๊ทธํ™•๋ฅ ๊ฐ’์˜ ํ•ฉ์— ๐œƒ๋กœ ๋ฏธ๋ถ„์„ ์ˆ˜ํ–‰, BP๋กœ ์‹ ๊ฒฝ๋ง ๐œƒ์ „์ฒด ๊ธฐ์šธ๊ธฐ๊ฐ€ ๊ตฌํ•ด์ง
6. ์ด๋ฏธ Risk๋ฅผ ํ™•๋ฅ ๋ถ„ํฌ์— ๊ณฑํ–ˆ๊ธฐ์—, ๋ฐ”๋กœ ์ด ๊ธฐ์šธ๊ธฐ๋กœ BP๋ฅผ ์ˆ˜ํ–‰, ์ตœ์ ํ™”
from nltk.translate.gleu_score import sentence_gleu
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction

import numpy as np

import torch
from torch import optim
from torch.nn import functional as F
import torch.nn.utils as torch_utils

from ignite.engine import Engine
from ignite.engine import Events
from ignite.metrics import RunningAverage
from ignite.contrib.handlers.tqdm_logger import ProgressBar

import simple_nmt.data_loader as data_loader
from simple_nmt.trainer import MaximumLikelihoodEstimationEngine
from simple_nmt.utils import get_grad_norm, get_parameter_norm

VERBOSE_SILENT = 0
VERBOSE_EPOCH_WISE = 1
VERBOSE_BATCH_WISE = 2


class MinimumRiskTrainingEngine(MaximumLikelihoodEstimationEngine):

    @staticmethod
    def _get_reward(y_hat, y, n_gram=6, method='gleu'):
        # This method gets the reward based on the sampling result and reference sentence.
        # For now, we uses GLEU in NLTK, but you can used your own well-defined reward function.
        # In addition, GLEU is variation of BLEU, and it is more fit to reinforcement learning.
        sf = SmoothingFunction()
        score_func = {
            'gleu':  lambda ref, hyp: sentence_gleu([ref], hyp, max_len=n_gram),
            'bleu1': lambda ref, hyp: sentence_bleu([ref], hyp,
                                                    weights=[1./n_gram] * n_gram,
                                                    smoothing_function=sf.method1),
            'bleu2': lambda ref, hyp: sentence_bleu([ref], hyp,
                                                    weights=[1./n_gram] * n_gram,
                                                    smoothing_function=sf.method2),
            'bleu4': lambda ref, hyp: sentence_bleu([ref], hyp,
                                                    weights=[1./n_gram] * n_gram,
                                                    smoothing_function=sf.method4),
        }[method]

        # Since we don't calculate reward score exactly as same as multi-bleu.perl,
        # (especialy we do have different tokenization,) I recommend to set n_gram to 6.

        # |y| = (batch_size, length1)
        # |y_hat| = (batch_size, length2)

        with torch.no_grad():
            scores = []

            for b in range(y.size(0)):
                ref, hyp = [], []
                for t in range(y.size(-1)):
                    ref += [str(int(y[b, t]))]
                    if y[b, t] == data_loader.EOS:
                        break

                for t in range(y_hat.size(-1)):
                    hyp += [str(int(y_hat[b, t]))]
                    if y_hat[b, t] == data_loader.EOS:
                        break
                # Below lines are slower than naive for loops in above.
                # ref = y[b].masked_select(y[b] != data_loader.PAD).tolist()
                # hyp = y_hat[b].masked_select(y_hat[b] != data_loader.PAD).tolist()

                scores += [score_func(ref, hyp) * 100.]
            scores = torch.FloatTensor(scores).to(y.device)
            # |scores| = (batch_size)

            return scores


    @staticmethod
    def _get_loss(y_hat, indice, reward=1):
        # |indice| = (batch_size, length)
        # |y_hat| = (batch_size, length, output_size)
        # |reward| = (batch_size,)
        batch_size = indice.size(0)
        output_size = y_hat.size(-1)

        '''
        # Memory inefficient but more readable version
        mask = indice == data_loader.PAD
        # |mask| = (batch_size, length)
        indice = F.one_hot(indice, num_classes=output_size).float()
        # |indice| = (batch_size, length, output_size)
        log_prob = (y_hat * indice).sum(dim=-1)
        # |log_prob| = (batch_size, length)
        log_prob.masked_fill_(mask, 0)
        log_prob = log_prob.sum(dim=-1)
        # |log_prob| = (batch_size, )
        '''

        # Memory efficient version
        log_prob = -F.nll_loss(
            y_hat.view(-1, output_size),
            indice.view(-1),
            ignore_index=data_loader.PAD,
            reduction='none'
        ).view(batch_size, -1).sum(dim=-1)

        loss = (log_prob * -reward).sum()
        # Following two equations are eventually same.
        # \theta = \theta - risk * \nabla_\theta \log{P}
        # \theta = \theta - -reward * \nabla_\theta \log{P}
        # where risk = -reward.

        return loss

    @staticmethod
    def train(engine, mini_batch):
        # You have to reset the gradients of all model parameters
        # before to take another step in gradient descent.
        engine.model.train()
        if engine.state.iteration % engine.config.iteration_per_update == 1 or \
            engine.config.iteration_per_update == 1:
            if engine.state.iteration > 1:
                engine.optimizer.zero_grad()

        device = next(engine.model.parameters()).device
        mini_batch.src = (mini_batch.src[0].to(device), mini_batch.src[1])
        mini_batch.tgt = (mini_batch.tgt[0].to(device), mini_batch.tgt[1])

        # Raw target variable has both BOS and EOS token.
        # The output of sequence-to-sequence does not have BOS token.
        # Thus, remove BOS token for reference.
        x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
        # |x| = (batch_size, length)
        # |y| = (batch_size, length)

        # Take sampling process because set False for is_greedy.
        y_hat, indice = engine.model.search(
            x,
            is_greedy=False,
            max_length=engine.config.max_length
        )

        with torch.no_grad():
            # Based on the result of sampling, get reward.
            actor_reward = MinimumRiskTrainingEngine._get_reward(
                indice,
                y,
                n_gram=engine.config.rl_n_gram,
                method=engine.config.rl_reward,
            )
            # |y_hat| = (batch_size, length, output_size)
            # |indice| = (batch_size, length)
            # |actor_reward| = (batch_size)

            # Take samples as many as n_samples, and get average rewards for them.
            # I figured out that n_samples = 1 would be enough.
            baseline = []

            for _ in range(engine.config.rl_n_samples):
                _, sampled_indice = engine.model.search(
                    x,
                    is_greedy=False,
                    max_length=engine.config.max_length,
                )
                baseline += [
                    MinimumRiskTrainingEngine._get_reward(
                        sampled_indice,
                        y,
                        n_gram=engine.config.rl_n_gram,
                        method=engine.config.rl_reward,
                    )
                ]

            baseline = torch.stack(baseline).mean(dim=0)
            # |baseline| = (n_samples, batch_size) --> (batch_size)

            # Now, we have relatively expected cumulative reward.
            # Which score can be drawn from actor_reward subtracted by baseline.
            reward = actor_reward - baseline
            # |reward| = (batch_size)

        # calculate gradients with back-propagation
        loss = MinimumRiskTrainingEngine._get_loss(
            y_hat,
            indice,
            reward=reward
        )
        backward_target = loss.div(y.size(0)).div(engine.config.iteration_per_update)
        backward_target.backward()

        p_norm = float(get_parameter_norm(engine.model.parameters()))
        g_norm = float(get_grad_norm(engine.model.parameters()))

        if engine.state.iteration % engine.config.iteration_per_update == 0 and \
            engine.state.iteration > 0:
            # In orther to avoid gradient exploding, we apply gradient clipping.
            torch_utils.clip_grad_norm_(
                engine.model.parameters(),
                engine.config.max_grad_norm,
            )
            # Take a step of gradient descent.
            engine.optimizer.step()

        return {
            'actor': float(actor_reward.mean()),
            'baseline': float(baseline.mean()),
            'reward': float(reward.mean()),
            '|param|': p_norm if not np.isnan(p_norm) and not np.isinf(p_norm) else 0.,
            '|g_param|': g_norm if not np.isnan(g_norm) and not np.isinf(g_norm) else 0.,
        }

    @staticmethod
    def validate(engine, mini_batch):
        engine.model.eval()

        with torch.no_grad():
            device = next(engine.model.parameters()).device
            mini_batch.src = (mini_batch.src[0].to(device), mini_batch.src[1])
            mini_batch.tgt = (mini_batch.tgt[0].to(device), mini_batch.tgt[1])

            x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
            # |x| = (batch_size, length)
            # |y| = (batch_size, length)

            # feed-forward
            y_hat, indice = engine.model.search(
                x,
                is_greedy=True,
                max_length=engine.config.max_length,
            )
            # |y_hat| = (batch_size, length, output_size)
            # |indice| = (batch_size, length)
            reward = MinimumRiskTrainingEngine._get_reward(
                indice,
                y,
                n_gram=engine.config.rl_n_gram,
                method=engine.config.rl_reward,
            )

        return {
            'BLEU': float(reward.mean()),
        }

    @staticmethod
    def attach(
        train_engine,
        validation_engine,
        training_metric_names = ['actor', 'baseline', 'reward', '|param|', '|g_param|'],
        validation_metric_names = ['BLEU', ],
        verbose=VERBOSE_BATCH_WISE
    ):
        # Attaching would be repaeted for serveral metrics.
        # Thus, we can reduce the repeated codes by using this function.
        def attach_running_average(engine, metric_name):
            RunningAverage(output_transform=lambda x: x[metric_name]).attach(
                engine,
                metric_name,
            )

        for metric_name in training_metric_names:
            attach_running_average(train_engine, metric_name)

        if verbose >= VERBOSE_BATCH_WISE:
            pbar = ProgressBar(bar_format=None, ncols=120)
            pbar.attach(train_engine, training_metric_names)

        if verbose >= VERBOSE_EPOCH_WISE:
            @train_engine.on(Events.EPOCH_COMPLETED)
            def print_train_logs(engine):
                avg_p_norm = engine.state.metrics['|param|']
                avg_g_norm = engine.state.metrics['|g_param|']
                avg_reward = engine.state.metrics['actor']

                print('Epoch {} - |param|={:.2e} |g_param|={:.2e} BLEU={:.2f}'.format(
                    engine.state.epoch,
                    avg_p_norm,
                    avg_g_norm,
                    avg_reward,
                ))

        for metric_name in validation_metric_names:
            attach_running_average(validation_engine, metric_name)

        if verbose >= VERBOSE_BATCH_WISE:
            pbar = ProgressBar(bar_format=None, ncols=120)
            pbar.attach(validation_engine, validation_metric_names)

        if verbose >= VERBOSE_EPOCH_WISE:
            @validation_engine.on(Events.EPOCH_COMPLETED)
            def print_valid_logs(engine):
                avg_bleu = engine.state.metrics['BLEU']
                print('Validation - BLEU={:.2f} best_BLEU={:.2f}'.format(
                    avg_bleu,
                    -engine.best_loss,
                ))

    @staticmethod
    def resume_training(engine, resume_epoch):
        resume_epoch = max(1, resume_epoch - engine.config.n_epochs)
        engine.state.iteration = (resume_epoch - 1) * len(engine.state.dataloader)
        engine.state.epoch = (resume_epoch - 1)

    @staticmethod
    def check_best(engine):
        loss = -float(engine.state.metrics['BLEU'])
        if loss <= engine.best_loss:
            engine.best_loss = loss

    @staticmethod
    def save_model(engine, train_engine, config, src_vocab, tgt_vocab):
        avg_train_bleu = train_engine.state.metrics['actor']
        avg_valid_bleu = engine.state.metrics['BLEU']

        # Set a filename for model of last epoch.
        # We need to put every information to filename, as much as possible.
        model_fn = config.model_fn.split('.')

        model_fn = model_fn[:-1] + ['mrt',
                                    '%02d' % train_engine.state.epoch,
                                    '%.2f-%.2f' % (avg_train_bleu,
                                                   avg_valid_bleu),
                                    ] + [model_fn[-1]]

        model_fn = '.'.join(model_fn)

        # Unlike other tasks, we need to save current model, not best model.
        torch.save(
            {
                'model': engine.model.state_dict(),
                'opt': train_engine.optimizer.state_dict(),
                'config': config,
                'src_vocab': src_vocab,
                'tgt_vocab': tgt_vocab,
            }, model_fn
        )

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 


6. RL์„ ํ™œ์šฉํ•œ Unsupervised Learning

์ง€๋„ํ•™์Šต๋ฐฉ์‹์€ ๋†’์€ ์ •ํ™•๋„๋ฅผ ์ž๋ž‘ํ•œ๋‹ค. ๋‹ค๋งŒ, labeled data๊ฐ€ ํ•„์š”ํ•ด dataํ™•๋ณด๋‚˜ cost๊ฐ€ ๋†’๋‹ค.

๋น„์ง€๋„ํ•™์Šต๋ฐฉ์‹์€ dataํ™•๋ณด์— ๋Œ€ํ•œ cost๊ฐ€ ๋‚ฎ๊ธฐ์— ๋” ์ข‹์€ ๋Œ€์•ˆ์ด ๋  ์ˆ˜ ์žˆ๋‹ค.

(๋ฌผ๋ก , ์ง€๋„ํ•™์Šต์— ๋น„ํ•ด ์„ฑ๋Šฅ์ด๋‚˜ ํšจ์œจ์ด ๋–จ์–ด์งˆ ๊ฐ€๋Šฅ์„ฑ์€ ๋†’์Œ)

 

์ด๋Ÿฐ ์ ์—์„œ parallel corpus์— ๋น„ํ•ด monolinugal corpus๋ฅผ ํ™•๋ณดํ•˜๊ธฐ ์‰ฝ๋‹ค๋Š” NLP์˜ ํŠน์„ฑ์ƒ, ์ข‹์€ ๋Œ€์•ˆ์ด ๋  ์ˆ˜ ์žˆ๋‹ค.

์†Œ๋Ÿ‰์˜ parallel corpus์™€ ๋‹ค๋Ÿ‰์˜ monolingual corpus๋ฅผ ๊ฒฐํ•ฉ ๋” ๋‚˜์€ ์„ฑ๋Šฅ์„ ํ™•๋ณดํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค.

 

 

6.1 Unsupervised๋ฅผ ํ†ตํ•œ NMT
์ด๋ฒˆ์— ์†Œ๊ฐœํ•  ๋…ผ๋ฌธ์€ ์˜ค์ง monolingual corpus๋งŒ์„ ์‚ฌ์šฉํ•ด ๋ฒˆ์—ญ๊ธฐ๋ฅผ ์ œ์ž‘ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ–ˆ๋‹ค. [Guillaume Lampl;2018]
๋”ฐ๋ผ์„œ ์ง„์ •ํ•œ ๋น„์ง€๋„ํ•™์Šต์„ ํ†ตํ•œ NMT๋ผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

[ํ•ต์‹ฌ idea]
โˆ™ ์–ธ์–ด์— ์ƒ๊ด€์—†์ด ๊ฐ™์€ ์˜๋ฏธ์˜ ๋ฌธ์žฅ์ผ ๊ฒฝ์šฐ, Encoder๊ฐ€ ๊ฐ™์€ ๊ฐ’์œผ๋กœ embeddingํ•  ์ˆ˜ ์žˆ๋„๋ก ํ›ˆ๋ จํ•˜๋Š” ๊ฒƒ.
์ด๋ฅผ ์œ„ํ•ด GAN์ด ๋„์ž…๋˜์—ˆ๋‹ค!! →โ“์–ด? ๋ถ„๋ช… GAN์„ NLP์—์„œ ๋ชป์“ด๋‹ค๊ณ  ์•ž์„œ ์–˜๊ธฐํ–ˆ๋˜ ๊ฒƒ ๊ฐ™์€๋ฐ...??
โ—๏ธencoder์˜ ์ถœ๋ ฅ๊ฐ’์ด ์—ฐ์†์  ๊ฐ’์ด๊ธฐ์— GAN์„ ์ ์šฉํ•  ์ˆ˜ ์žˆ์—ˆ๋‹ค.
[Encoder]
์–ธ์–ด์— ์ƒ๊ด€์—†์ด ๋™์ผํ•œ ๋‚ด์šฉ์˜ ๋ฌธ์žฅ์— ๋Œ€ํ•ด ๊ฐ™์€ ๊ฐ’์˜ ๋ฒกํ„ฐ๋กœ encodingํ•˜๋„๋ก ํ›ˆ๋ จ
โˆ™ ์ด๋ฅผ ์œ„ํ•ด ํŒ๋ณ„์ž D(encoding๋œ ๋ฌธ์žฅ์˜ ์–ธ์–ด๋ฅผ ๋งž์ถ”๋Š” ์—ญํ• )๊ฐ€ ํ•„์š”ํ•˜๊ณ 
โˆ™ D๋ฅผ ์†์ด๋„๋ก Encoder๋Š” ํ•™์Šต๋œ๋‹ค.

[Decoder]
encoder์˜ ์ถœ๋ ฅ๊ฐ’์„ ๊ฐ–๊ณ  Decoder๋ฅผ ํ†ตํ•ด ๊ธฐ์กด ๋ฌธ์žฅ์œผ๋กœ ์ž˜ ๋Œ์•„์˜ค๋„๋ก ํ•จ

์ฆ‰, Encoder์™€ Decoder๋ฅผ ์–ธ์–ด์— ๋”ฐ๋ผ ๋‹ค๋ฅด๊ฒŒ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ์–ธ์–ด์— ์ƒ๊ด€์—†์ด 1๊ฐœ์”ฉ์˜ Encoder, Decoder๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.

๊ฒฐ๊ณผ์ ์œผ๋กœ ์†์‹คํ•จ์ˆ˜๋Š” ์•„๋ž˜ 3๊ฐ€์ง€ ๋ถ€๋ถ„์œผ๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.

 

์†์‹คํ•จ์ˆ˜์˜ 3๊ฐ€์ง€ ๊ตฌ์„ฑ

De-noising Auto-Encoder
seq2seq๋„ ์ผ์ข…์˜ Auto-Encoder์˜ ์ผ์ข…์ด๋ผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค.
AE๋Š” ๊ต‰์žฅํžˆ ์‰ฌ์šด ๋ฌธ์ œ์— ์†ํ•œ๋‹ค.

๋”ฐ๋ผ์„œ AE์—์„œ ๋‹จ์ˆœํžˆ ๋ณต์‚ฌ์ž‘์—…์„ ์ง€์‹œํ•˜๋Š”๋Œ€์‹ ,
noise๋ฅผ ์„ž์–ด์ค€ src๋ฌธ์žฅ์—์„œ De-noising์„ ํ•˜๋ฉด์„œ ์ž…๋ ฅ๊ฐ’์„ ์ถœ๋ ฅ์—์„œ ๋ณต์›(reconstruction)ํ•˜๋„๋ก ํ›ˆ๋ จํ•ด์•ผํ•˜๋Š”๋ฐ, ์ด๋ฅผ "De-noising AutoEncoder"๋ผ ๋ถ€๋ฅด๋ฉฐ ์•„๋ž˜์™€ ๊ฐ™์ด ํ‘œํ˜„๋œ๋‹ค.
์ด ์ˆ˜์‹์—์„œ  x_hat์€ ์ž…๋ ฅ๋ฌธ์žฅ x๋ฅผ noise_model C๋ฅผ ํ†ตํ•ด noise๋ฅผ ๋”ํ•˜๊ณ  ๊ฐ™์€ ์–ธ์–ด โ„“๋กœ encoding๊ณผ decoding์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•œ๋‹ค.
Δ(x_hat, x)๋Š” MRT์—์„œ์™€ ๊ฐ™์ด ์›๋ฌธ๊ณผ ๋ณต์›๋œ ๋ฌธ์žฅ๊ณผ์˜ ์ฐจ์ด๋ฅผ ์˜๋ฏธํ•œ๋‹ค.
noise_model C(x)๋Š” ์ž„์˜๋กœ ๋ฌธ์žฅ ๋‚ด ๋‹จ์–ด๋“ค์„ ๋“œ๋กญํ•˜๊ฑฐ๋‚˜ ์ˆœ์„œ๋ฅผ ์„ž์–ด์ฃผ๋Š” ์ผ์„ ํ•œ๋‹ค.
Cross Domain training
โ“Cross Domainํ›ˆ๋ จ์ด๋ž€
์‚ฌ์ „๋ฒˆ์—ญ์„ ํ†ตํ•ด ์‚ฌ์ „ํ›ˆ๋ จํ•œ ์ €์„ฑ๋Šฅ์˜ ๋ฒˆ์—ญ๋ชจ๋ธ M์—์„œ ์–ธ์–ด โ„“2์˜ noise๊ฐ€ ์ถ”๊ฐ€๋˜์–ด
๋ฒˆ์—ญ๋œ ๋ฌธ์žฅ y๋ฅผ ๋‹ค์‹œ ์–ธ์–ด โ„“1 src ๋ฌธ์žฅ์œผ๋กœ ์›์ƒ๋ณต๊ตฌํ•˜๋Š” ์ž‘์—…์„ ํ•™์Šตํ•˜๋Š” ๊ฒƒ
์ด๋‹ค.

Adversarial Learning
Encoder๊ฐ€ ์–ธ์–ด์™€ ์ƒ๊ด€์—†์ด ํ•ญ์ƒ ๊ฐ™์€ ๋ถ„ํฌ๋กœ latent space์— ๋ฌธ์žฅ๋ฒกํ„ฐ๋ฅผ embeddingํ•˜๋Š”์ง€ ๊ฐ์‹œํ•˜๋Š” ํŒ๋ณ„์ž D๊ฐ€ ์ถ”๊ฐ€๋˜์–ด ์ ๋Œ€์  ํ•™์Šต์„ ์ง„ํ–‰ํ•œ๋‹ค.
โˆ™ D๋Š” latent variable z์˜ ๊ธฐ์กด ์–ธ์–ด๋ฅผ ์˜ˆ์ธกํ•˜๋„๋ก ํ›ˆ๋ จ๋œ๋‹ค.
โˆ™ xi , โ„“i๋Š” ๊ฐ™์€ ์–ธ์–ด(language pair)๋ฅผ ์˜๋ฏธํ•œ๋‹ค.

๋”ฐ๋ผ์„œ GAN์ฒ˜๋Ÿผ Encoder๋Š” ํŒ๋ณ„์ž D๋ฅผ ์†์ผ ์ˆ˜ ์žˆ๋„๋ก ํ›ˆ๋ จ๋˜์–ด์•ผ ํ•œ๋‹ค.
์ด๋•Œ,  j = - (i - 1) ๊ฐ’์„ ๊ฐ–๋Š”๋‹ค.

 

์ตœ์ข… ๋ชฉ์ ํ•จ์ˆ˜
์œ„์˜ 3๊ฐ€์ง€ ๋ชฉ์ ํ•จ์ˆ˜๋ฅผ ๊ฒฐํ•ฉํ•˜๋ฉด ์ตœ์ข… ๋ชฉ์ ํ•จ์ˆ˜๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ๋‹ค.
๊ฐ λ๋ฅผ ํ†ตํ•ด ์†์‹คํ•จ์ˆ˜ ์ƒ์—์„œ ๋น„์œจ์„ ์กฐ์ ˆ, ์ตœ์ ์˜ parameter θ๋ฅผ ์ฐพ๋Š”๋‹ค.


๋…ผ๋ฌธ์—์„œ๋Š” ์˜ค์ง monolingual corpus๋งŒ ์กด์žฌํ•  ๋•Œ, NMT๋ฅผ ๋งŒ๋“œ๋Š” ๋ฐฉ์‹์— ๋Œ€ํ•ด ๋‹ค๋ฃฌ๋‹ค.
parallel corpus๊ฐ€ ์—†๋Š” ์ƒํ™ฉ์—์„œ๋„ NMT๋ฅผ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋”ฐ๋Š” ์ ์—์„œ ๋งค์šฐ ๊ณ ๋ฌด์ ์ด๋‹ค.

๋‹ค๋งŒ, ์ด ๋ฐฉ๋ฒ• ์ž์ฒด๋งŒ์œผ๋ก  ์‹ค์ œํ•„๋“œ์—์„œ ํ™œ์šฉํ•˜๊ธฐ ์–ด๋ ค์šด๋ฐ, ์‹ค๋ฌด์—์„œ๋Š” ๋ฒˆ์—ญ๊ธฐ๊ตฌ์ถ• ์‹œ parallel corpus๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ๋Š” ๋“œ๋ฌผ๊ณ , ์—†๋‹ค ํ•˜๋”๋ผ๋„ monolingual corpus๋งŒ์œผ๋กœ ๋ฒˆ์—ญ๊ธฐ๋ฅผ ๊ตฌ์ถ•ํ•ด ๋‚ฎ์€ ์„ฑ๋Šฅ์˜ ๋ฒˆ์—ญ๊ธฐ๋ฅผ ํ™•๋ณดํ•˜๊ธฐ ๋ณด๋‹จ, ๋น„์šฉ์„ ๋“ค์—ฌ parallel corpus๋ฅผ ์ง์ ‘ ๊ตฌ์ถ•ํ•˜๊ณ , parallel corpus์™€ ๋‹ค์ˆ˜์˜ monolingual corpus๋ฅผ ํ•ฉ์ณ NMT๋ฅผ ๊ตฌ์ถ•ํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ์ง„ํ–‰ํ•˜๋Š” ๊ฒƒ์ด ๋‚ซ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 


๋งˆ์น˜๋ฉฐ...

์ด๋ฒˆ์‹œ๊ฐ„์—๋Š” Reinforcement Learning์— ๋Œ€ํ•ด ์•Œ์•„๋ณด๊ณ , ์ด๋ฅผ ์ด์šฉํ•ด ์ž์—ฐ์–ด ์ƒ์„ฑ๋ฌธ์ œ(NLG)๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋‹ค๋ฃจ์—ˆ๋‹ค.
๋‹ค์–‘ํ•œ RL ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์‚ฌ์šฉํ•ด NLG๋ฌธ์ œ์˜ ์„ฑ๋Šฅ์„ ๋†’์ผ ์ˆ˜ ์žˆ๋Š”๋ฐ, ํŠนํžˆ Policy Gradient๋ฅผ ์‚ฌ์šฉํ•ด NLG์— ์ ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ–ˆ๋‹ค.



Policy Gradient๋ฐฉ๋ฒ•์„ NLG์— ์ ์šฉํ•ด ์–ป๋Š” ์ด์ ์€ ํฌ๊ฒŒ 2๊ฐ€์ง€์ธ๋‹ค.
โ‘  teacher-forcing(AR์†์„ฑ์œผ๋กœ ์ธํ•ด ์‹ค์ œ์ถ”๋ก ๋ฐฉ์‹๊ณผ ๋‹ค๋ฅด๊ฒŒ ํ›ˆ๋ จ)๋ฐฉ๋ฒ•์—์„œ ํƒˆํ”ผํ•ด
์‹ค์ œ ์ถ”๋ก ๋ฐฉ์‹๊ณผ ๊ฐ™์€ sampling์„ ํ†ตํ•ด ๋ฌธ์žฅ์ƒ์„ฑ๋Šฅ๋ ฅ ํ–ฅ์ƒ

โ‘ก ๋” ์ •ํ™•ํ•œ ๋ชฉ์ ํ•จ์ˆ˜๋ฅผ ํ›ˆ๋ จ์ด ๊ฐ€๋Šฅํ•˜๋‹ค.
 - ๊ธฐ์กด PPL: ๋ฒˆ์—ญํ’ˆ์งˆ, NLGํ’ˆ์งˆ์„ ์ •ํ™•ํžˆ ๋ฐ˜์˜X
 - ๋”ฐ๋ผ์„œ BLEU ๋“ฑ์˜ metric์„ ์‚ฌ์šฉ
 - ํ•˜์ง€๋งŒ BLEU ๋“ฑ์˜ metric์€ ๋ฏธ๋ถ„์„ ํ•  ์ˆ˜ ์—†์Œ.
  ∴ PPL๊ณผ ๋™์ผํ•œ Cross-Entropy๋ฅผ ํ™œ์šฉํ•ด ์‹ ๊ฒฝ๋ง์„ ํ›ˆ๋ จํ•ด์•ผ๋งŒ ํ–ˆ๋‹ค.

๋‹ค๋งŒ, Policy Gradient ๋˜ํ•œ ๋‹จ์ ์ด ์กด์žฌํ•œ๋‹ค.
โ‘  sampling๊ธฐ๋ฐ˜ ํ›ˆ๋ จ์ด๊ธฐ์— ๋งŽ์€ iteration์ด ํ•„์š”.
๋”ฐ๋ผ์„œ Cost๊ฐ€ ๋†’์•„ ๋” ๋น„ํšจ์œจ์  ํ•™์Šต์ด ์ง„ํ–‰๋œ๋‹ค.

โ‘ก ๋ณด์ƒํ•จ์ˆ˜๋Š” ๋ฐฉํ–ฅ์ด ์—†๋Š” ์Šค์นผ๋ผ๊ฐ’์„ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
๋”ฐ๋ผ์„œ ๋ณด์ƒํ•จ์ˆ˜๋ฅผ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ๋ฐฉํ–ฅ์„ ์ •ํ™•ํžˆ ์•Œ ์ˆ˜ ์—†๋‹ค.

์ด๋Š” ๊ธฐ์กด MLE๋ฐฉ์‹์—์„œ ์†์‹คํ•จ์ˆ˜๋ฅผ ์‹ ๊ฒฝ๋ง ํŒŒ๋ผ๋ฏธํ„ฐ θ
์— ๋Œ€ํ•ด
๋ฏธ๋ถ„ํ•ด ์–ป์€ ๊ธฐ์šธ๊ธฐ๋กœ ์†์‹คํ•จ์ˆ˜์ž์ฒด๋ฅผ ์ตœ์†Œํ™”ํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ updateํ•˜๋Š” ๊ฒƒ๊ณผ ์ฐจ์ด๊ฐ€ ์กด์žฌํ•œ๋‹ค.
๊ฒฐ๊ตญ, ์ด ๋˜ํ•œ ๊ธฐ์กด MLE๋ฐฉ์‹๋ณด๋‹ค ํ›จ์”ฌ ๋น„ํšจ์œจ์  ํ•™์Šต์œผ๋กœ ์ด์–ด์ง€๊ฒŒ ๋˜๋Š” ๊ฒƒ์ด๋‹ค.

+ Recent posts