๐ ๋ชฉ์ฐจ
1. preview
2. Reinforcement Learning ๊ธฐ์ด
3. Policy based RL
4. NLG์ RL ์ ์ฉ
5. RL์ ํ์ฉํ Supervised Learning
6. RL์ ํ์ฉํ Unsupervised Learning
๐ ๊ธ์ ๋ง์น๋ฉฐ...
Reinforcement Learning์ ๊ฒฝ์ฐ, ๋งค์ฐ ๋ฐฉ๋ํ ๋ถ์ผ์ด๊ธฐ์ ๊ทธ ๋ฐฉ๋ํ ์์ญ์ ์ผ๋ถ๋ถ์ธ Policy Gradient๋ฅผ ํ์ฉํด ์์ฐ์ด์์ฑ(NLG; Natural Language Generation)์ ์ฑ๋ฅ์ ๋์ด์ฌ๋ฆฌ๋ ๋ฐฉ๋ฒ์ ๋ค๋ค๋ณผ ๊ฒ์ด๋ค.
๋จผ์ , Reinforcement Learning์ด ๋ฌด์์ธ์ง, NLG์ ์ ํ์ํ์ง ์ฐจ๊ทผ์ฐจ๊ทผ ๋ค๋ค๋ณผ ๊ฒ์ด๋ค.
1. Preview
1.1 GAN (Generative Adversarial Network)
2016๋ ๋ถํฐ ์ฃผ๋ชฉ๋ฐ์ผ๋ฉฐ 2017๋ ๊ฐ์ฅ ํฐ ํ์ ๊ฐ ๋์๋ ๋ถ์ผ์ด๊ณ , ํ์ฌ Stable Diffusion ๋ฑ์ผ๋ก ์ธํด Vision์์ ๊ฐ์ฅ ํฐ ํ์ ์ ๋ถ์ผ๋ ๋จ์ฐ์ฝ ์์ฑ์ ๋์ ๊ฒฝ๋ง(GAN)์ด๋ค. ์ด๋ ๋ณ๋ถ์คํ ์ธ์ฝ๋(VAE)์ ํจ๊ป ์์ฑ๋ชจ๋ธํ์ต์ ๋ํํ๋ ๋ฐฉ๋ฒ ์ค ํ๋์ด๋ค.
์ด๋ ๊ฒ ์์ฑ๋ ์ด๋ฏธ์ง๋ ์ค์ํ์ ์ค์ํ์ง๋ง trainset์ ์ป๊ธฐ ํ๋ ๋ฌธ์ ๋ค์ ํด๊ฒฐ์ ํฐ ๋์์ ์ค ๊ฒ์ด๋ผ ๊ธฐ๋๋๊ณ ์๋ค.
์์ ๊ทธ๋ฆผ์ฒ๋ผ GAN์ ์์ฑ์(Generator)G์ ํ๋ณ์(Discriminator) D๋ผ๋ 2๊ฐ์ ๋ชจ๋ธ์ ๊ฐ๊ธฐ๋ค๋ฅธ ๋ชฉํ๋ฅผ ๊ฐ๊ณ "๋์์ ํ๋ จ"์ํจ๋ค.
๋ ๋ชจ๋ธ์ด ๊ท ํ์ ์ด๋ฃจ๋ฉด min/max ๊ฒ์์ ํผ์น๊ฒ ๋๋ฉด ์ต์ข ์ ์ผ๋ก G๋ ํ๋ฅญํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์ ์๊ฒ ๋๋ค.
1.2 GAN์ ์์ฐ์ด์์ฑ์ ์ ์ฉ
ํ๋ฒ GAN์ NLG์ ์ ์ฉํด๋ณด์.
์๋ฅผ๋ค์ด, CE๋ฅผ ์ฌ์ฉํด ๋ฐ๋ก ํ์ตํ๊ธฐ๋ณด๋จ
- ์ค์ corpus์์ ๋์จ ๋ฌธ์ฅ์ธ์ง
- seq2seq์์ ๋์จ ๋ฌธ์ฅ์ธ์ง
์์ ๋ ๊ฒฝ์ฐ์ ๋ํ ํ๋ณ์ D๋ฅผ ๋์ด seq2seq์์ ๋์จ ๋ฌธ์ฅ์ด ์ง์ง ๋ฌธ์ฅ๊ณผ ๊ฐ์์ง๋๋ก ํ๋ จํ๋ ๋ฑ์ ์์๋ก ๋ค ์ ์๋ค.
๋ค๋ง, ์์ฝ๊ฒ๋ ์ด ์ข์๋ณด์ด๋ ์์ด๋์ด๋ ๋ฐ๋ก ์ ์ฉํ ์ ์๋๋ฐ, seq2seq์ ๊ฒฐ๊ณผ๋ ์ด์ฐํ๋ฅ ๋ถํฌ์ด๊ธฐ ๋๋ฌธ์ด๋ค.
๋ฐ๋ผ์ ์ฌ๊ธฐ์ sampling์ด๋ argmax๋ก ์ป์ด์ง๋ ๊ฒฐ๊ณผ๋ฌผ์ discreteํ ๊ฐ์ด๊ธฐ์ one-hot๋ฒกํฐ๋ก ํํ๋์ด์ผ ํ ๊ฒ์ด๋ค.
ํ์ง๋ง ์ด ๊ณผ์ ์ ํ๋ฅ ์ ์ธ ๊ณผ์ (stochastic process)์ผ๋ก ๊ธฐ์ธ๊ธฐ๋ฅผ ์ญ์ ํํ ์ ์๊ฑฐ๋ ๋ฏธ๋ถ ๊ฒฐ๊ณผ๊ฐ์ด 0์ด๋ ๋ถ์ฐ์์ ์ธ ๊ฒฝ์ฐ๊ฐ ๋๊ธฐ์ D๊ฐ ๋ง์ถ ์ฌ๋ถ๋ฅผ ์ญ์ ํ๋ฅผ ํตํด seq2seq G๋ก ์ ๋ฌ๋ ์ ์๊ณ , ๊ฒฐ๊ณผ์ ์ผ๋ก ํ์ต์ด ๋ถ๊ฐ๋ฅํ๋ค.
1.3 GAN๊ณผ ์์ฐ์ด์์ฑ
1.1๊ณผ 1.2์์ ๋งํ๋ฏ, GAN์ Vision์์๋ ๋์ฑ๊ณต์ ์ด๋ค์ง๋ง NLG์์๋ ์ ์ฉ์ด ์ด๋ ค์ด๋ฐ, ์ด๋ ์์ฐ์ด ๊ทธ ์์ฒด์ ํน์ง๋๋ฌธ์ด๋ค.
โ ์ด๋ฏธ์ง๋ ์ด๋ค "์ฐ์์ "์ธ ๊ฐ๋ค๋ก ์ฑ์์ง ํ๋ ฌ
โ ์ธ์ด๋ ๋ถ์ฐ์์ ์ธ ๊ฐ๋ค์ ์์ฐจ์ ๋ฐฐ์ด์ด๊ธฐ์ ์ฐ๋ฆฐ LM์ ํตํด latent space์ ์ฐ์์ ๋ณ์๋ก ๊ทธ ๊ฐ๋ค์ ์นํํด ๋ค๋ฃฌ๋ค.
๊ฒฐ๊ตญ, ์ธ๋ถ์ ์ผ๋ก ์ธ์ด๋ฅผ ํํํ๋ ค๋ฉด ์ด์ฐํ๋ฅ ๋ถํฌ์ ๋ณ์๋ก ๋ํ๋ด์ผํ๊ณ ,
๋ถํฌ๊ฐ ์๋ ์ด๋ค sample๋ก ํํํ๋ ค๋ฉด ์ด์ฐํ๋ฅ ๋ถํฌ์์ samplingํ๋ ๊ณผ์ ์ด ํ์ํ๋ค.
์ด๋ฐ ์ด์ ๋ก D์ loss๋ฅผ G์ ์ ๋ฌํ ์ ์๊ณ , ๋ฐ๋ผ์ ์ ๋์ ์ ๊ฒฝ๋ง๋ฐฉ๋ฒ์ NLG์ ์ ์ฉํ ์ ์๋ค๋ ์ธ์์ด ์ฃผ๋ฅผ ์ด๋ฃจ๊ฒ ๋์๋ค.
But!! ๊ฐํํ์ต์ ํตํด ์ ๋์ ํ์ต๋ฐฉ์์ ์ฐํ์ ์ผ๋ก ์ฌ์ฉํ ์ ์๊ฒ ๋์๋ค.
1.4 ๊ฐํํ์ต ์ฌ์ฉ์ด์
์ด๋ค taskํด๊ฒฐ์ ์ํด CE๋ฅผ ์ธ ์ ์๋ classification์ด๋ continuous๋ณ์๋ฅผ ๋ค๋ฃจ๋ MSE๋ฑ์ผ๋ก๋ ์ ์ํ ์ ์๋ ๋ณต์กํ ๋ชฉ์ ํจ์๊ฐ ๋ง๊ธฐ ๋๋ฌธ์ด๋ค.
์ฆ, ์ฐ๋ฆฌ๋ CE๋ MSE๋ก ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ์ง๋ง, ์ด๋ ๋ฌธ์ ๋ฅผ ๋จ์ํํด ์ ๊ทผํ๋ค๋ ๊ฒ์ ์ ์ ์๋ค.
์ด๋ฐ ๋ฌธ์ ๋ค์ ๊ฐํํ์ต์ ํตํด ํด๊ฒฐํ๊ณ , ์ฑ๋ฅ์ ๊ทน๋ํํ ์ ์๋ค.
์ด๋ฅผ ์ํด ์ ์ค๊ณ๋ ๋ณด์(reward)์ ์ฌ์ฉํด ๋ ๋ณต์กํ๊ณ ์ ๊ตํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ ์ ์๋ค.
2. Reinforcement Learning ๊ธฐ์ด
Reinforcement Learning์ ์ด๋ฏธ ์ค๋์ Machine Learning์ ํ ์ข ๋ฅ๋ก ๋์๋ ๋ฐฉ๋ํ๊ณ ๋ ์ ์๊น์ ํ๋ฌธ์ด๊ธฐ์ ์ด๋ฒ์ ํ๋์ ๊ธ๋ก๋ ๋ค๋ฃจ๊ธฐ ๋ฌด๋ฆฌ๊ฐ ์๋ค.
๋ฐ๋ผ์ ์ด๋ฒ์๊ฐ์ ๋ค๋ฃฐ, Policy Gradient์ ๋ํ ์ดํด๋ฅผ ์ํด ํ์ํ ์ ๋๋ง ์ดํดํ๊ณ ๋์ด๊ฐ๋ณด๋ ค ํ๋ค.
์ถ๊ฐ์ ์ธ Reinforcement Learning์ ๋ํ ๊ธฐํ ์ถ์ฒ์ ์๋ ๋งํฌ ์ฐธ๊ณ .
⌈Reinforcement Learning: An Introduction ; (MIT Press, 2018)⌋
2.1 Universe
๋จผ์ ๊ฐํํ์ต์ด ๋์ํ๋ ๊ณผ์ ๋ถํฐ ์์๋ณด์.
Q: ๊ฐํํ์ต์ด๋??
A: ์ด๋ค ๊ฐ์ฒด๊ฐ ์ฃผ์ด์ง ํ๊ฒฝ์์, ์ํฉ์ ๋ฐ๋ผ ์ด๋ป๊ฒ ํ๋ํด์ผํ ์ง ํ์ตํ๋ ๋ฐฉ๋ฒ
์ฒ์ ์ํ์ธ St(t=0)์ ๋ฐ๊ณ ,
โAgent๋ ์์ ์ policy์ ๋ฐ๋ผ action At๋ฅผ ์ ํํ๋ค.
โEnvironment๋ Agent๋ก๋ถํฐ ์ ํ๋ At๋ฅผ ๋ฐ์
๋ณด์ Rt+1๊ณผ ์๋กญ๊ฒ ๋ฐ๋ ์ํ St+1์ ๋ฐํํ๋ค.
์ด ๊ณผ์ ์ ํน์ ์กฐ๊ฑด์ด ๋ง์กฑ๋ ๋๊น์ง ๋ฐ๋ณตํ๋ฉฐ, ํ๊ฒฝ์ ์ด ์ํ์ค๋ฅผ ์ข ๋ฃํ๋ค.
์ด๋ฅผ ํ๋์ episode๋ผ ํ๋ค.
๋ชฉํ: ๋ฐ๋ณต๋๋ episode์์ agent๊ฐ RL์ ํตํด ์ ์ ํ ํ๋(๋ณด์์ด ์ต๋๊ฐ ๋๋๋ก)์ ํ๋๋ก ํ๋ จ์ํค๋ ๊ฒ
2.2 MDP (Markov Decision Process)
์ฌ๊ธฐ์ ๋ํด Markov๊ฒฐ์ ๊ณผ์ (MDP)๋ผ๋ ๊ฐ๋ ์ ๋์ ํ๋ค.
์ฐ๋ฆฐ ์จ ์ธ์์ ํ์ฌ T=t๋ผ๋ ์๊ฐ์ ํ๋์ ์ํ(state)๋ก ์ ์ํ ์ ์๋ค.
๊ฐ์ ) ํ์ฌ์ํ(present state)๊ฐ ์ฃผ์ด์ก์ ๋, ๋ฏธ๋(T>t)๋ ๊ณผ๊ฑฐ(T<t)๋ก๋ถํฐ "๋ ๋ฆฝ์ ".
์ด ๊ฐ์ ํ์, ์จ ์ธ์์ Markov๊ณผ์ ์์์ ๋์ํ๋ค ํ ์ ์๋ค.
์ด๋, ํ์ฌ์ํฉ์์ ๋ฏธ๋์ํฉ์ผ๋ก ๋ฐ๋ ํ๋ฅ ์ P(S' | S)๋ก ํํ๊ฐ๋ฅํ๋ค.
์ฌ๊ธฐ์ MDP๋ ๊ฒฐ์ ์ ๋ด๋ฆฌ๋ ๊ณผ์ (= ํ๋์ ์ ํํ๋ ๊ณผ์ )์ด ์ถ๊ฐ๋ ๊ฒ์ด๋ค.
์ฆ, ํ์ฌ์ํฉ์์ ์ด๋คํ๋์ ์ ํ ์, ๋ฏธ๋์ํฉ์ผ๋ก ๋ฐ๋ ํ๋ฅ ๋ก P(S' | S, A)์ด๋ค.
์ฝ๊ฒ ์ค๋ช ํ์๋ฉด, ์๋ ๊ฐ์๋ฐ์๋ณด๊ฒ์์ ์์๋ก ๋ค ์ ์๋ค.
์๋ฅผ๋ค์ด ์ฌ๋๋ง๋ค ๊ฐ์๋ฐ์๋ณด๋ฅผ ๋ด๋ ํจํด์ด ๋ค๋ฅด๊ธฐ์ ๊ฐ์๋ฐ์๋ณด ์๋๋ฐฉ์ ๋ฐ๋ผ ์ ๋ต์ด ๋ฐ๋์ด์ผ ํ๋ฏ๋ก
โ ์๋๋ฐฉ์ด ์ฒซ ์ํ S0๋ฅผ ๊ฒฐ์ ํ๋ค.
โ Agent๋ S0์ ๋ฐ๋ผ ํ๋ A0๋ฅผ ์ ํํ๋ค.
โ ์๋๋ฐฉ์ ์๋๋ฐฉ์ ์ ์ฑ ์ ๋ฐ๋ผ ๊ฐ์/๋ฐ์/๋ณด ์ค ํ๋๋ฅผ ๋ธ๋ค.
โ ์นํจ๊ฐ ๊ฒฐ์ , ํ๊ฒฝ์ผ๋ก๋ถํฐ ๋ฐฉ๊ธ ์ ํํ ํ๋์ ๋ํ ๋ณด์ R1์ ๋ฐ๋๋ค.
โ update๋ ์ํ S1์ ์ป๋๋ค.
2.3 Reward
์์ Agent๊ฐ ์ด๋ค ํ๋์ ํ์ ๋, ํ๊ฒฝ์ผ๋ก๋ถํฐ "๋ณด์"์ ๋ฐ๋๋ค.
์ด๋, ์ฐ๋ฆฐ Gt๋ฅผ ์ด๋ค ์์ ์ผ๋ก๋ถํฐ ๋ฐ๋ ๋ณด์์ ๋์ ํฉ์ด๋ผ ์ ์ํ์.
๋ฐ๋ผ์ Gt๋ ์๋์ ๊ฐ์ด ์ ์๋๋ค.
์ด๋ ๊ฐ์์จ(discount factor) γ(0~1๊ฐ)๋ฅผ ๋์ ํด ์์์ ์กฐ๊ธ ๋ณํํ ์ ์๋ค.
γ์ ๋์ ์ผ๋ก ๋จผ ๋ฏธ๋์ ๋ณด์๋ณด๋ค ๊ฐ๊น์ด ๋ฏธ๋์ ๋ณด์์ ๋ ์ค์ํด ๋ค๋ฃฐ ์ ์๊ฒ ๋๋ค.
2.4 Policy
Agent๋ ์ฃผ์ด์ง ์ํ์์ ์์ผ๋ก ๋ฐ์ ๋ณด์์ ๋์ ํฉ์ ์ต๋๋ก ํ๋๋ก ํ๋ํด์ผํ๋ค.
์ฆ, ๋์์ ์์ ์์ ์ํด๋ณด๋ค ๋จผ๋ฏธ๋๊น์ง ํฌํจํ ๋ณด์์ ์ดํฉ์ด ์ต๋๊ฐ ๋๋๊ฒ์ด ์ค์ํ๋ค.
(๐ ๋ง์น ์ฐ๋ฆฌ๊ฐ ์ํ๊ธฐ๊ฐ์ ๋์ง๋ชปํ๊ณ ๊ณต๋ถํ๋ ๊ฒ์ฒ๋ผ...?)
์ ์ฑ (policy)์ Agent๊ฐ ์ํฉ์ ๋ฐ๋ผ ์ด๋ป๊ฒ ํ๋์ ํด์ผํ ์ง, "ํ๋ฅ ์ ์ผ๋ก ๋ํ๋ธ ๊ธฐ์ค"์ด๋ค.
์ฆ, ๊ฐ์ ์ํฉ์ด ์ฃผ์ด์ก์ ๋, ์ด๋ค ํ๋์ ์ ํํ ์ง์ ๋ํ ํ๋ฅ ํจ์์ด๊ธฐ์ ๋ฐ๋ผ์ ์ฐ๋ฆฌ๊ฐ ํ๋ํ๋ ๊ณผ์ ์ ํ๋ฅ ์ ์ธ ํ๋ก์ธ์ค๋ผ ๋ณผ ์ ์๋ค. ์ด๋, ํจ์๋ฅผ ํตํด ์ฃผ์ด์ง ํ๋์ ์ทจํ ํ๋ฅ ๊ฐ์ ์๋ ํจ์๋ฅผ ํตํด ๊ตฌํ ์ ์๋ค.
2.5 Value Function (๊ฐ์นํจ์)
๊ฐ์นํจ์๋ ์ฃผ์ด์ง policy ๐ ๋ด์ ํน์ ์ํ s์์๋ถํฐ ์์ผ๋ก ์ป์ ์ ์๋ ๋ณด์์ ๋์ ์ดํฉ์ ๊ธฐ๋๊ฐ์ ์๋ฏธํ๋ค.
์๋ ์์๊ณผ ๊ฐ์ด ๋ํ๋ผ ์ ์๋๋ฐ, ์์ผ๋ก ์ป์ ์ ์๋ ๋ณด์์ ์ดํฉ์ ๊ธฐ๋๊ฐ์ ๊ธฐ๋๋์ ๋ณด์(Expected Cumulative Reward)๋ผ๊ณ ๋ ํ๋ค.
ํ๋๊ฐ์นํจ์ (Q-function ; Q ํจ์)
ํ๋๊ฐ์นํจ์(activation-value function ; Q-function)๋ ์ฃผ์ด์ง policy ๐ ์๋ ์ํฉ s์์ action a๋ฅผ ์ ํํ์ ๋, ์์ผ๋ก ์ป์ ์ ์๋ ๋ณด์์ ๋์ ํฉ์ ๊ธฐ๋๊ฐ(๊ธฐ๋๋์ ๋ณด์)์ ํํํ๋ค.
๊ฐ์นํจ์๊ฐ ์ด๋ค s์์ ์ด๋ค a๋ฅผ ์ ํํ ์ง์ ๊ด๊ณ์์ด ์ป์ ์ ์๋ ๋์ ๋ณด์์ ๊ธฐ๋๊ฐ์ด๋ผ ํ๋ค๋ฉด
Qํจ์๋ ์ด๋ค a๋ฅผ ์ ํํ๋๊ฐ์ ๋ํ ๊ฐ๋ ์ด ์ถ๊ฐ๋ ๊ฒ์ด๋ค.
์ฆ, ์ํ์ ํ๋์ ๋ฐ๋ฅธ ๊ธฐ๋๋์ ๋ณด์์ ๋ํ๋ด๋ Qํจ์์ ์์ ์๋์ ๊ฐ๋ค.
2.6 Bellman ๋ฐฉ์ ์
๊ฐ์นํจ์์ ํ๋๊ฐ์นํจ์์ ์ ์์ ๋ฐ๋ผ ์ด์์ ์ธ ๊ฐ์นํจ์์ ์ด์์ ์ธ Qํจ์๋ฅผ ์ ์ํด๋ณด๊ณ ์ ํ๋ค๋ฉด...?
→ Bellman Equation(๋ฒจ๋ง ๋ฐฉ์ ์)์ ๋ค์๊ณผ ๊ฐ์ด ๋ํ๋ผ ์ ์๋ค.
DP: ๋ฌธ์ ๋ฅผ ๊ฒน์น๋ ํ์ ๋ฌธ์ (sub-problems)๋ก ๋ถํดํ๊ณ ์ต์ ๋ถ๋ถ ๊ตฌ์กฐ(optimal substructure)๋ฅผ ๋ฐ๋ฅด๋ ๋ฐฉ๋ฒ์ผ๋ก ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๊ธฐ์ ๋ก ํฐ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๊ฒ์ ์์ ํ์ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๊ฒ์ผ๋ก ๋๋์ด์ง ์ ์์ต๋๋ค. ์ด ์์ ํ์ ๋ฌธ์ ๋ค์ ํด๊ฒฐํ ํ์๋ ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ์กฐํฉํ์ฌ ์๋ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํฉ๋๋ค. ์ค์ํ ์ ์ ๋์ผํ ํ์ ๋ฌธ์ ๊ฐ ์ฌ๋ฌ ๋ฒ ๊ณ์ฐ๋๋ ๋์ , ํ ๋ฒ ๊ณ์ฐ๋ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ๊ณ ์ฌ์ฌ์ฉํ์ฌ ๊ณ์ฐ ๋น์ฉ์ ์ค์ด๋ ๊ฒ์ ๋๋ค. ์ด๊ฒ์ด DP์ ํต์ฌ ์์ด๋์ด์ด๋ฉฐ, ๊ฒน์น๋ ํ์ ๋ฌธ์ ๊ฐ ์๋ค๋ฉด DP๋ฅผ ํจ๊ณผ์ ์ผ๋ก ํ์ฉํ ์ ์๋ค.
Bellman๋ฐฉ์ ์์ ๋์ ํ๋ก๊ทธ๋๋ฐ(DP; Dynamic Programming)์๊ณ ๋ฆฌ์ฆ ๋ฌธ์ ๋ก ์ ๊ทผ๊ฐ๋ฅํ๋ค.
์ฆ, ๋จ์ํ ์ต๋จ๊ฒฝ๋ก๋ฅผ ์ฐพ๋ ๋ฌธ์ ์ ๋น์ทํ๋ฐ, ์๋ ๊ทธ๋ฆผ์ฒ๋ผ ๋ชจ๋ ๊ฒฝ์ฐ์ ๋ํด ํ์์ ์ํํ๋ Back-Tracking๋ณด๋ค ํจ์ฌ ํจ์จ์ ์ด๊ณ ๋น ๋ฅด๊ฒ ๋ฌธ์ ํ์ด์ ์ ๊ทผํ ์ ์๋ค.
2.7 Monte Carlo Method
ํ์ง๋ง ๊ทธ๋ ๊ฒ ์ฝ๊ฒ ๋ฌธ์ ๊ฐ ํ๋ฆฌ๋ฉด ๊ฐํํ์ต์ด ํ์ํ์ง ์์์ ๊ฒ์ด๋ค.
Prob) ๋๋ถ๋ถ์ ๊ฒฝ์ฐ, ์์ 2.6์ ์์์์ ๊ฐ์ฅ ์ค์ํ P(s', r | s, a)๋ถ๋ถ์ ๋ชจ๋ฅธ๋ค๋ ์ ์ด๋ค.
์ฆ, ์ด๋ค ์ํ→์ด๋คํ๋→์ด๋คํ๋ฅ ๋ก ๋ค๋ฅธ์ํ s'๊ณผ ๋ณด์ r'์ ๋ฐ๊ฒ๋๋์ง, "์ง์ ํด๋ด์ผ" ์๋ค๋ ์ ์ด๋ค.
∴ DP๊ฐ ์ ์ฉ๋ ์ ์๋ ๊ฒฝ์ฐ๊ฐ ๋๋ถ๋ถ์ด๋ค.
๋ฐ๋ผ์ RL์ฒ๋ผ simulation๋ฑ์ ๊ฒฝํ์ ํตํด Agent๋ฅผ ํ์ตํด์ผํ๋ค.
์ด๋ฐ ์ด์ ๋ก Monte Carlo Method์ฒ๋ผ sampling์ ํตํด Bellman์ Expectation Equation์ ํด๊ฒฐํ ์ ์๋ค.
Prob) ๊ธด episode์ ๋ํด sampling์ผ๋ก ํ์ตํ๋ ๊ฒฝ์ฐ์ด๋ค.
์ค์ episode๊ฐ ๋๋์ผ Gt๋ฅผ ๊ตฌํ ์ ์๊ธฐ์ episode์ ๋๊น์ง ๊ธฐ๋ค๋ ค์ผ ํ๋ค.
๋ค๋ง, ๊ทธ๊ฐ ์ตํ ๋ณด์์จ AlphaGo์ฒ๋ผ ๊ต์ฅํ ๊ธด episode์ ๊ฒฝ์ฐ, ๋งค์ฐ ๋ง์ ์๊ฐ๊ณผ ๊ธด ์๊ฐ์ด ํ์ํ๊ฒ ๋๋ค.
2.8 TD ํ์ต (Temporal Difference Learning)
์ด๋, ์๊ฐ์ฐจํ์ต(TD)๋ฐฉ๋ฒ์ด ์ ์ฉํ๋ค.
TDํ์ต๋ฒ์ ๋ค์์์์ฒ๋ผ episode๋ณด์์ ๋์ ํฉ ์์ด๋ ๋ฐ๋ก ๊ฐ์นํจ์๋ฅผ updateํ ์ ์๋ค.
Q-Learning
๋ง์ฝ ์ฌ๋ฐ๋ฅธ Qํจ์๋ฅผ ์๊ณ ์๋ค๋ฉด, ์ด๋ค์ํฉ์ด๋๋ผ๋ ํญ์ ๊ธฐ๋๋์ ๋ณด์์ ์ต๋ํํ๋ ๋งค์ฐ ์ข์ ์ ํ์ด ๊ฐ๋ฅํ๋ค.
์ด๋, Qํจ์๋ฅผ ์ ํ์ตํ๋ ๊ฒ์ Q-Learning์ด๋ผ ํ๋ค.
์๋์์์ฒ๋ผ target๊ณผ ํ์ฌ๊ฐ์นํจ์(current)์ ์ฐจ์ด๋ฅผ ์ค์ด๋ฉด ์ฌ๋ฐ๋ฅธ Qํจ์๋ฅผ ํ์ตํ ๊ฒ์ด๋ค.
DQN (Deep Q-Learning)
Qํจ์๋ฅผ ํ์ตํ ๋, state๊ณต๊ฐ์ ํฌ๊ธฐ์ action๊ณต๊ฐ์ ํฌ๊ธฐ๊ฐ ๋๋ฌด ์ปค ์ํฉ๊ณผ ํ๋์ด ํฌ์ํ ๊ฒฝ์ฐ, ๋ฌธ์ ๊ฐ ์๊ธด๋ค.
ํ๋ จ๊ณผ์ ์์ ํฌ์์ฑ์ผ๋ก ์ธํด ์ ๋ณผ ์ ์๊ธฐ ๋๋ฌธ์ด๋ค.
์ด์ฒ๋ผ ์ํฉ๊ณผ ํ๋์ด ๋ถ์ฐ์์ ์ธ ๋ณ๊ฐ์ ๊ฐ์ด๋๋ผ๋, Qํจ์๋ฅผ ๊ทผ์ฌํ๋ฉด ๋ฌธ์ ๊ฐ ๋ฐ์ํ ์ ์๋ค.
DeepMind๋ ์ ๊ฒฝ๋ง์ ์ฌ์ฉํด ๊ทผ์ฌํ Q-Learning์ ํตํด Atari๊ฒ์์ ํ๋ฅญํ ํ๋ ์ดํ๋ ๊ฐํํ์ต๋ฐฉ๋ฒ์ ์ ์ํ๋๋ฐ, ์ด๋ฅผ DQN(Deep Q-Learning)์ด๋ผ ํ๋ค.
์๋ ์์์ฒ๋ผ Qํจ์ ๋ถ๋ถ์ ์ ๊ฒฝ๋ง์ผ๋ก ๊ทผ์ฌํด ํฌ์์ฑ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ณ , ์ฌ์ง์ด Atari๊ฒ์์ ์ฌ๋๋ณด๋ค ๋ ์ ํ๋ ์ดํ๋ Agent๋ฅผ ํ์ตํ๊ธฐ๋ ํ๋ค.
3. Policy based Reinforcement Learning
3.1 Policy Gradient
Policy Gradient๋ Policy based Reinforcement Learning๋ฐฉ์์ ์ํ๋ค.
cf) DQN์ ๊ฐ์น๊ธฐ๋ฐ(Value based Reinforcement Learning)๋ฐฉ์์ ์ํ๋ค.-ex) DeepMind๊ฐ ์ฌ์ฉํ AlphaGo.
์ ์ฑ ๊ธฐ๋ฐ๊ณผ ๊ฐ์น๊ธฐ๋ฐ ๊ฐํํ์ต, ๋ ๋ฐฉ์์ ์ฐจ์ด๋ ๋ค์๊ณผ ๊ฐ๋ค.
โ ๊ฐ์น๊ธฐ๋ฐํ์ต: ANN์ ์ฌ์ฉํด ์ด๋ค ํ๋์ ์ ํ ์, ์ป์ ์ ์๋ ๋ณด์์ ์์ธกํ๋๋ก ํ๋ จ
โ ์ ์ฑ ๊ธฐ๋ฐํ์ต: ANN์ ํ๋์ ๋ํ ๋ณด์์ ์ญ์ ํ์๊ณ ๋ฆฌ์ฆ์ ํตํด ์ ๋ฌํด ํ์ต
∴ DQN์ ๊ฒฝ์ฐ, ํ๋์ ์ ํ์ด ํ๋ฅ ์ (stochastic)์ผ๋ก ๋์ค์ง ์์ง๋ง
Policy Gradient๋ ํ๋ ์ ํ ์, ํ๋ฅ ์ ์ธ ๊ณผ์ ์ ๊ฑฐ์น๋ค.
Policy Gradient์ ๋ํ ์์์ ์๋์ ๊ฐ๋ค.
์ด ์์์์ ์์ ๐๋ policy๋ฅผ ์๋ฏธํ๋ค.
์ฆ, ์ ๊ฒฝ๋ง ๊ฐ์ค์น θ๋ ํ์ฌ์ํ s๊ฐ ์ฃผ์ด์ก์๋, ์ด๋คํ๋ a๋ฅผ ์ ํํด์ผํ๋์ง์ ๊ดํ ํ๋ฅ ์ ๋ฐํํ๋ค.
์ฐ๋ฆฌ์ ๋ชฉํ๋ ์ต์ด์ํ(initial state)์์์ ๊ธฐ๋๋์ ๋ณด์์ ์ต๋๋ก ํ๋ ์ ์ฑ θ๋ฅผ ์ฐพ๋๊ฒ์ด๊ณ
์ต์ํํด์ผํ๋ ์์ค๊ณผ ๋ฌ๋ฆฌ ๋ณด์์ ์ต๋ํํด์ผํ๋ฏ๋ก
๊ธฐ์กด์ ๊ฒฝ์ฌํ๊ฐ๋ฒ๋์ , ๊ฒฝ์ฌ์์น๋ฒ(Gradient Ascent)๋ฅผ ์ฌ์ฉํ๋ค.
์ด๋ฐ ๊ฒฝ์ฌ์์น๋ฒ์ ๋ฐ๋ผ ∇θJ(θ)๋ฅผ ๊ตฌํด θ๋ฅผ updateํด์ผํ๋ค.
์ด๋, d(s)๋ Markov Chain์ ์ ์ ๋ถํฌ(stationary distribution)๋ก์จ ์์์ ์ ์๊ด์์ด ์ ์ฒด๊ฒฝ๋ก์์ s์ ๋จธ๋ฌด๋ฅด๋ ์๊ฐ๋น์จ์ ์๋ฏธํ๋ค.
์ด๋, ๋ก๊ทธ๋ฏธ๋ถ์ ์ฑ์ง์ ์ด์ฉํด ์๋ ์์์ฒ๋ผ ∇θJ(θ)๋ฅผ ๊ตฌํ ์ ์๋ค.
์ด ์์์ ํด์ํ์๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
โ ๋งค time-step๋ณ ์ํฉ s๊ฐ ์ฃผ์ด์ง ๋, a๋ฅผ ์ ํํ ๋ก๊ทธํ๋ฅ ์ ๊ธฐ์ธ๊ธฐ์ ๊ทธ์๋ฐ๋ฅธ ๋ณด์์ ๊ณฑํ ๊ฐ์ ๊ธฐ๋๊ฐ์ด ๋๋ค.
Policy Gradient Theorem์ ๋ฐ๋ฅด๋ฉด
์ฌ๊ธฐ์ ํด๋น time-step์ ๋ํ ์ฆ๊ฐ์ ๋ณด์ r๋์ , episode ์ข ๋ฃ๊น์ง์ ๊ธฐ๋๋์ ๋ณด์์ ์ฌ์ฉํ ์ ์๋ค.
์ฆ, Qํจ์๋ฅผ ์ฌ์ฉํ ์ ์๋ค๋ ๊ฒ์ธ๋ฐ, ์ด๋ Policy Gradient์ ์ง๊ฐ๊ฐ ๋๋ฌ๋๋ค.
์ฐ๋ฆฐ Policy Gradient ์ ๊ฒฝ๋ง์ ๋ํด ๋ฏธ๋ถ๊ณ์ฐ์ด ํ์ํ์ง๋ง,
"Qํจ์์ ๋ํด์๋ ๋ฏธ๋ถํ ํ์๊ฐ ์๋ค!!"
์ฆ, ๋ฏธ๋ถ๊ฐ๋ฅ์ฌ๋ถ๋ฅผ ๋ ๋, ์์์ ์ด๋ค ํจ์๋ผ๋ ๋ณด์ํจ์๋ก ์ฌ์ฉํ ์ ์๋ ๊ฒ์ด๋ค!!
์ด๋ ๊ฒ ์ด๋ค ํจ์๋ ๋ณด์ํจ์๋ก ์ฌ์ฉํ ์ ์๊ฒ๋๋ฉด์, ๊ธฐ์กด Cross-Entropy๋ MSE๊ฐ์ ์์คํจ์๋ก fittingํ๋ ๋์ , ์ข ๋ ์ค๋ฌด์ ๋ถํฉํ๋ ํจ์(ex. ๋ฒ์ญ์ ๊ฒฝ์ฐ, BLEU)๋ฅผ ์ฌ์ฉํด θ๋ฅผ ํ๋ จ์ํฌ ์ ์๊ฒ๋์๋ค.
์ถ๊ฐ์ ์ผ๋ก ์์ ์์์์ ๊ธฐ๋๊ฐ ์์์ Monte Carlo Sampling์ผ๋ก ๋์ฒดํ๋ฉด ์๋์ฒ๋ผ ์ ๊ฒฝ๋ง ํ๋ผ๋ฏธํฐ θ๋ฅผ updateํ ์ ์๋ค.
์ด ์์์ ๋ ํ์ด์ ์ค๋ช ํด๋ณด์.
โ log๐θ(at | st) : st๊ฐ ์ฃผ์ด์ก์ ๋, ์ ์ฑ ํ๋ผ๋ฏธํฐ θ์์ ํ๋ฅ ๋ถํฌ์์ sampling๋์ด ์ ํ๋ ํ๋์ด at์ผ ํ๋ฅ ๊ฐ์ด๋ค.
ํด๋นํ๋ฅ ๊ฐ์ θ์ ๋ํด ๋ฏธ๋ถํ ๊ฐ์ด ∇θlog๐θ(at | st)์ด๋ค.
๋ฐ๋ผ์ ํด๋น ๊ธฐ์ธ๊ธฐ๋ฅผ ํตํ ๊ฒฝ์ฌ์์น๋ฒ์ log๐θ(at | st)๋ฅผ ์ต๋ํํจ์ ์๋ฏธํ๋ค.
โ ์ฆ, at์ ํ๋ฅ ์ ๋์ด๋๋ก ํ์ฌ ์์ผ๋ก ๋์ผ์ํ์์ ํด๋นํ๋์ด ๋ ์์ฃผ ์ ํ๋๊ฒ ํ๋ค.
Gradient ∇θlog๐θ(at | st)์ ๋ณด์์ ๊ณฑํด์ฃผ์๊ธฐ์ ๋ง์ฝ sampling๋ ํด๋น ํ๋๋ค์ด ํฐ ๋ณด์์ ๋ฐ์๋ค๋ฉด,
ํ์ต๋ฅ γ์ ์ถ๊ฐ์ ์ธ ๊ณฑ์ ์ ํตํด ๋ ํฐ step์ผ๋ก ๊ฒฝ์ฌ์์น์ ์ํํ ์ ์๋ค.
ํ์ง๋ง ์์ ๋ณด์๊ฐ์ ๋ฐ๊ฒ๋๋ค๋ฉด, ๊ฒฝ์ฌ์ ๋ฐ๋๋ฐฉํฅ์ผ๋ก step์ ๊ฐ๊ฒ ๊ฐ์ด ๊ณฑํด์ง ๊ฒ์ด๋ฏ๋ก ๊ฒฝ์ฌํ๊ฐ๋ฒ์ ์ํํ๋ ๊ฒ๊ณผ ๊ฐ์ ํจ๊ณผ๊ฐ ๋ฐ์ํ ๊ฒ์ด๋ค.
๋ฐ๋ผ์ ํด๋น sampling๋ a๋ค์ด ์์ผ๋ก๋ ์ ๋์ค์ง ์๊ฒ ์ ๊ฒฝ๋ง ํ๋ผ๋ฏธํฐ θ๊ฐ update๋ ๊ฒ์ด๋ค.
๋ฐ๋ผ์ ์ค์ ๋ณด์์ ์ต๋ํํ๋ ํ๋์ ํ๋ฅ ์ ์ต๋๋กํ๋ ํ๋ผ๋ฏธํฐ θ๋ฅผ ์ฐพ๋๋ก ํ ๊ฒ์ด๋ค.
๋ค๋ง, ๊ธฐ์กด์ ๊ฒฝ์ฌ๋๋ ๋ฐฉํฅ๊ณผ ํฌ๊ธฐ๋ฅผ ๋ํ๋ผ ์ ์์์ง๋ง,
Policy Gradient๋ ๊ธฐ์กด ๊ฒฝ์ฌ๋์ ๋ฐฉํฅ์ ์ค์นผ๋ผ ํฌ๊ธฐ๊ฐ์ ๊ณฑํด์ฃผ๋ฏ๋ก
์ค์ ๋ณด์์ ์ต๋ํํ๋ ์ง์ ์ ์ธ ๋ฐฉํฅ์ ์ง์ ํ ์ ์๊ธฐ์ ์ฌ์ค์ ํ๋ จ์ด ์ด๋ ต๊ณ ๋นํจ์จ์ ์ด๋ผ๋ ๋จ์ ์ด ์กด์ฌํ๋ค.
3.2 MLE v.s Policy Gradient
๋ค์ ์์๋ก ์ต๋๊ฐ๋ฅ๋์ถ์ (MLE)๊ณผ์ ๋น๊ต๋ฅผ ํตํด Policy Gradient๋ฅผ ๋ ์ดํดํด๋ณด์.
โ n๊ฐ์ sequence๋ก ์ด๋ค์ง data๋ฅผ ์ ๋ ฅ๋ฐ์
โ m๊ฐ์ sequence๋ก ์ด๋ค์ง data๋ฅผ ์ถ๋ ฅํ๋ ํจ์๋ฅผ ๊ทผ์ฌ์ํค๋ ๊ฒ์ด ๋ชฉํ
๊ทธ๋ ๋ค๋ฉด sequence x1:n๊ณผ y1:m์ B๋ผ๋ dataset์ ์กด์ฌํ๋ค.
๋ชฉํ) ์ค์ ํจ์ f: x→y๋ฅผ ๊ทผ์ฌํ๋ ์ ๊ฒฝ๋ง parameter θ๋ฅผ ์ฐพ๋๊ฒ์ด๋ฏ๋ก
์ด์ ํด๋น ํจ์๋ฅผ ๊ทผ์ฌํ๊ธฐ ์ํด parameter θ๋ฅผ ํ์ตํด์ผํ๋ค.
θ๋ ์๋์ ๊ฐ์ด MLE๋ก ์ป์ ์ ์๋ค.
Dataset B์ ๊ด๊ณ๋ฅผ ์ ์ค๋ช ํ๋ θ๋ฅผ ์ป๊ธฐ์ํด, ๋ชฉ์ ํจ์๋ฅผ ์๋์ ๊ฐ์ด ์ ์ํ๋ค.
์๋ ์์์ Cross-Entropy ๋ฅผ ๋ชฉ์ ํจ์๋ก ์ ์ํ ๊ฒ์ด๋ค.
๋ชฉํ) ์์คํจ์์ ๊ฐ์ ์ต์ํ ํ๋ ๊ฒ.
์์์ ์ ์ํ ๋ชฉ์ ํจ์๋ฅผ ์ต์ํ ํด์ผํ๋ฏ๋ก Optimizer๋ฅผ ํตํด ๊ทผ์ฌํ ์ ์๋ค.
(์๋ ์์์ Optimizer๋ก Gradient Descent๋ฅผ ์ฌ์ฉํ์๋ค.)
ํด๋น ์์์์ ํ์ต๋ฅ γ๋ฅผ ํตํด update์ ํฌ๊ธฐ๋ฅผ ์กฐ์ ํ๋ค.
๋ค์ ์์์ Policy Gradient์ ๊ธฐ๋ฐํด ๋์ ๊ธฐ๋๋ณด์์ ์ต๋๋กํ๋ ๊ฒฝ์ฌ์์น๋ฒ ์์์ด๋ค.
์ด ์์์์๋ ์ด์ MLE์ ๊ฒฝ์ฌํ๊ฐ๋ฒ ์์์ฒ๋ผ γ์ ์ถ๊ฐ๋ก Q๐θ(st, at)๊ฐ ๊ธฐ์ธ๊ธฐ ์์ ๋ถ์ด ํ์ต๋ฅ ์ญํ ์ ํ๋ ๊ฒ์ ๋ณผ ์ ์๋ค.
๋ฐ๋ผ์ ๋ณด์์ ํฌ๊ธฐ์ ๋ฐ๋ผ ํด๋น ํ๋์ ๋์ฑ ๊ฐํํ๊ฑฐ๋ ๋ฐ๋ ๋ฐฉํฅ์ผ๋ก ๋ถ์ ํ ์ ์๊ฒ ๋๋ ๊ฒ์ด๋ค.
์ฆ, ๊ฒฐ๊ณผ์ ๋ฐ๋ผ ๋์ ์ผ๋ก ํ์ต๋ฅ ์ ์๋ง๊ฒ ์กฐ์ ํด์ค๋ค๊ณ ์ดํดํ ์ ์๋ค.
3.3 Baseline์ ๊ณ ๋ คํ Reinforce ์๊ณ ๋ฆฌ์ฆ
์์์ ์ค๋ช ํ Policy Gradient๋ฅผ ์ํํ ๋, ๋ณด์์ด ์์์ธ ๊ฒฝ์ฐ ์ด๋ป๊ฒ ๋์ํ ๊น?
ex) ์ํ์ด 0~100์ ์ฌ์ด ๋ถํฌํ ๋, ์ ๊ท๋ถํฌ์ ์ํด ๋๋ถ๋ถ ํ๊ท ๊ทผ์ฒ์ ์ ์๊ฐ ๋ถํฌํ ๊ฒ์ด๋ค.
๋ฐ๋ผ์ ๋๋ถ๋ถ์ ํ์๋ค์ ์์ ๋ณด์์ ๋ฐ๋๋ค.
๊ทธ๋ ๊ฒ๋๋ฉด, ์์ ๊ธฐ์กด policy gradient๋ ํญ์ ์์ ๋ณด์์ ๋ฐ์ Agent(ํ์)์๊ฒ ํด๋น policy๋ฅผ ๋์ฑ ๋ ๋ คํ ๊ฒ์ด๋ค.
But! ํ๊ท ์ ์๊ฐ 50์ ์ผ ๋, 10์ ์ ์๋์ ์ผ๋ก ๋งค์ฐ ๋์์ ์์ด๋ฏ๋ก ๊ธฐ์กด ์ ์ฑ ์ ๋ฐ๋๋ฐฉํฅ์ผ๋ก ํ์ตํ๊ฒ๋๋ค.
์ฆ, ์ฃผ์ด์ง ์ํฉ์ ๋ง๋ ํ ๋ณด์์ด ์๊ธฐ์ ์ฐ๋ฆฐ ์ด๋ฅผ ๋ฐํ์ผ๋ก ํ์ฌ policy๊ฐ ์ผ๋ง๋ ํ๋ฅญํ์ง ํ๊ฐ๋ฅผ ํ ์ ์๋ค.
์ด๋ฅผ ์๋ ์์์ฒ๋ผ policy gradient์์์ผ๋ก ํํํ ์ ์๋ค.
์ด์ฒ๋ผ Reinforce ์๊ณ ๋ฆฌ์ฆ์ baseline์ ๊ณ ๋ คํด ์ข ๋ ์์ ์ ๊ฐํํ์ต์ํ์ด ๊ฐ๋ฅํ๋ค.
4. Natural Language Generation์ Reinforcement Learning ์ ์ฉ
4.1 NLG์์ ๊ฐํํ์ต์ ํน์ง
์์์ RL์ Policy based Learning์ธ Policy Gradient๋ฐฉ์์ ๋ํด ๊ฐ๋จํ ์์๋ดค๋ค.
Policy Gradient์ ๊ฒฝ์ฐ, ์์์ ์ค๋ช ํ ๋ด์ฉ ์ด์ธ์๋ ๋ฐ์ ๋ ๋ฐฉ๋ฒ๋ค์ด ๋ง๋ค.
ex) Actor Critic, A3C, ...
โ Actor Critic: ์ ์ฑ ๋ง θ์ด์ธ์๋ ๊ฐ์ค์น๋ง W๋ฅผ ๋ฐ๋ก ๋์ด episode์ข ๋ฃ๊น์ง ๊ธฐ๋ค๋ฆฌ์ง ์๊ณ TDํ์ต๋ฒ์ผ๋ก ํ์ตํ๋ค.
โ A3C(Asunchronous Advantage Actor Critic): Actor Critic์์ ๋์ฑ ๋ฐ์ ๋ฐ ๊ธฐ์กด ๋จ์ ์ ๋ณด์
๋ค๋ง, NLP์ RL์ ์ด๋ฐ ๋ค์ํ ๋ฐฉ๋ฒ๋ค์ ๊ตณ์ด ์ฌ์ฉํ ํ์์์ด ๊ฐ๋จํ Reinforce์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํด๋ ํฐ ๋ฌธ์ ๊ฐ ์๋๋ฐ, ์ด๋ NLP๋ถ์ผ์ ํน์ง ๋๋ถ์ผ๋ก ๊ฐํํ์ต์ ์์ฐ์ด์ฒ๋ฆฌ์ ์ ์ฉ ์, ์๋์ ๊ฐ์ ํน์ง์ด ์กด์ฌํ๋ค.
1. ์ ํ ๊ฐ๋ฅํ ๋งค์ฐ ๋ง์ ํ๋(action) at๊ฐ ์กด์ฌ
โ ๋ณดํต ๋ค์ ๋จ์ด๋ฅผ ์ ํํ๋ ๊ฒ = ํ๋์ ์ ํํ๋ ๊ฒ
โ ์ ํ ๊ฐ๋ฅํ ํ๋์ ์งํฉ์ ํฌ๊ธฐ = ์ดํ ์ฌ์ ์ ํฌ๊ธฐ
∴ ๊ทธ ์งํฉ์ ํฌ๊ธฐ๋ ๋ณดํต ๋ช ๋ง ๊ฐ๊ฐ ๋๊ธฐ ๋ง๋ จ.
2. ๋งค์ฐ ๋ง์ ์ํ (state)๊ฐ ์กด์ฌ
๋จ์ด๋ฅผ ์ ํํ๋ ๊ฒ์ด ํ๋์ด์๋ค๋ฉด,
โ ์ด์ ๊น์ง ์ ํ๋ ๋จ์ด๋ค์ ์ํ์ค = ์ํ
โ ์ฌ๋ฌ time-step์ ๊ฑฐ์ณ ์๋ง์ ํ๋(๋จ์ด)์ด ์ ํ๋์๋ค๋ฉด ๊ฐ๋ฅํ ์ํ์ ๊ฒฝ์ฐ์ ์๋ ๋งค์ฐ ์ปค์ง ๊ฒ.
3. ๋ฐ๋ผ์ ๋งค์ฐ ๋ง์ ํ๋์ ์ ํํ๊ณ , ๋งค์ฐ ๋ง์ ์ํ๋ฅผ ํ๋ จ ๊ณผ์ ์์ ๋ชจ๋ ๊ฒช๋ ๊ฒ์ ๊ฑฐ์ ๋ถ๊ฐ๋ฅํ๋ค๊ณ ๋ณผ ์ ์๋ค.
โ ๊ฒฐ๊ตญ ์ถ๋ก ๊ณผ์ ์์ unseen sample์ ๋ง๋๋ ๊ฒ์ ๋งค์ฐ ๋น์ฐํ ๊ฒ์ด๋ค.
โ ์ด๋ฐ ํฌ์์ฑ ๋ฌธ์ ๋ ํฐ ๊ณจ์นซ๊ฑฐ๋ฆฌ๊ฐ ๋ ์ ์์ง๋ง DNN์ผ๋ก ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐ ํ ์ ์๋ค.
4. ๊ฐํํ์ต์ ์์ฐ์ด ์ฒ๋ฆฌ์ ์ ์ฉํ ๋ ์ฌ์ด ์ ๋ ์๋ค.
โ ๋๋ถ๋ถ ํ๋์ ๋ฌธ์ฅ ์์ฑ = ํ๋์ ์ํผ์๋ (์ด๋, ๋ฌธ์ฅ์ ๊ธธ์ด๋ ๋ณดํต 100 ๋จ์ด ๋ฏธ๋ง)
→ ๋ค๋ฅธ ๋ถ์ผ์ ๊ฐํํ์ต๋ณด๋ค ํจ์ฌ ์ฝ๋ค๋ ์ด์ ์ ๊ฐ๋๋ค.
ex) DeepMind์ AlphaGo, Starcraft์ ๊ฒฝ์ฐ, ํ๋์ ์ํผ์๋๊ฐ ๋๋๊ธฐ๊น์ง ๋งค์ฐ ๊ธด ์๊ฐ์ด ๋ ๋ค.
→ ์ํผ์๋ ๋ด์์ ์ ํ๋ ํ๋๋ค์ด ์ ์ฑ ์ ์ ๋ฐ์ดํธํ๋ ค๋ฉด ๋งค์ฐ ๊ธด ์ํผ์๋๊ฐ ๋๋๊ธฐ๋ฅผ ๊ธฐ๋ค๋ ค์ผ ํ๋ค.
→ ๋ฟ๋ง ์๋๋ผ, 10๋ถ ์ ์ ์ ํํ๋ ํ๋์ด ํด๋น ๊ฒ์์ ์นํจ์ ์ผ๋ง๋ ํฐ ์ํฅ์ ๋ฏธ์ณค๋์ง ์์๋ด๊ธฐ๋ ๋งค์ฐ ์ด๋ ค์ด ์ผ์ด ๋ ๊ฒ์ด๋ค.
โ๏ธ์ด๋ ์์ฐ์ด ์์ฑ ๋ถ์ผ๊ฐ ๋ค๋ฅธ ๋ถ์ผ์ ๋นํด ์ํผ์๋๊ฐ ์งง๋ค๋ ๊ฒ์ ๋งค์ฐ ํฐ ์ด์ ์ผ๋ก ์์ฉํ์ฌ ์ ์ฑ ๋ง์ ํจ์ฌ ๋ ์ฝ๊ฒ ํ๋ จ ์ํฌ ์ ์๋ค.
5. ๋์ , ๋ฌธ์ฅ ๋จ์์ ์ํผ์๋๋ฅผ ๊ฐ์ง๋ ๊ฐํํ์ต์์๋ ๋ณดํต ์ํผ์๋ ์ค๊ฐ์ ๋ณด์์ ์ป๊ธฐ ์ด๋ ต๋ค.
ex) ๋ฒ์ญ์ ๊ฒฝ์ฐ, ๊ฐ time-step๋ง๋ค ๋จ์ด๋ฅผ ์ ํํ ๋ ์ฆ๊ฐ์ ์ธ ๋ณด์์ ์ป์ง ๋ชปํ๊ณ , ๋ฒ์ญ์ด ๋ชจ๋ ๋๋ ์ดํ ์์ฑ๋ ๋ฌธ์ฅ๊ณผ ์ ๋ต ๋ฌธ์ฅ์ ๋น๊ตํ์ฌ BLEU ์ ์๋ฅผ ๋์ ๋ณด์์ผ๋ก ์ฌ์ฉํ๋ค.
๋ง์ฐฌ๊ฐ์ง๋ก ์ํผ์๋๊ฐ ๋งค์ฐ ๊ธธ๋ค๋ฉด ์ด๊ฒ์ ๋งค์ฐ ํฐ ๋ฌธ์ ๊ฐ ๋์๊ฒ ์ง๋ง, ๋คํํ๋ ๋ฌธ์ฅ ๋จ์์ ์ํผ์๋์์๋ ํฐ ๋ฌธ์ ๊ฐ ๋์ง ์์ต๋๋ค
4.2 RL ์ ์ฉ์ ์ฅ์
Teacher Forcing์ ์ด์ฉํ ๋ฌธ์ ํด๊ฒฐ
seq2seq๊ฐ์ AR์์ฑ์ ๋ชจ๋ธํ๋ จ ์, teacher forcing ๋ฐฉ๋ฒ์ ์ฌ์ฉํ๋ค.
์ด ๋ฐฉ๋ฒ์ train๊ณผ inference๋ฐฉ์์ ์ฐจ์ด๊ฐ ๋ฐ์, ์ค์ ์ถ๋ก ๋ฐฉ์๊ณผ ๋ค๋ฅด๊ฒ ๋ฌธ์ ๋ฅผ ํ๋ จํด์ผํ๋ค.
โ๏ธํ์ง๋ง, RL์ ํตํด ์ค์ ์ถ๋ก ํํ์ ๊ฐ์ด sampling์ผ๋ก ๋ชจ๋ธ์ ํ์ต
→ "train๊ณผ inference"์ ์ฐจ์ด๊ฐ ์์ด์ก๋ค.
๋ ์ ํํ ๋ชฉ์ ํจ์์ ์ฌ์ฉ
BLEU๋ PPL์ ๋นํด ๋ ๋์ ๋ฒ์ญํ์ง์ ๋ฐ์ํ๋ค. [Gain_NLP_07]
๋ค๋ง, ์ด๋ฐ metric๋ค์ ๋ฏธ๋ถ์ ํตํด ํํํ๊ธฐ ๋ถ๊ฐ๋ฅํ ๊ฒฝ์ฐ๊ฐ ๋๋ถ๋ถ์ด์ด์ ์ ๊ฒฝ๋ง ํ๋ จ์ ์ฌ์ฉ์ด ์ด๋ ต๋ค.
โ๏ธํ์ง๋ง, RL์ Policy Gradient๋ฅผ ์์ฉํด ๋ณด์ํจ์์ ๋ํด ๋ฏธ๋ถ์ ๊ณ์ฐํ ํ์๊ฐ ์์ด์ง๋ฉด์ ์ ํํ Metric ์ฌ์ฉ์ด ๊ฐ๋ฅํ๋ค.
5. RL์ ํ์ฉํ Supervised Learning
BLEU๋ฅผ ํ๋ จ๊ณผ์ ์ ๋ชฉ์ ํจ์๋ก ์ฌ์ฉํ๋ค๋ฉด ๋ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์ ์์ํ ๋ฐ...
๋ง์ฐฌ๊ฐ์ง๋ก ๋ค๋ฅธ NLG๋ฌธ์ ์ ๋ํด์๋ ๋น์ทํ ์ ๊ทผ์ ํ ์ ์์ผ๋ฉด ์ข์ํ ๋ฐ...
5.1 MRT (Minimum Risk Training)
์์ ๋ฐ๋์์ ์ถ๋ฐํ์ฌ ์ํ์ต์ํ ํ๋ จ[MRT; Minimum Risk Training]๋ ผ๋ฌธ์ด ์ ์๋์๋ค.
๋น์ ์ ์๋ Policy Gradient๋ฅผ ์ง์ ์ ์ผ๋ก ์ฌ์ฉํ์ง ์์์ผ๋ ์ ์ฌํ ์์์ด ์ ๋๋์๋ค๋ ์ ์์ ๋งค์ฐ ์ธ์์ ์ด๋ค.
๊ธฐ์กด์ ์ต๋๊ฐ๋ฅ๋์ถ์ (MLE)๋ฐฉ์์ ์์๊ฐ์ ์์คํจ์๋ฅผ ์ฌ์ฉํด |S|๊ฐ์ ์ ์ถ๋ ฅ์ ๋ํ ์์ค๊ฐ์ ๊ตฌํ๊ณ , ์ด๋ฅผ ์ต์ํํ๋ θ๋ฅผ ์ฐพ๋ ๊ฒ์ด ๋ชฉํ์๋ค.
ํ์ง๋ง, ์ด ๋ ผ๋ฌธ์์๋ risk๋ฅผ ์๋์ ๊ฐ์ด ์ ์ํ๊ณ , ์ด๋ฅผ ์ต์ํํ๋ ํ์ต๋ฐฉ์์ธ MRT๋ฅผ ์ ์ํ๋ค.
์์ ์์์์ y(x(s))๋ search scape(ํ์๊ณต๊ฐ)์ ์ ์ฒด ์งํฉ์ด๋ค.
์ด๋ S๋ฒ์งธ x(s)๊ฐ ์ฃผ์ด์ก์ ๋ ๊ฐ๋ฅํ ์ ๋ต์งํฉ์ ์๋ฏธํ๋ค.
๋ํ, Δ(y, y(s))๋ ์ ๋ ฅ ํ๋ผ๋ฏธํฐ θ๊ฐ ์ฃผ์ด์ก์ ๋, samplingํ y์ ์ค์ ์ ๋ต y(s)์ ์ฐจ์ด(= error)๊ฐ์ ์๋ฏธํ๋ค.
์ฆ, ์ด ์์์ ๋ฐ๋ฅด๋ฉด risk R์ ์ฃผ์ด์ง ์ ๋ ฅ๊ณผ ํ์ฌ ํ๋ผ๋ฏธํฐ ์์์ ์ป์ y๋ฅผ ํตํด ํ์ฌ ๋ชจ๋ธํจ์๋ฅผ ๊ตฌํ๊ณ , ๋์์ ์ด๋ฅผ ์ฌ์ฉํด Risk์ ๊ธฐ๋๊ฐ์ ๊ตฌํ๋ค ๋ณผ ์ ์๋ค.
์ด๋ ๊ฒ ์ ์๋ Risk๋ฅผ ์ต์ํํ๋ ๊ฒ์ด ๋ชฉํ์ด๋ค.
๋ฐ๋๋ก risk๋์ ๋ณด์์ผ๋ก ์๊ฐํ๋ฉด, ๋ณด์์ ์ต๋ํํ๋ ๊ฒ์ด ๋ชฉํ์ด๋ค.
๊ฒฐ๊ตญ, risk๋ฅผ ์ต์ํํ ๋ ๊ฒฝ์ฌํ๊ฐ๋ฒ, ๋ณด์์ ์ต๋ํํ ๋๋ ๊ฒฝ์ฌ์์น๋ฒ์ ์ฌ์ฉํ๊ธฐ์
์์์ ๋ถํดํ๋ฉด ๊ฒฐ๊ตญ ์จ์ ํ ๋์ผํ ๋ด์ฉ์์ ์ ์ ์๋ค.
๋ฐ๋ผ์ ์ค์ ๊ตฌํ ์, Δ(y, y(s))์ฌ์ฉ์ ์ํ ๋ณด์ํจ์ BLEU์ -1์ ๊ณฑํด Riskํจ์๋ก ๋ง๋ ๋ค.
๋ค๋ง ์ฃผ์ด์ง ์ ๋ ฅ์ ๋ํด ๊ฐ๋ฅํ ์ ๋ต์ ๊ดํ ์ ์ฒด๊ณต๊ฐ์ ํ์ํ ์ ์๊ธฐ์
์ ์ฒดํ์๊ณต๊ฐ์ samplingํ sub-space์์ samplingํ๋ ๊ฒ์ ํํ๋ค.
๊ทธ ํ ์์ ์์์์ θ์ ๋ํด ๋ฏธ๋ถ์ ์ํํ๋ค.
์ด์ , ๋ฏธ๋ถ์ ํตํด ์ป์ MRT์ ์ต์ข ์์์ ํด์ํด๋ณด์.
์ต์ข ์ ์ผ๋ก๋ ์์์์ ๊ธฐ๋๊ฐ๋ถ๋ถ์ Monte Carlo Sampling์ ํตํด ์ ๊ฑฐํ ์ ์๋ค.
์๋์์์ Policy Gradient์ Reinforce์๊ณ ๋ฆฌ์ฆ ์์์ผ๋ก ์์ MRT์์๊ณผ ๋น๊ตํ์ฌ ์ฐธ๊ณ ํด๋ณด์.
[MRT์ Reinforce์๊ณ ๋ฆฌ์ฆ ์์]
MRT๋ ๊ฐํํ์ต์ผ๋ก์จ์ ์ ๊ทผ์ ์ ํํ์ง ์๊ณ , ์์์ ์ผ๋ก Policy Gradient์ Reinforce์๊ณ ๋ฆฌ์ฆ ์์์ ๋์ถํด๋ด์ด ์ฑ๋ฅ์ ๋์ด์ฌ๋ฆฐ๋ค๋ ์ ์์ ๋งค์ฐ ์ธ์๊น์ ๋ฐฉ๋ฒ์์ ์ ์ ์๋ค.
[Policy Gradient์ Reinforce์๊ณ ๋ฆฌ์ฆ ์์]
Pytorch NMT_with_MRT ์์ ์ฝ๋
๊ตฌํ๊ณผ์
1. ์ฃผ์ด์ง ์
๋ ฅ๋ฌธ์ฅ์๋ํด ์ ์ฑ
๐๋ฅผ ์ด์ฉํด ๋ฒ์ญ๋ฌธ์ฅ sampling
2. sampling๋ฌธ์ฅ๊ณผ ์ ๋ต๋ฌธ์ฅ์ฌ์ด BLEU๋ฅผ ๊ณ์ฐ, -1์ ๊ณฑํด Risk๋ก ๋ณํ
3. logํ๋ฅ ๋ถํฌ์ ์ฒด์ Risk๋ฅผ ๊ณฑํจ
4. ๊ฐ sample๊ณผ time-step๋ณ ๊ตฌํด์ง NLL๊ฐ์ ํฉ์ -1์ ๊ณฑํด์ค(=PLL)
5. ๋ก๊ทธํ๋ฅ ๊ฐ์ ํฉ์ ๐๋ก ๋ฏธ๋ถ์ ์ํ, BP๋ก ์ ๊ฒฝ๋ง ๐์ ์ฒด ๊ธฐ์ธ๊ธฐ๊ฐ ๊ตฌํด์ง
6. ์ด๋ฏธ Risk๋ฅผ ํ๋ฅ ๋ถํฌ์ ๊ณฑํ๊ธฐ์, ๋ฐ๋ก ์ด ๊ธฐ์ธ๊ธฐ๋ก BP๋ฅผ ์ํ, ์ต์ ํ
from nltk.translate.gleu_score import sentence_gleu
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction
import numpy as np
import torch
from torch import optim
from torch.nn import functional as F
import torch.nn.utils as torch_utils
from ignite.engine import Engine
from ignite.engine import Events
from ignite.metrics import RunningAverage
from ignite.contrib.handlers.tqdm_logger import ProgressBar
import simple_nmt.data_loader as data_loader
from simple_nmt.trainer import MaximumLikelihoodEstimationEngine
from simple_nmt.utils import get_grad_norm, get_parameter_norm
VERBOSE_SILENT = 0
VERBOSE_EPOCH_WISE = 1
VERBOSE_BATCH_WISE = 2
class MinimumRiskTrainingEngine(MaximumLikelihoodEstimationEngine):
@staticmethod
def _get_reward(y_hat, y, n_gram=6, method='gleu'):
# This method gets the reward based on the sampling result and reference sentence.
# For now, we uses GLEU in NLTK, but you can used your own well-defined reward function.
# In addition, GLEU is variation of BLEU, and it is more fit to reinforcement learning.
sf = SmoothingFunction()
score_func = {
'gleu': lambda ref, hyp: sentence_gleu([ref], hyp, max_len=n_gram),
'bleu1': lambda ref, hyp: sentence_bleu([ref], hyp,
weights=[1./n_gram] * n_gram,
smoothing_function=sf.method1),
'bleu2': lambda ref, hyp: sentence_bleu([ref], hyp,
weights=[1./n_gram] * n_gram,
smoothing_function=sf.method2),
'bleu4': lambda ref, hyp: sentence_bleu([ref], hyp,
weights=[1./n_gram] * n_gram,
smoothing_function=sf.method4),
}[method]
# Since we don't calculate reward score exactly as same as multi-bleu.perl,
# (especialy we do have different tokenization,) I recommend to set n_gram to 6.
# |y| = (batch_size, length1)
# |y_hat| = (batch_size, length2)
with torch.no_grad():
scores = []
for b in range(y.size(0)):
ref, hyp = [], []
for t in range(y.size(-1)):
ref += [str(int(y[b, t]))]
if y[b, t] == data_loader.EOS:
break
for t in range(y_hat.size(-1)):
hyp += [str(int(y_hat[b, t]))]
if y_hat[b, t] == data_loader.EOS:
break
# Below lines are slower than naive for loops in above.
# ref = y[b].masked_select(y[b] != data_loader.PAD).tolist()
# hyp = y_hat[b].masked_select(y_hat[b] != data_loader.PAD).tolist()
scores += [score_func(ref, hyp) * 100.]
scores = torch.FloatTensor(scores).to(y.device)
# |scores| = (batch_size)
return scores
@staticmethod
def _get_loss(y_hat, indice, reward=1):
# |indice| = (batch_size, length)
# |y_hat| = (batch_size, length, output_size)
# |reward| = (batch_size,)
batch_size = indice.size(0)
output_size = y_hat.size(-1)
'''
# Memory inefficient but more readable version
mask = indice == data_loader.PAD
# |mask| = (batch_size, length)
indice = F.one_hot(indice, num_classes=output_size).float()
# |indice| = (batch_size, length, output_size)
log_prob = (y_hat * indice).sum(dim=-1)
# |log_prob| = (batch_size, length)
log_prob.masked_fill_(mask, 0)
log_prob = log_prob.sum(dim=-1)
# |log_prob| = (batch_size, )
'''
# Memory efficient version
log_prob = -F.nll_loss(
y_hat.view(-1, output_size),
indice.view(-1),
ignore_index=data_loader.PAD,
reduction='none'
).view(batch_size, -1).sum(dim=-1)
loss = (log_prob * -reward).sum()
# Following two equations are eventually same.
# \theta = \theta - risk * \nabla_\theta \log{P}
# \theta = \theta - -reward * \nabla_\theta \log{P}
# where risk = -reward.
return loss
@staticmethod
def train(engine, mini_batch):
# You have to reset the gradients of all model parameters
# before to take another step in gradient descent.
engine.model.train()
if engine.state.iteration % engine.config.iteration_per_update == 1 or \
engine.config.iteration_per_update == 1:
if engine.state.iteration > 1:
engine.optimizer.zero_grad()
device = next(engine.model.parameters()).device
mini_batch.src = (mini_batch.src[0].to(device), mini_batch.src[1])
mini_batch.tgt = (mini_batch.tgt[0].to(device), mini_batch.tgt[1])
# Raw target variable has both BOS and EOS token.
# The output of sequence-to-sequence does not have BOS token.
# Thus, remove BOS token for reference.
x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
# |x| = (batch_size, length)
# |y| = (batch_size, length)
# Take sampling process because set False for is_greedy.
y_hat, indice = engine.model.search(
x,
is_greedy=False,
max_length=engine.config.max_length
)
with torch.no_grad():
# Based on the result of sampling, get reward.
actor_reward = MinimumRiskTrainingEngine._get_reward(
indice,
y,
n_gram=engine.config.rl_n_gram,
method=engine.config.rl_reward,
)
# |y_hat| = (batch_size, length, output_size)
# |indice| = (batch_size, length)
# |actor_reward| = (batch_size)
# Take samples as many as n_samples, and get average rewards for them.
# I figured out that n_samples = 1 would be enough.
baseline = []
for _ in range(engine.config.rl_n_samples):
_, sampled_indice = engine.model.search(
x,
is_greedy=False,
max_length=engine.config.max_length,
)
baseline += [
MinimumRiskTrainingEngine._get_reward(
sampled_indice,
y,
n_gram=engine.config.rl_n_gram,
method=engine.config.rl_reward,
)
]
baseline = torch.stack(baseline).mean(dim=0)
# |baseline| = (n_samples, batch_size) --> (batch_size)
# Now, we have relatively expected cumulative reward.
# Which score can be drawn from actor_reward subtracted by baseline.
reward = actor_reward - baseline
# |reward| = (batch_size)
# calculate gradients with back-propagation
loss = MinimumRiskTrainingEngine._get_loss(
y_hat,
indice,
reward=reward
)
backward_target = loss.div(y.size(0)).div(engine.config.iteration_per_update)
backward_target.backward()
p_norm = float(get_parameter_norm(engine.model.parameters()))
g_norm = float(get_grad_norm(engine.model.parameters()))
if engine.state.iteration % engine.config.iteration_per_update == 0 and \
engine.state.iteration > 0:
# In orther to avoid gradient exploding, we apply gradient clipping.
torch_utils.clip_grad_norm_(
engine.model.parameters(),
engine.config.max_grad_norm,
)
# Take a step of gradient descent.
engine.optimizer.step()
return {
'actor': float(actor_reward.mean()),
'baseline': float(baseline.mean()),
'reward': float(reward.mean()),
'|param|': p_norm if not np.isnan(p_norm) and not np.isinf(p_norm) else 0.,
'|g_param|': g_norm if not np.isnan(g_norm) and not np.isinf(g_norm) else 0.,
}
@staticmethod
def validate(engine, mini_batch):
engine.model.eval()
with torch.no_grad():
device = next(engine.model.parameters()).device
mini_batch.src = (mini_batch.src[0].to(device), mini_batch.src[1])
mini_batch.tgt = (mini_batch.tgt[0].to(device), mini_batch.tgt[1])
x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
# |x| = (batch_size, length)
# |y| = (batch_size, length)
# feed-forward
y_hat, indice = engine.model.search(
x,
is_greedy=True,
max_length=engine.config.max_length,
)
# |y_hat| = (batch_size, length, output_size)
# |indice| = (batch_size, length)
reward = MinimumRiskTrainingEngine._get_reward(
indice,
y,
n_gram=engine.config.rl_n_gram,
method=engine.config.rl_reward,
)
return {
'BLEU': float(reward.mean()),
}
@staticmethod
def attach(
train_engine,
validation_engine,
training_metric_names = ['actor', 'baseline', 'reward', '|param|', '|g_param|'],
validation_metric_names = ['BLEU', ],
verbose=VERBOSE_BATCH_WISE
):
# Attaching would be repaeted for serveral metrics.
# Thus, we can reduce the repeated codes by using this function.
def attach_running_average(engine, metric_name):
RunningAverage(output_transform=lambda x: x[metric_name]).attach(
engine,
metric_name,
)
for metric_name in training_metric_names:
attach_running_average(train_engine, metric_name)
if verbose >= VERBOSE_BATCH_WISE:
pbar = ProgressBar(bar_format=None, ncols=120)
pbar.attach(train_engine, training_metric_names)
if verbose >= VERBOSE_EPOCH_WISE:
@train_engine.on(Events.EPOCH_COMPLETED)
def print_train_logs(engine):
avg_p_norm = engine.state.metrics['|param|']
avg_g_norm = engine.state.metrics['|g_param|']
avg_reward = engine.state.metrics['actor']
print('Epoch {} - |param|={:.2e} |g_param|={:.2e} BLEU={:.2f}'.format(
engine.state.epoch,
avg_p_norm,
avg_g_norm,
avg_reward,
))
for metric_name in validation_metric_names:
attach_running_average(validation_engine, metric_name)
if verbose >= VERBOSE_BATCH_WISE:
pbar = ProgressBar(bar_format=None, ncols=120)
pbar.attach(validation_engine, validation_metric_names)
if verbose >= VERBOSE_EPOCH_WISE:
@validation_engine.on(Events.EPOCH_COMPLETED)
def print_valid_logs(engine):
avg_bleu = engine.state.metrics['BLEU']
print('Validation - BLEU={:.2f} best_BLEU={:.2f}'.format(
avg_bleu,
-engine.best_loss,
))
@staticmethod
def resume_training(engine, resume_epoch):
resume_epoch = max(1, resume_epoch - engine.config.n_epochs)
engine.state.iteration = (resume_epoch - 1) * len(engine.state.dataloader)
engine.state.epoch = (resume_epoch - 1)
@staticmethod
def check_best(engine):
loss = -float(engine.state.metrics['BLEU'])
if loss <= engine.best_loss:
engine.best_loss = loss
@staticmethod
def save_model(engine, train_engine, config, src_vocab, tgt_vocab):
avg_train_bleu = train_engine.state.metrics['actor']
avg_valid_bleu = engine.state.metrics['BLEU']
# Set a filename for model of last epoch.
# We need to put every information to filename, as much as possible.
model_fn = config.model_fn.split('.')
model_fn = model_fn[:-1] + ['mrt',
'%02d' % train_engine.state.epoch,
'%.2f-%.2f' % (avg_train_bleu,
avg_valid_bleu),
] + [model_fn[-1]]
model_fn = '.'.join(model_fn)
# Unlike other tasks, we need to save current model, not best model.
torch.save(
{
'model': engine.model.state_dict(),
'opt': train_engine.optimizer.state_dict(),
'config': config,
'src_vocab': src_vocab,
'tgt_vocab': tgt_vocab,
}, model_fn
)
6. RL์ ํ์ฉํ Unsupervised Learning
์ง๋ํ์ต๋ฐฉ์์ ๋์ ์ ํ๋๋ฅผ ์๋ํ๋ค. ๋ค๋ง, labeled data๊ฐ ํ์ํด dataํ๋ณด๋ cost๊ฐ ๋๋ค.
๋น์ง๋ํ์ต๋ฐฉ์์ dataํ๋ณด์ ๋ํ cost๊ฐ ๋ฎ๊ธฐ์ ๋ ์ข์ ๋์์ด ๋ ์ ์๋ค.
(๋ฌผ๋ก , ์ง๋ํ์ต์ ๋นํด ์ฑ๋ฅ์ด๋ ํจ์จ์ด ๋จ์ด์ง ๊ฐ๋ฅ์ฑ์ ๋์)
์ด๋ฐ ์ ์์ parallel corpus์ ๋นํด monolinugal corpus๋ฅผ ํ๋ณดํ๊ธฐ ์ฝ๋ค๋ NLP์ ํน์ฑ์, ์ข์ ๋์์ด ๋ ์ ์๋ค.
์๋์ parallel corpus์ ๋ค๋์ monolingual corpus๋ฅผ ๊ฒฐํฉ → ๋ ๋์ ์ฑ๋ฅ์ ํ๋ณดํ ์ ์์ ๊ฒ์ด๋ค.
6.1 Unsupervised๋ฅผ ํตํ NMT
์ด๋ฒ์ ์๊ฐํ ๋ ผ๋ฌธ์ ์ค์ง monolingual corpus๋ง์ ์ฌ์ฉํด ๋ฒ์ญ๊ธฐ๋ฅผ ์ ์ํ๋ ๋ฐฉ๋ฒ์ ์ ์ํ๋ค. [Guillaume Lampl;2018]
๋ฐ๋ผ์ ์ง์ ํ ๋น์ง๋ํ์ต์ ํตํ NMT๋ผ ๋ณผ ์ ์๋ค.
[ํต์ฌ idea]
โ ์ธ์ด์ ์๊ด์์ด ๊ฐ์ ์๋ฏธ์ ๋ฌธ์ฅ์ผ ๊ฒฝ์ฐ, Encoder๊ฐ ๊ฐ์ ๊ฐ์ผ๋ก embeddingํ ์ ์๋๋ก ํ๋ จํ๋ ๊ฒ.
์ด๋ฅผ ์ํด GAN์ด ๋์ ๋์๋ค!! →โ์ด? ๋ถ๋ช GAN์ NLP์์ ๋ชป์ด๋ค๊ณ ์์ ์๊ธฐํ๋ ๊ฒ ๊ฐ์๋ฐ...??
โ๏ธencoder์ ์ถ๋ ฅ๊ฐ์ด ์ฐ์์ ๊ฐ์ด๊ธฐ์ GAN์ ์ ์ฉํ ์ ์์๋ค.
[Encoder]
์ธ์ด์ ์๊ด์์ด ๋์ผํ ๋ด์ฉ์ ๋ฌธ์ฅ์ ๋ํด ๊ฐ์ ๊ฐ์ ๋ฒกํฐ๋ก encodingํ๋๋ก ํ๋ จ
โ ์ด๋ฅผ ์ํด ํ๋ณ์ D(encoding๋ ๋ฌธ์ฅ์ ์ธ์ด๋ฅผ ๋ง์ถ๋ ์ญํ )๊ฐ ํ์ํ๊ณ
โ D๋ฅผ ์์ด๋๋ก Encoder๋ ํ์ต๋๋ค.
[Decoder]
encoder์ ์ถ๋ ฅ๊ฐ์ ๊ฐ๊ณ Decoder๋ฅผ ํตํด ๊ธฐ์กด ๋ฌธ์ฅ์ผ๋ก ์ ๋์์ค๋๋ก ํจ
์ฆ, Encoder์ Decoder๋ฅผ ์ธ์ด์ ๋ฐ๋ผ ๋ค๋ฅด๊ฒ ์ฌ์ฉํ์ง ์๊ณ ์ธ์ด์ ์๊ด์์ด 1๊ฐ์ฉ์ Encoder, Decoder๋ฅผ ์ฌ์ฉํ๋ค.
๊ฒฐ๊ณผ์ ์ผ๋ก ์์คํจ์๋ ์๋ 3๊ฐ์ง ๋ถ๋ถ์ผ๋ก ๊ตฌ์ฑ๋๋ค.
์์คํจ์์ 3๊ฐ์ง ๊ตฌ์ฑ
De-noising Auto-Encoder
seq2seq๋ ์ผ์ข ์ Auto-Encoder์ ์ผ์ข ์ด๋ผ ๋ณผ ์ ์๋ค.
AE๋ ๊ต์ฅํ ์ฌ์ด ๋ฌธ์ ์ ์ํ๋ค.
๋ฐ๋ผ์ AE์์ ๋จ์ํ ๋ณต์ฌ์์ ์ ์ง์ํ๋๋์ ,
noise๋ฅผ ์์ด์ค src๋ฌธ์ฅ์์ De-noising์ ํ๋ฉด์ ์ ๋ ฅ๊ฐ์ ์ถ๋ ฅ์์ ๋ณต์(reconstruction)ํ๋๋ก ํ๋ จํด์ผํ๋๋ฐ, ์ด๋ฅผ "De-noising AutoEncoder"๋ผ ๋ถ๋ฅด๋ฉฐ ์๋์ ๊ฐ์ด ํํ๋๋ค.
์ด ์์์์ x_hat์ ์ ๋ ฅ๋ฌธ์ฅ x๋ฅผ noise_model C๋ฅผ ํตํด noise๋ฅผ ๋ํ๊ณ ๊ฐ์ ์ธ์ด โ๋ก encoding๊ณผ decoding์ ์ํํ๋ ๊ฒ์ ์๋ฏธํ๋ค.
Δ(x_hat, x)๋ MRT์์์ ๊ฐ์ด ์๋ฌธ๊ณผ ๋ณต์๋ ๋ฌธ์ฅ๊ณผ์ ์ฐจ์ด๋ฅผ ์๋ฏธํ๋ค.
noise_model C(x)๋ ์์๋ก ๋ฌธ์ฅ ๋ด ๋จ์ด๋ค์ ๋๋กญํ๊ฑฐ๋ ์์๋ฅผ ์์ด์ฃผ๋ ์ผ์ ํ๋ค.
Cross Domain training
โCross Domainํ๋ จ์ด๋
์ฌ์ ๋ฒ์ญ์ ํตํด ์ฌ์ ํ๋ จํ ์ ์ฑ๋ฅ์ ๋ฒ์ญ๋ชจ๋ธ M์์ ์ธ์ด โ2์ noise๊ฐ ์ถ๊ฐ๋์ด
๋ฒ์ญ๋ ๋ฌธ์ฅ y๋ฅผ ๋ค์ ์ธ์ด โ1 src ๋ฌธ์ฅ์ผ๋ก ์์๋ณต๊ตฌํ๋ ์์ ์ ํ์ตํ๋ ๊ฒ์ด๋ค.
Adversarial Learning
Encoder๊ฐ ์ธ์ด์ ์๊ด์์ด ํญ์ ๊ฐ์ ๋ถํฌ๋ก latent space์ ๋ฌธ์ฅ๋ฒกํฐ๋ฅผ embeddingํ๋์ง ๊ฐ์ํ๋ ํ๋ณ์ D๊ฐ ์ถ๊ฐ๋์ด ์ ๋์ ํ์ต์ ์งํํ๋ค.
โ D๋ latent variable z์ ๊ธฐ์กด ์ธ์ด๋ฅผ ์์ธกํ๋๋ก ํ๋ จ๋๋ค.
โ xi , โi๋ ๊ฐ์ ์ธ์ด(language pair)๋ฅผ ์๋ฏธํ๋ค.
๋ฐ๋ผ์ GAN์ฒ๋ผ Encoder๋ ํ๋ณ์ D๋ฅผ ์์ผ ์ ์๋๋ก ํ๋ จ๋์ด์ผ ํ๋ค.
์ด๋, j = - (i - 1) ๊ฐ์ ๊ฐ๋๋ค.
์ต์ข ๋ชฉ์ ํจ์
์์ 3๊ฐ์ง ๋ชฉ์ ํจ์๋ฅผ ๊ฒฐํฉํ๋ฉด ์ต์ข ๋ชฉ์ ํจ์๋ฅผ ์ป์ ์ ์๋ค.
๊ฐ λ๋ฅผ ํตํด ์์คํจ์ ์์์ ๋น์จ์ ์กฐ์ , ์ต์ ์ parameter θ๋ฅผ ์ฐพ๋๋ค.
๋ ผ๋ฌธ์์๋ ์ค์ง monolingual corpus๋ง ์กด์ฌํ ๋, NMT๋ฅผ ๋ง๋๋ ๋ฐฉ์์ ๋ํด ๋ค๋ฃฌ๋ค.
parallel corpus๊ฐ ์๋ ์ํฉ์์๋ NMT๋ฅผ ๋ง๋ค ์ ์๋ฐ๋ ์ ์์ ๋งค์ฐ ๊ณ ๋ฌด์ ์ด๋ค.
๋ค๋ง, ์ด ๋ฐฉ๋ฒ ์์ฒด๋ง์ผ๋ก ์ค์ ํ๋์์ ํ์ฉํ๊ธฐ ์ด๋ ค์ด๋ฐ, ์ค๋ฌด์์๋ ๋ฒ์ญ๊ธฐ๊ตฌ์ถ ์ parallel corpus๊ฐ ์๋ ๊ฒฝ์ฐ๋ ๋๋ฌผ๊ณ , ์๋ค ํ๋๋ผ๋ monolingual corpus๋ง์ผ๋ก ๋ฒ์ญ๊ธฐ๋ฅผ ๊ตฌ์ถํด ๋ฎ์ ์ฑ๋ฅ์ ๋ฒ์ญ๊ธฐ๋ฅผ ํ๋ณดํ๊ธฐ ๋ณด๋จ, ๋น์ฉ์ ๋ค์ฌ parallel corpus๋ฅผ ์ง์ ๊ตฌ์ถํ๊ณ , parallel corpus์ ๋ค์์ monolingual corpus๋ฅผ ํฉ์ณ NMT๋ฅผ ๊ตฌ์ถํ๋ ๋ฐฉํฅ์ผ๋ก ์งํํ๋ ๊ฒ์ด ๋ซ๊ธฐ ๋๋ฌธ์ด๋ค.
๋ง์น๋ฉฐ...
์ด๋ฒ์๊ฐ์๋ Reinforcement Learning์ ๋ํด ์์๋ณด๊ณ , ์ด๋ฅผ ์ด์ฉํด ์์ฐ์ด ์์ฑ๋ฌธ์ (NLG)๋ฅผ ํด๊ฒฐํ๋ ๋ฐฉ๋ฒ์ ๋ค๋ฃจ์๋ค.
๋ค์ํ RL ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํด NLG๋ฌธ์ ์ ์ฑ๋ฅ์ ๋์ผ ์ ์๋๋ฐ, ํนํ Policy Gradient๋ฅผ ์ฌ์ฉํด NLG์ ์ ์ฉํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํ๋ค.
Policy Gradient๋ฐฉ๋ฒ์ NLG์ ์ ์ฉํด ์ป๋ ์ด์ ์ ํฌ๊ฒ 2๊ฐ์ง์ธ๋ค.
โ teacher-forcing(AR์์ฑ์ผ๋ก ์ธํด ์ค์ ์ถ๋ก ๋ฐฉ์๊ณผ ๋ค๋ฅด๊ฒ ํ๋ จ)๋ฐฉ๋ฒ์์ ํํผํด
์ค์ ์ถ๋ก ๋ฐฉ์๊ณผ ๊ฐ์ sampling์ ํตํด ๋ฌธ์ฅ์์ฑ๋ฅ๋ ฅ ํฅ์
โก ๋ ์ ํํ ๋ชฉ์ ํจ์๋ฅผ ํ๋ จ์ด ๊ฐ๋ฅํ๋ค.
- ๊ธฐ์กด PPL: ๋ฒ์ญํ์ง, NLGํ์ง์ ์ ํํ ๋ฐ์X
- ๋ฐ๋ผ์ BLEU ๋ฑ์ metric์ ์ฌ์ฉ
- ํ์ง๋ง BLEU ๋ฑ์ metric์ ๋ฏธ๋ถ์ ํ ์ ์์.
∴ PPL๊ณผ ๋์ผํ Cross-Entropy๋ฅผ ํ์ฉํด ์ ๊ฒฝ๋ง์ ํ๋ จํด์ผ๋ง ํ๋ค.
๋ค๋ง, Policy Gradient ๋ํ ๋จ์ ์ด ์กด์ฌํ๋ค.
โ sampling๊ธฐ๋ฐ ํ๋ จ์ด๊ธฐ์ ๋ง์ iteration์ด ํ์.
๋ฐ๋ผ์ Cost๊ฐ ๋์ ๋ ๋นํจ์จ์ ํ์ต์ด ์งํ๋๋ค.
โก ๋ณด์ํจ์๋ ๋ฐฉํฅ์ด ์๋ ์ค์นผ๋ผ๊ฐ์ ๋ฐํํ๋ค.
๋ฐ๋ผ์ ๋ณด์ํจ์๋ฅผ ์ต๋ํํ๋ ๋ฐฉํฅ์ ์ ํํ ์ ์ ์๋ค.
์ด๋ ๊ธฐ์กด MLE๋ฐฉ์์์ ์์คํจ์๋ฅผ ์ ๊ฒฝ๋ง ํ๋ผ๋ฏธํฐ θ์ ๋ํด
๋ฏธ๋ถํด ์ป์ ๊ธฐ์ธ๊ธฐ๋ก ์์คํจ์์์ฒด๋ฅผ ์ต์ํํ๋ ๋ฐฉํฅ์ผ๋ก updateํ๋ ๊ฒ๊ณผ ์ฐจ์ด๊ฐ ์กด์ฌํ๋ค.
๊ฒฐ๊ตญ, ์ด ๋ํ ๊ธฐ์กด MLE๋ฐฉ์๋ณด๋ค ํจ์ฌ ๋นํจ์จ์ ํ์ต์ผ๋ก ์ด์ด์ง๊ฒ ๋๋ ๊ฒ์ด๋ค.