๐ ๋ชฉ์ฐจ
1. DCGAN (Deep Convolutional GAN)
2. WGAN-GP (Wasserstein GAN-Gradient Penalty)
3. CGAN (Conditional GAN)
4. ์์ฝ
๐ง preview:
GAN์ ์์ฑ์์ ํ๋ณ์๋ผ๋ ๋ ๋ชจ๋๊ฐ์ ์ธ์์ด๋ค.
์์ฑ์: random noise๋ฅผ ๊ธฐ์กด dataset์์ samplingํ ๊ฒ์ฒ๋ผ ๋ณด์ด๋๋ก ๋ณํ
ํ๋ณ์: sample์ด ๊ธฐ์กด dataset์์์ธ์ง, ์์ฑ์์์๋์๋์ง ์์ธก.
1. DCGAN (Deep Convolutinal GAN)
DCGAN
2015๋ ์ ๋์จ ๋ ผ๋ฌธ ์ฐธ๊ณ .
๐ธ Generator
๋ชฉํ: ํ๋ณ์๊ฐ ํ๋ณ ๋ถ๊ฐ๋ฅํ img์์ฑ
input: ๋ค๋ณ๋ํ์ค์ ๊ท๋ถํฌ์์ ๋ฝ์ ๋ฒกํฐ
output: ์๋ณธ train data์ ์๋ img์ ๋์ผํ ํฌ๊ธฐ์ img
์ ์ค๋ช ์ด ๋ง์น VAE๊ฐ๋ค๋ฉด?
์ค์ ๋ก VAE์ Decoder์ ๋์ผํ ๋ชฉ์ ์ ์ํํ๋ค.
latent space์ ๋ฒกํฐ๋ฅผ ์กฐ์, ๊ธฐ์กด domain์์ img์ ๊ณ ์์ค ํน์ฑ์ ๋ฐ๊พธ๋ ๊ธฐ๋ฅ์ ์ ๊ณตํ๊ธฐ ๋๋ฌธ.
Gen = nn.Sequential( nn.ConvTranspose2d(input_latent, 512, 4, 1, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True), nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False), nn.Tanh() ) Gen = Gen.to(device) Gen.apply(initialize_weights)
๐ Discriminator
๋ชฉํ: img๊ฐ ์ง์ง์ธ์ง, ๊ฐ์ง์ธ์ง ์์ธก.
๋ง์ง๋ง Conv2D์ธต์์ Sigmoid๋ฅผ ์ด์ฉํด 0๊ณผ 1์ฌ์ด ์ซ์๋ก ์ถ๋ ฅ.
Dis = nn.Sequential( nn.Conv2d(3, 64, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 512, 4, 2, 1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(512, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) Dis=Dis.to(device) Dis.apply(initialize_weights)โ
๐จ Train
batch img์์ฑ→ํ๋ณ์์ ํต๊ณผ→๊ฐ img์ ๋ํ ์ ์ get.
โ G_Loss: BCELoss (0: fake img / 1: real img)
โ D_Loss: BCELoss (0: fake img / 1: real img)
์ด๋, ํ๋ฒ์ ํ ์ ๊ฒฝ๋ง ๊ฐ์ค์น๋ง update๋๋๋ก ๋ ์ ๊ฒฝ๋ง์ ๋ฒ๊ฐ์ trainํด์ค์ผํจ.
criterion = nn.BCELoss() Gen_optimizer = torch.optim.Adam(Gen.parameters(), lr=0.0002, betas=(0.5, 0.999)) Dis_optimizer = torch.optim.Adam(Dis.parameters(), lr=0.0002, betas=(0.5, 0.999))โ
๊ธฐํ train์ฝ๋ ์ฐธ๊ณ : https://github.com/V2LLAIN/Vision_Generation/blob/main/Implicit_Density/DCGAN/train.py
์ด๋, DCGANํ๋ จ๊ณผ์ ์ด ๋ถ์์ ํ ์ ์๋ค. (โต ํ๋ณ์์ ์์ฑ์๊ฐ ์ฐ์๋ฅผ ์ฐจ์งํ๋ ค ์๋ก ๊ณ์ ๊ฒฝ์ํ๊ธฐ ๋๋ฌธ.)
์๊ฐ์ด ์ถฉ๋ถํ ์ง๋๋ฉด, ํ๋ณ์๊ฐ ์ฐ์ธํด์ง๋ ๊ฒฝํฅ์ด ์๋ค.
๋ค๋ง, ์ด์์ ์๋ ์์ฑ์๊ฐ ์ถฉ๋ถํ ๊ณ ํ์ง Img์์ฑ์ด ๊ฐ๋ฅํด์ ํฐ ๋ฌธ์ ๋ ๋์ง ์๋๋ค.
Label Smoothing
๋ํ, GAN์ random noise๋ฅผ ์กฐ๊ธ ์ถ๊ฐํ๋ฉด ์ ์ฉํ๋ฐ, train๊ณผ์ ์ ์์ ์ฑ ๊ฐ์ฑ ๋ฐ img์ ๋ช ๋๊ฐ ์ฆ๊ฐํ๋ค.
(๋ง์น Denoise Auto Encoder์ ๊ฐ์ ๋๋.)
GAN ํ๋ จ ํ & Trick
โ D >> G ์ธ ๊ฒฝ์ฐ.
ํ๋ณ์๊ฐ ๋๋ฌด ๊ฐํ๋ฉด Loss์ ํธ๊ฐ ๋๋ฌด ์ฝํด์ง๋ค.
์ด๋ก ์ธํด ์์ฑ์์์ ์๋ฏธ์๋ ํฅ์์ ๋๋ชจํ๊ธฐ ์ด๋ ค์์ง๋ค.
๋ฐ๋ผ์ ๋ค์๊ณผ ๊ฐ์ด ํ๋ณ์๋ฅผ ์ฝํํ ๋ฐฉ๋ฒ์ด ํ์ํ๋ค.
โ ํ๋ณ์์ Dropout rate ์ฆ๊ฐ. โ ํ๋ณ์์ LR ๊ฐ์. โ ํ๋ณ์์ Conv filter ์ ๊ฐ์. โ ํ๋ณ์ ํ๋ จ ์, Label์ Noise์ถ๊ฐ. (Label Smoothing) โ ํ๋ณ์ ํ๋ จ ์, ์ผ๋ถ img์ label์ random์ผ๋ก ๋ค์ง๋๋ค.โ
โ G >> D ์ธ ๊ฒฝ์ฐ.
mode collapse: ์์ฑ์๊ฐ ๊ฑฐ์ ๋์ผํ ๋ช๊ฐ์ img๋ก ํ๋ณ์๋ฅผ "์ฝ๊ฒ ์์ด๋ ๋ฐฉ๋ฒ"
mode: ํ๋ณ์๋ฅผ ํญ์ ์์ด๋ ํ๋์ sample.
์์ฑ์๋ ์ด๋ฐ mode๋ฅผ ์ฐพ์ผ๋ ค๋ ๊ฒฝํฅ์ด ์๊ณ ,
latent space์ ๋ชจ๋ point๋ฅผ ์ด img์ mapping๊ฐ๋ฅํ๋ค.
๋ํ, ์์คํจ์์ Gradient๊ฐ 0์ ๊ฐ๊น์ด๊ฐ์ผ๋ก ๋ถ๊ดด(collapse)ํ๊ธฐ์ ์ด์ํ์์ ๋ฒ์ด๋๊ธฐ ์ด๋ ค์์ง๋ค.
โ ์ ์ฉํ์ง ์์ Loss
์์ค์ด ์์์๋ก ์์ฑ๋ imgํ์ง์ด ๋ ์ข์ ๊ฒ์ด๋ผ ์๊ฐํ ์ ์๋ค.
ํ์ง๋ง ์์ฑ์๋ ํ์ฌ ํ๋ณ์์ ์ํด์๋ง ํ๊ฐ๋๋ค.
ํ๋ณ์๋ ๊ณ์ ํฅ์๋๊ธฐ์ train๊ณผ์ ์ ๋ค๋ฅธ์ง์ ์์ ํ๊ฐ๋ ์์ค์ ๋น๊ตํ ์ ์๋ค.
์ฆ, ํ๋ณLoss๋ ๊ฐ์ํ๊ณ , ์์ฑLoss๋ ์ฆ๊ฐํ๋ค.→ GAN train๊ณผ์ ๋ชจ๋ํฐ๋ง์ด ์ด๋ ค์ด ์ด์ .
2. WGAN-GP (Wasserstein GAN with Gradient Penalty)
GAN Loss
GAN์ ํ๋ณ์โ์์ฑ์ ํ๋ จ ์ ์ฌ์ฉํ BCE Loss๋ฅผ ์ดํด๋ณด์.
ํ๋ณ์ Dํ๋ จ: real_img์ ๋ํ ์์ธก pi=D(xi)์ target yi=1์ ๋น๊ต.
์์ฑ์ Gํ๋ จ: ์์ฑ_img์ ๋ํ ์์ธก pi=D(G(zi))์ target yi=0์ ๋น๊ต.
[GAN D_Loss ์ต๋ํ ์]:
[GAN G_Loss ์ต์ํ ์]:
Wesserstein Loss
[GAN Loss์์ ์ฐจ์ด์ ]:
โ 1๊ณผ 0๋์ , yi = 1, yi = -1์ ์ฌ์ฉ.
โ D์ ๋ง์ง๋ง์ธต์์ sigmoid์ ๊ฑฐ.
→ ์์ธก pi๊ฐ [0,1]๋ฒ์์ ๊ตญํ๋์ง ์๊ณ [-∞,∞] ๋ฒ์์ ์ด๋ค ์ซ์๋ ๋ ์ ์๊ฒํจ.
์์ ์ด์ ๋ค๋ก WGAN์ ํ๋ณ์๋ ๋ณดํต ๋นํ์(Critic)๋ผ ๋ถ๋ฅด๋ฉฐ, ํ๋ฅ ๋์ ์ ์"score"๋ฅผ ๋ฐํํ๋ค.
[Wesserstein Lossํจ์]:
WGAN์ critic D๋ฅผ ํ๋ จํ๊ธฐ์ํด
real_img์ ๋ํ ์์ธก(D(xi))๊ณผ ํ๊ฒ(= 1)์ ๋น๊ต.
์์ฑ_img์ ๋ํ ์์ธก(D(G(zi)))๊ณผ ํ๊ฒ(= -1)์ ๋น๊ต.
∴ ์์ค์ ๊ณ์ฐ
[WGAN Critic D_Loss ์ต์ํ]: real๊ณผ ์์ฑ๊ฐ์ ์์ธก์ฐจ์ด ์ต๋ํ.
[WGAN G_Loss ์ต์ํ]: Critic์์ ๊ฐ๋ฅํ ๋์ ์ ์๋ฅผ ๋ฐ๋ img์์ฑ.
(= Critic์ ์์ฌ real_img๋ผ ์๊ฐํ๊ฒ ๋ง๋๋ ๊ฒ.)
1-Lipshitz Continuous function
sigmoid๋ก [0,1]๋ฒ์์ ๊ตญํํ์ง ์๊ณ
Critic์ด [-∞,∞] ๋ฒ์์ ์ด๋ค ์ซ์๋ ๋ ์ ์๊ฒํ๋ค๋ ์ ์ Wessertein Loss๊ฐ ์ ํ์์ด ์์ฃผ ํฐ ๊ฐ์ผ ์ ์๋ค๋ ๊ฒ์ธ๋ฐ, ๋ณดํต ์ ๊ฒฝ๋ง์์ ํฐ ์๋ ํผํด์ผํ๋ค.
๊ทธ๋ ๊ธฐ์, "Critic์ ์ถ๊ฐ์ ์ธ ์ ์ฝ์ด ํ์"ํ๋ค.
ํนํ, Critic์ 1-Lipshitz ์ฐ์ํจ์์ฌ์ผ ํ๋๋ฐ, ์ด์๋ํด ์ดํด๋ณด์.
Critic์ ํ๋์ img๋ฅผ ํ๋์ ์์ธก์ผ๋ก ๋ณํํ๋ ํจ์ D์ด๋ค.
์์์ ๋ input_img x1, x2์ ๋ํด ๋ค์ ๋ถ๋ฑ์์ ๋ง์กฑํ๋ฉด, ์ด ํจ์๋ฅผ 1-Lipshitz๋ผ ํ๋ค:
|x1-x2| : ๋ img ํฝ์ ์ ํ๊ท ์ ์ธ ์ ๋๊ฐ ์ฐจ์ด
|D(x1) - D(x2)| : Critic ์์ธก๊ฐ์ ์ ๋๊ฐ ์ฐจ์ด
๊ธฐ๋ณธ์ ์ผ๋ก ๊ธฐ์ธ๊ธฐ์ ์ ๋๊ฐ์ด ์ด๋์์๋ ์ต๋ 1์ด์ด์ผํ๋ค
= ๋ img๊ฐ Critic์์ธก๋ณํ๋น์จ ์ ํ์ด ํ์ํ๋ค๋ ์๋ฏธ.
WGAN-GP
WGAN์ Critic์ ๊ฐ์ค์น๋ฅผ ์์ [-0.01, 0.01]๋ฒ์์ ๋์ด๋๋ก
train batch ์ดํ weight clipping์ผ๋ก Lipshitz์ ์ฝ์ ๋ถ๊ณผํ๋ค.
์ด๋, ํ์ต์๋๊ฐ ํฌ๊ฒ ๊ฐ์ํ๊ธฐ์ Lipshitz์ ์ฝ์ ์ํด ๋ค๋ฅธ ๋ฐฉ๋ฒ์ ์ ์ฉํ๋ค:
๋ฐ๋ก Wesserstein GAN-Gradient Penalty์ด๋ค.
[WGAN-GP]: Gradient Norm์ด 1์์ ๋ฒ์ด๋๋ฉด ๋ชจ๋ธ์ ๋ถ์ด์ต์ ์ฃผ๋ ๋ฐฉ์์ด๋ค.
[Gradient Penalty Loss]:
input_img์ ๋ํ ์์ธก์ Gradient Norm๊ณผ 1์ฌ์ด ์ฐจ์ด๋ฅผ ์ ๊ณฑํ ๊ฒ.
๋ชจ๋ธ์ ์์ฐ์ค๋ GPํญ์ ์ต์ํํ๋ ๊ฐ์ค์น๋ฅผ ์ฐพ์ผ๋ คํ๊ธฐ์ ์ด ๋ชจ๋ธ์ ๋ฆฝ์์ธ ์ ์ฝ์ ๋ฐ๋ฅด๊ฒ ํ๋ค.
Train๊ณผ์ ๋์ ๋ชจ๋ ๊ณณ์์ Gradient๊ณ์ฐ์ ํ๋ค๊ธฐ์ WGAN-GP๋ ์ผ๋ถ์ง์ ์์๋ง Gradient๋ฅผ ๊ณ์ฐํ๋ค.
์ด๋, real_img์ fake_img์ ๊ฐ์ interpolation img๋ฅผ ์ฌ์ฉํ๋ค.
from torch.autograd import Variable from torch.autograd import grad as torch_grad def gradient_penalty(self, real_data, generated_data): batch_size = real_data.size()[0] # Calculate interpolation alpha = torch.rand(batch_size, 1, 1, 1) alpha = alpha.expand_as(real_data) interpolated = alpha*real_data.data + (1 - alpha)*generated_data.data interpolated = Variable(interpolated, requires_grad=True) # Calculate probability of interpolated examples prob_interpolated = self.D(interpolated) # Calculate gradients of probabilities with respect to examples gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated, grad_outputs=torch.ones(prob_interpolated.size()).cuda() if self.use_cuda else torch.ones( prob_interpolated.size()), create_graph=True, retain_graph=True)[0] # Gradients have shape (B,C,W,H) # so flatten to easily take norm per example in batch gradients = gradients.view(batch_size, -1) self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().data[0]) # Derivatives of the gradient close to 0 can cause problems because of # the square root, so manually calculate norm and add epsilon gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1) + 1e-12) # Return gradient penalty return self.gp_weight * ((gradients_norm-1)**2).mean()
[WGAN-GP์์์ Batch Normalization]
BN์ ๊ฐ์ batch์์ img๊ฐ์ correlation์ ๋ง๋ ๋ค.
๊ทธ๋ ๊ธฐ์ gradient penalty loss์ ํจ๊ณผ๊ฐ ๋จ์ด์ง์
WGAN-GP๋ Critic์์ BN์ ์ฌ์ฉํด์๋ ์๋๋ค.
3. CGAN (Conditional GAN)
prev.
์์ ์ค๋ช ํ ๋ชจ๋ธ๋ค์ "์ฃผ์ด์ง trainset์์ ์ฌ์ค์ ์ธ img๋ฅผ ์์ฑํ๋ GAN"์ด์๋ค.
ํ์ง๋ง, "์์ฑํ๋ ค๋ img์ ์ ํ์ ์ ์ดํ ์ ๋ ์์๋ค."
(ex. ์์ฑํ๋ ค๋ img์ ํ: ํฌ๊ฑฐ๋ ์์ ๋ฒฝ๋, ํ๋ฐ/๊ธ๋ฐ ๋ฑ๋ฑ)
latent space์์ randomํ ํ๋์ point sampling์ ๊ฐ๋ฅํ๋ค.
latent variable์ ์ ํํ๋ฉด ์ด๋ค ์ข ๋ฅ์ img๊ฐ ์์ฑ๋ ์ง ์ฝ๊ฒ ํ์ ๊ฐ๋ฅํ๋ค.
CGAN
[GAN v.s CGAN]:
CGAN์ GAN๊ณผ ๋ฌ๋ฆฌ "label๊ณผ ๊ด๋ จ๋ ์ถ๊ฐ์ ๋ณด๋ฅผ ์์ฑ์์ critic์ ์ ๋ฌํ๋ค๋ ์ "์ด๋ค.
โ ์์ฑ์: ์ด ์ ๋ณด๋ฅผ one-hot encoding vector๋ก latent space sample์ ๋จ์ํ ์ถ๊ฐ.
โ Critic: label ์ ๋ณด๋ฅผ RGB img์ ์ฑ๋์ ์ถ๊ฐ์ฑ๋๋ก ์ถ๊ฐ.
→ input img๊ฐ ๋์ผํ ํฌ๊ธฐ๊ฐ ๋ ๋ ๊น์ง one-hot encoding vector๋ฅผ ๋ฐ๋ณต.
[์ ์ผํ ๊ตฌ์กฐ ๋ณ๊ฒฝ์ฌํญ]:
label์ ๋ณด๋ฅผ G,D์ ๊ธฐ์กด ์ ๋ ฅ์ ์ฐ๊ฒฐํ๋ ๊ฒ.
class Generator(nn.Module): def __init__(self, generator_layer_size, z_size, img_size, class_num): super().__init__() self.z_size = z_size self.img_size = img_size self.label_emb = nn.Embedding(class_num, class_num) self.model = nn.Sequential( nn.Linear(self.z_size + class_num, generator_layer_size[0]), nn.LeakyReLU(0.2, inplace=True), nn.Linear(generator_layer_size[0], generator_layer_size[1]), nn.LeakyReLU(0.2, inplace=True), nn.Linear(generator_layer_size[1], generator_layer_size[2]), nn.LeakyReLU(0.2, inplace=True), nn.Linear(generator_layer_size[2], self.img_size * self.img_size), nn.Tanh() ) def forward(self, z, labels): # Reshape z z = z.view(-1, self.z_size) # One-hot vector to embedding vector c = self.label_emb(labels) # Concat image & label x = torch.cat([z, c], 1) # Generator out out = self.model(x) return out.view(-1, self.img_size, self.img_size)โ
class Discriminator(nn.Module): def __init__(self, discriminator_layer_size, img_size, class_num): super().__init__() self.label_emb = nn.Embedding(class_num, class_num) self.img_size = img_size self.model = nn.Sequential( nn.Linear(self.img_size * self.img_size + class_num, discriminator_layer_size[0]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), nn.Linear(discriminator_layer_size[0], discriminator_layer_size[1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), nn.Linear(discriminator_layer_size[1], discriminator_layer_size[2]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.3), nn.Linear(discriminator_layer_size[2], 1), nn.Sigmoid() ) def forward(self, x, labels): # Reshape fake image x = x.view(-1, self.img_size * self.img_size) # One-hot vector to embedding vector c = self.label_emb(labels) # Concat image & label x = torch.cat([x, c], 1) # Discriminator out out = self.model(x) return out.squeeze()โ
4. ์์ฝ
์์ Data๋ฅผ ์ง์ ์์ฑํ๋ ํ๋ฅ ์ ๊ณผ์ ์ผ๋ก ๋ฐ๋ํจ์๋ฅผ ์๋ฌต์ ์ผ๋ก ๋ชจ๋ธ๋ง๋ฐฉ์์ด๋ผ ํ์๋ค.
์ด๋ฒ์ ์๊ฐํ GAN์ ์ด 3๊ฐ์ง๊ฐ ์๋ค.(์ฐธ๊ณ )
โ DCGAN: mode collapse ๋ฐ gradient vanishing problem์กด์ฌ.
โก WGAN: DCGAN๋ฌธ์ ํด๊ฒฐ์ ์ํด ์์ ํ ์งํ.
WGAN-GP: ํ๋ จ๊ณผ์ ์ค 1-Lipshitz์กฐ๊ฑด(์์คํจ์์ Gradient Norm์ด 1์ด ๋๋๋ก ๋์ด๋น๊ธฐ๋ ํญ)์ ์ถ๊ฐ.
โข CGAN: ์์ฑ๋ ์ถ๋ ฅ Img์ ํ์ ์ ์ดํ๋๋ฐ ํ์ํ ์ถ๊ฐ์ ๋ณด๊ฐ ์ ๊ฒฝ๋ง์ ์ ๊ณต.
๋ค์์๋ sequential data modeling ์, ์ด์์ ์ธ AR Model์ ์์๋ณผ ๊ฒ์ด๋ค.
'Gain Study > Generation' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[G]Part 2-5. Energy-based Model (0) | 2024.01.29 |
---|---|
[G]Part 2-4. Normalizing Flows (2) | 2024.01.29 |
[G]Part 2-3. Auto Regressive Models (0) | 2024.01.26 |
[G]Part 2-1. VAE (2) | 2024.01.25 |
[G]Part 1. Intro. Generative Deep Learning (0) | 2024.01.25 |