๐ถ ์ด๋ก (Abstract)
- ResNet์ ๋ ์ค๋๋ ฅ ์๋ ์ ํ๋์ ๋ฉ์ง ์๋ ด๋์์ ๋ณด์ฌ์ฃผ๋ ๋งค์ฐ ์ฌ์ธต์ ์ธ ์ํคํ ์ฒ๊ตฐ์ผ๋ก ๋ถ์ํ๋ค.
๋ณธ ๋ ผ๋ฌธ์์๋ "Identity Mapping"์ "Skip Connection" ๋ฐ "After-addition Activation"๋ก ์ฌ์ฉํ ๋
์์ ํ/ํ์ ํ signal์ด ํ ๋ธ๋ก์์ ๋ค๋ฅธ ๋ธ๋ก์ผ๋ก ์ง์ ์ ํ๋ ์ ์์์ ์ ์ํ๋ residual building block ์ดํ์ propagation๊ณต์์ ๋ถ์ํ๋ค.
์ผ๋ จ์ ์ ๊ฑฐ(ablation)์คํ์ ์ด๋ฐ identity mapping์ ์ค์์ฑ์ ๋ท๋ฐ์นจํ๋ค.
์ด๋ ์๋ก์ด residual unit์ ์ ์ํ๋๋ก ๋๊ธฐ๋ถ์ฌํ์ฌ ํ๋ จ์ ๋ ์ฝ๊ฒ ํ๊ณ ์ผ๋ฐํ๋ฅผ ๊ฐ์ ํ๋ค.
CIFAR-10(4.62% ์ค๋ฅ) ๋ฐ CIFAR-100์ 1001์ธต ResNet๊ณผ ImageNet์ 200-layer ResNet์ ์ฌ์ฉํ์ฌ ๊ฐ์ ๋ ๊ฒฐ๊ณผ๋ฅผ ์ ์ผ๋ฉฐ ์ด์ ๋ํ ์ฝ๋๋ https://github.com/KaimingHe/resnet-1k-layers ์์ ํ์ธํ ์ ์๋ค.
1. ์๋ก (Introduction)
2. Analysis of Deep Residual Networks
• Discussions
3. On the Importance of Identity Skip Connections
3.1 Experiments on Skip Connections
• Constant scaling
• Exclusive gating
์ด๋ Highway Network๋ ผ๋ฌธ์ ์ง์นจ์ ๋ฐ๋ผ(people.idsia.ch/~rupesh/very_deep_learning/)
• Shortcut-only gating
• 1x1 convolutional shortcut
• Dropout Shortcut
3.2 Discussions
4. On the Usage of Activation Functions
4.1 Experiments on Activation
์ด Section์์๋ ResNet-110๊ณผ 164์ธต์ Bottleneck Architecture(ResNet-164๋ผ๊ณ ๋ ํจ)์ผ๋ก ์คํํ๋ค.
Bottleneck Residual Unit์ ๋ค์๊ณผ ๊ฐ์ด ๊ตฌ์ฑ๋๋ค.
์ฐจ์์ถ์๋ฅผ ์ํ 1×1 ๋ฐ 3×3 layer
์ฐจ์๋ณต์์ ์ํ 1×1 layer
์ด๋ [ResNet๋ ผ๋ฌธ]์์์ ์ค๊ณ๋ฐฉ์์ฒ๋ผ ๊ณ์ฐ๋ณต์ก๋๋ 2๊ฐ์ 3×3 Residual Unit๊ณผ ์ ์ฌํ๋ค. (์์ธํ ๋ด์ฉ์ ๋ถ๋ก์ ๊ธฐ์ฌ)
๋ํ, ๊ธฐ์กด์ ResNet-164๋ CIFAR-10์์ 5.93%์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์๋ค.(ํ2)
•BN after addition
• ReLU before addition
• Post-activation or Pre-activation ?
4.2 Analysis
•Ease of optimization
- ์ด ํจ๊ณผ๋ ResNet-1001์ ํ๋ จํ ๋ ๋งค์ฐ ๋๋๋ฌ์ง๋ค. (๊ทธ๋ฆผ 1์ ๊ณก์ .)
[ResNet๋ ผ๋ฌธ]์ ๊ธฐ์กด ์ค๊ณ๋ฐฉ์์ผ๋ก ํ๋ จ์ ์์ํ๋ฉด training Loss๊ฐ ๋งค์ฐ ๋๋ฆฌ๊ฒ ๊ฐ์ํ๋ค.
f = ReLU๊ฐ ์์ ๊ฐ์ ๊ฒฝ์ฐ, ์ ํธ์ ์ํฅ์ ๋ฏธ์น๋๋ฐ ์ด๋ Residual Unit์ด ๋ง์ผ๋ฉด ์ด ํจ๊ณผ๊ฐ ๋๋๋ฌ์ง๋ค.
์ฆ, Eqn.(3)(๋ฐ๋ผ์ Eqn.(5))์ ์ข์ ๊ทผ์ฌ์น๊ฐ ์๋๊ฒ ๋๋ค.
๋ฐ๋ฉด, f๊ฐ identity mapping์ธ ๊ฒฝ์ฐ, ์ ํธ๋ ์์์ ๋ Unit ์ฌ์ด์ ์ง์ ์ ํ๋ ์ ์๋ค.
1001์ธต์ด๋ ๋๋ ์ ๊ฒฝ๋ง์ training Loss๊ฐ์ ๋งค์ฐ ๋น ๋ฅด๊ฒ ๊ฐ์์ํจ๋ค(๊ทธ๋ฆผ 1).
๋ํ ์ฐ๋ฆฌ๊ฐ ์กฐ์ฌํ ๋ชจ๋ ๋ชจ๋ธ ์ค ๊ฐ์ฅ ๋ฎ์ Loss๋ฅผ ๋ฌ์ฑํ์ฌ ์ต์ ํ์ ์ฑ๊ณต์ ๋ณด์ฌ์ค๋ค.
- ๋ํ ResNet์ด ๋ ์ ์ ์ธต์ ๊ฐ์ง ๋ f = ReLU์ ์ํฅ์ด ์ฌ๊ฐํ์ง ์๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ๋ค(์: ๊ทธ๋ฆผ 6(์ค๋ฅธ์ชฝ)).
ํ๋ จ ์ด๋ฐ, training๊ณก์ ์ด ์กฐ๊ธ ํ๋ค์ด ๋ณด์ด์ง๋ง ๊ณง ์ข์ ์ํ๋ก ๋๋ค.๊ทธ๋ฌ๋ ๋จ์ (truncation)์ 1000๊ฐ์ ๋ ์ด์ด๊ฐ ์์ ๋ ๋ ๋น๋ฒํ ์ผ์ด๋๋ค.
• Reducing Overfitting
์ ์๋ "pre-activation" unit์ ์ฌ์ฉํ๋ ๊ฒ์ด Regualarizatoin์ ๋ฏธ์น๋ ๋ ๋ค๋ฅธ ์ํฅ์ ๊ทธ๋ฆผ 6(์ค๋ฅธ์ชฝ)๊ณผ ๊ฐ๋ค."pre-activation" ๋ฒ์ ์ ์๋ ด ์ training Loss๊ฐ์ด ์ฝ๊ฐ ๋ ๋์ง๋ง "test Error"๋ ๋ ๋ฎ๋ค.
์ด ํ์์ CIFAR-10๊ณผ 100 ๋ชจ๋์์ ResNet-110, ResNet-110(1-layer) ๋ฐ ResNet-164์์ ๊ด์ฐฐ๋๋ค.
์ด๋, ์ฐ๋ฆฌ๋ ์ด๊ฒ์ด BN์ "regularization" ํจ๊ณผ์ ์ํด ๋ฐ์ํ ๊ฒ์ผ๋ก ์ถ์ ๋๋ค.์๋ Residual Unit(๊ทธ๋ฆผ 4(a)์์ BN์ด ์ ํธ๋ฅผ ์ ๊ทํ(normalize)ํ์ง๋ง, ์ด๋ ๊ณง shortcut์ ์ถ๊ฐ๋๋ฏ๋ก ๋ณํฉ๋ ์ ํธ๋ ์ ๊ทํ(normalize)๋์ง ์์ต๋๋ค.
์ด ์ ๊ทํ๋์ง ์์ ์ ํธ๋ ๊ทธ๋ค์ weight-layer์ ์ ๋ ฅ๊ฐ์ผ๋ก ์ฌ์ฉ๋๋ค.
๋์กฐ์ ์ผ๋ก, ์ฐ๋ฆฌ์ "pre-activation"๋ฒ์ ์์, ๋ชจ๋ weight-layer์ ์ ๋ ฅ๊ฐ์ด ์ ๊ทํ๋์๋ค.
5. Results
• Comparisons on CIFAR-10/100
- ํ 4๋ CIFAR-10/100์ ๋ํ ์ต์ฒจ๋จ ๋ฐฉ๋ฒ์ ๋น๊ตํ์ฌ ๊ฒฝ์๋ ฅ ์๋ ๊ฒฐ๊ณผ๋ฅผ ์ป๋๋ค.
์ฐ๋ฆฐ ์ ๊ฒฝ๋ง ํญ์ด๋ filterํฌ๊ธฐ๋ฅผ ํน๋ณํ ์กฐ์ ํ๊ฑฐ๋ ์์ dataset์ ๋งค์ฐ ํจ๊ณผ์ ์ธ ์ ๊ทํ ๊ธฐ์ (์: Dropout)์ ์ฌ์ฉํ์ง ์๋๋ค๋ ์ ์ ์ฃผ๋ชฉํ๋ค.
์ฐ๋ฆฌ๋ ๋จ์ํ์ง๋ง ํ์์ ์ธ ๊ฐ๋ ์ ํตํด ๋ ๊น์ด ๋ค์ด๊ฐ ์ด๋ฌํ ๊ฒฐ๊ณผ๋ฅผ ์ป์ด์ ๊น์ด์ ํ๊ณ๋ฅผ ๋ฐ์ด๋ด๋ ์ ์ฌ๋ ฅ์ ๋ณด์ฌ์ค๋ค.
•Comparisons on ImageNet
- ๋ค์์ผ๋ก 1000-class ImageNet dataset์ ๋ํ ์คํ ๊ฒฐ๊ณผ์ด๋ค.
ResNet-101์ ์ฌ์ฉํด ImageNet์ ๊ทธ๋ฆผ 2์ 3์์ ์ฐ๊ตฌํ skip connection์ ์ฌ์ฉํ์ฌ ์๋น ์คํ์ ์ํํ์ ๋, ์ ์ฌํ ์ต์ ํ ์ด๋ ค์์ ๊ด์ฐฐํ๋ค.
์ด๋ฌํ non-idntity shortcut network์ training์ค๋ฅ๋ ์ฒซ learning rate(๊ทธ๋ฆผ 3๊ณผ ์ ์ฌ)์์ ๊ธฐ์กด ResNet๋ณด๋ค ๋ถ๋ช ํ ๋์ผ๋ฉฐ, ์์์ด ์ ํ๋์ด ํ๋ จ์ ์ค๋จํ๊ธฐ๋ก ๊ฒฐ์ ํ๋ค.
๊ทธ๋ฌ๋ ์ฐ๋ฆฌ๋ ImageNet์์ ResNet-101์ "BN after addition" ๋ฒ์ (๊ทธ๋ฆผ 4(b)์ ๋ง์ณค๊ณ ๋ ๋์ training Loss์ validation error์ค๋ฅ๋ฅผ ๊ด์ฐฐํ๋ค.
์ด ๋ชจ๋ธ์ ๋จ์ผ ํฌ๋กญ(224×224)์ validation error๋ 24.6%/7.5%์ด๋ฉฐ ๊ธฐ์กดResNet-101์ 23.6%/7.1%์ด๋ค.
์ด๋ ๊ทธ๋ฆผ 6(์ผ์ชฝ)์ CIFAR ๊ฒฐ๊ณผ์ ์ผ์นํฉ๋๋ค.
- ํ 5๋ ๋ชจ๋ ์ฒ์๋ถํฐ ํ๋ จ๋ ResNet-152์ ResNet-200์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋ค.
์ฐ๋ฆฌ๋ ๊ธฐ์กด ResNet ๋ ผ๋ฌธ์ด ๋ ์งง์ ์ธก๋ฉด s∈ [256, 480]์ ๊ฐ๋ scale-jittering์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ๋ จ์์ผฐ๊ธฐ ๋๋ฌธ์ s = 256 ([ResNet๋ ผ๋ฌธ]์์์ ๊ฐ์ด)์์ 224×224crop์ test๋ negative์ชฝ์ผ๋ก ํธํฅ๋์ด ์์๋ค.
๋์ , ๋ชจ๋ ๊ธฐ์กด ๋ฐ ResNets์ ๋ํด s = 320์์ ๋จ์ผ 320x320 crop์ testํ๋ค.
ResNets๋ ๋ ์์ ๊ฒฐ๊ณผ๋ฌผ์ ๋ํด ํ๋ จ๋ฐ์์ง๋ง, ResNets๋ ์ค๊ณ์ Fully-Convolution์ด๊ธฐ ๋๋ฌธ์ ๋ ํฐ ๊ฒฐ๊ณผ๋ฌผ์์ ์ฝ๊ฒ ํ ์คํธํ ์ ์๋ค.
์ด ํฌ๊ธฐ๋ Inception v3์์ ์ฌ์ฉํ 299×299์ ๊ฐ๊น๊ธฐ ๋๋ฌธ์ ๋ณด๋ค ๊ณต์ ํ ๋น๊ต๊ฐ ๊ฐ๋ฅํ๋ค.
- ๊ธฐ์กด ResNet-152๋ 320x320 crop์์ top-1 error๊ฐ 21.3%์ด๋ฉฐ, "pre-activation"์ 21.1%์ด๋ค.
ResNet-152์์๋ ์ด ๋ชจ๋ธ์ด ์ฌ๊ฐํ ์ผ๋ฐํ(generalization) ์ด๋ ค์์ ๋ณด์ด์ง ์์๊ธฐ ๋๋ฌธ์ ์ด๋์ด ํฌ์ง ์๋ค.
๊ทธ๋ฌ๋ ๊ธฐ์กด ResNet-200์ ์ค๋ฅ์จ์ 21.8%๋ก ๊ธฐ์กด ResNet-152๋ณด๋ค ๋๋ค.
๊ทธ๋ฌ๋ ๊ธฐ์กด ResNet-200์ ResNet-152๋ณด๋ค training error๊ฐ ๋ฎ์๋ฐ, ์ด๋ overfitting์ผ๋ก ์ด๋ ค์์ ๊ฒช๊ณ ์์์ ์์ฌํ๋ค.
"pre-activation" ResNet-200์ ์ค๋ฅ์จ์ 20.7%๋ก ๊ธฐ์กด ResNet-200๋ณด๋ค 1.1% ๋ฎ๊ณ ResNet-152์ ๋ ๋ฒ์ ๋ณด๋ค ๋ฎ๋ค. GoogLeNet๊ณผ InceptionV3์ scale ๋ฐ ์ข ํก(aspect)์ ๋น์จ์ ํ๋๋ฅผ ์ฌ์ฉํ ๋, ResNet-200์ Inception v3๋ณด๋ค ๋ ๋์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ธ๋ค(ํ 5).
์ฐ๋ฆฌ์ ์ฐ๊ตฌ๊ฐ ์งํ๋ ๋์ ๋์์, Inception-ResNet-v2 ๋ชจ๋ธ์ 19.9%/4.9%์ single crop ๊ฒฐ๊ณผ๋ฅผ ๋ฌ์ฑํ์๋ค.
• Computational Cost
- ์ฐ๋ฆฌ ๋ชจ๋ธ์ ๊ณ์ฐ ๋ณต์ก๋๋ ๊น์ด์ ๋ฐ๋ผ ์ ํ์ ์ด๋ค(๋ฐ๋ผ์ 1001-layer net์ 100-layer net๋ณด๋ค 10๋ฐฐ ๋ณต์กํ๋ค).
CIFAR์์ ResNet-1001์ 2๊ฐ์ GPU์์ ํ๋ จํ๋ ๋ฐ ์ฝ 27์๊ฐ์ด ๊ฑธ๋ฆฌ๊ณ ,
ImageNet์์ ResNet-200์ 8๊ฐ์ GPU์์ ํ๋ จํ๋ ๋ฐ ์ฝ 3์ฃผ๊ฐ ๊ฑธ๋ฆฐ๋ค(VGGNet๋ ผ๋ฌธ๊ณผ ๋๋ฑ).
6. Conclusions
์ด ๋ ผ๋ฌธ์ ResNet์ connection๋ฉ์ปค๋์ฆ ๋ค์์ ์๋ํ๋ ์ ํ ๊ณต์์ ์กฐ์ฌํ๋ค.
์ฐ๋ฆฌ์ ๊ฒฐ๊ณผ๋ฌผ์ identity shortcut connection ๋ฐ identity after-addition activation์ด ์ ๋ณด์ ์ ํ๋ฅผ ์ํํ๊ฒ ํ๊ธฐ ์ํด ํ์์ ์ด๋ผ๋ ๊ฒ์ ์์ํ๋ค.
์ด๋ฐ ๋ณ์ธํต์ ์คํ(Ablation Experimanet)์ ์ฐ๋ฆฌ์ ๊ฒฐ๊ณผ๋ฌผ๊ณผ ์ผ์นํ๋ ํ์์ ๋ณด์ฌ์ค๋ค.
์ฐ๋ฆฌ๋ ๋ํ ์ฝ๊ฒ ํ๋ จ๋๊ณ ์ ํ๋๋ฅผ ํฅ์์ํฌ ์ ์๋ 1000์ธต ์ฌ์ธต์ ๊ฒฝ๋ง์ ์ ์ํ๋ค
•Appendix: Implementation Details
๐ง ๋ ผ๋ฌธ ๊ฐ์_์ค์๊ฐ๋ ํต์ฌ ์์ฝ
"Identity Mappings in Deep Residual Networks"
Kaiming He, Xiangyu Zhang, Shaoqing Ren ๋ฐ Jian Sun์ด 2016๋ ์ ๋ฐํํ ์ฐ๊ตฌ ๋ ผ๋ฌธ์ผ๋ก ์ด ๋ ผ๋ฌธ์ ์ฌ์ธต ์ ๊ฒฝ๋ง์ ์ฑ๋ฅ ์ ํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ์๋ก์ด ์์ฐจ ๋คํธ์ํฌ ์ํคํ ์ฒ๋ฅผ ์ ์ํ๋ค.
[ํต์ฌ ๊ฐ๋ ]
1. ๊ธฐ์กด ResNet๊ณผ์ ์ฐจ์ด์
1. Shortcut Connections
์ด ๋ ผ๋ฌธ์ ๊ธฐ์กด์ ResNet์์ layer๊ฐ์ shortcut connection์์ "Identity Mapping"์ ์ฌ์ฉํ๋ค๋ ๊ฒ์ด๋ค.
- ๊ธฐ์กด ResNet: ๋ค์์ธต์ ์ถ๋ ฅ์ฐจ์๊ณผ ์ผ์นํ๋๋ก ์ ๋ ฅ์ ๋ณํํ๋ Residual Mapping์ ์ฌ์ฉ
- ResNet V2: transformation์ ์ฐํํ๊ณ ์ ๋ ฅ์ ๋ค์์ธต์ผ๋ก ์ง์ ์ ํํ๋ "Identity Mapping"์ ์ฌ์ฉ
2. Pre-activation
์ด ๋ ผ๋ฌธ์ ResNet์ ์์ฉํ ResNetV2๋ก ์ฌ์ ํ์ฑํ(pre-activation)์ ๋ํ ๊ฐ๋ ์ ๋์ ํ๋ค.
- BatchNormalization ๋ฐ ReLU๋ฅผ ๊ฐ conv.layer์ดํ๊ฐ ์๋, ์ด์ ์ ์ ์ฉํ๋ค.
- ์ด๋ฅผ ํตํดtraining performance๋ฅผ ๊ฐ์ ํ๊ณ ๋งค์ฐ ๊น์ ์ ๊ฒฝ๋ง์์์ overfitting์ ์ค์ฌ์ฃผ์๋ค.
[์ฅ์ โ _ Easy to Optimization]
- ์ด ํจ๊ณผ๋ ๊น์ ์ ๊ฒฝ๋ง(1001-layer ResNet)์ ํ์ต์ํฌ ๋ ๋ถ๋ช ํ๊ฒ ๋ํ๋๋ค.
๊ธฐ์กด ResNet์ Skip connetion์ ๊ฑฐ์ณ์ ์ ๋ ฅ๊ฐ๊ณผ ์ถ๋ ฅ๊ฐ์ด ๋ํด์ง๊ณ , ReLU ํจ์๋ฅผ ๊ฑฐ์น๋ค.
๋ํด์ง ๊ฐ์ด ์์์ด๋ฉด ReLU ํจ์๋ฅผ ๊ฑฐ์ณ์ 0์ด ๋๋๋ฐ, ์ด๋ ๋ง์ฝ, ์ธต์ด ๊น๋ค๋ฉด ์ด ์ฆ์์ ์ํฅ์ด ๋ ์ปค์ง๊ฒ ๋์ด ๋ ๋ง์ ๊ฐ์ด 0์ด ๋์ด ์ด๊ธฐ ํ์ต์์ ๋ถ์์ ์ฑ์ผ๋ก ์ธํ ์๋ ด์ด ๋์ง ์๋ ๋ฌธ์ ๊ฐ ๋ฐ์ํ ์ ์๋ค.
์ค์ ๋ก ์๋ ํ์ต ๊ณก์ ์ ๋ณด๋ฉด ์ด๊ธฐ์ Loss๊ฐ ์๋ ด๋์ง ์๋ ๋ชจ์ต์ ๋ณผ ์ ์๋ค.
ํ์ง๋ง pre-activation ๊ตฌ์กฐ๋ ๋ํด์ง ๊ฐ์ด ReLU ํจ์๋ฅผ ๊ฑฐ์น์ง ์์, ์์ ๊ฐ๋ ๊ทธ๋๋ก ์ด์ฉํ ์ ์๊ฒ ๋๋ค.
์ค์ ๋ก ํ์ต ๊ณก์ ์ ์ดํด๋ณด๋ฉด ์ ์๋ ๊ตฌ์กฐ๊ฐ ์ด๊ธฐ ํ์ต์์ loss๋ฅผ ๋ ๋น ๋ฅด๊ฒ ๊ฐ์์ํด์ ๋ณผ ์ ์๋ค.
[์ฅ์ โก_ Reduce Overfitting]
- ์ ๊ทธ๋ฆผ์ ๋ณด๋ฉด ์๋ ด์ง์ ์์ pre-activation ๊ตฌ์กฐ์ training loss๊ฐ original๋ณด๋ค ๋๋ค.
- ๋ฐ๋ฉด, test error๊ฐ ๋ฎ๋ค๋ ๊ฒ์ overfitting์ ๋ฐฉ์งํ๋ ํจ๊ณผ๊ฐ ์๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
- ์ด ๋ ผ๋ฌธ์์ ์ด ํจ๊ณผ์ ๋ํด Batch Normalization ํจ๊ณผ ๋๋ฌธ์ ๋ฐ์ํ๋ค๊ณ ์ถ์ธกํ๋๋ฐ, Original Residual unit์ BN์ ๊ฑฐ์น๊ณ ๊ฐ์ด shortcut์ ๋ํด์ง๋ฉฐ, ๋ํด์ง ๊ฐ์ ์ ๊ทํ๋์ง ์๋๋ค.
์ด ์ ๊ทํ๋์ง ์์ ๊ฐ์ด ๋ค์ conv. layer์ ์ ๋ ฅ๊ฐ์ผ๋ก ์ ๋ฌ๋๋ค.
Pre-activation Residual unit์ ๋ํด์ง ๊ฐ์ด BN์ ๊ฑฐ์ณ์ ์ ๊ทํ ๋ ๋ค์ convolution layer์ ์ ๋ ฅ๋์ overfitting์ ๋ฐฉ์งํ๋ค๊ณ ์ ์๋ ์ถ์ธกํ๋ค.
3. Recommendation
์ด ๋ ผ๋ฌธ์์๋ ResNet V2 ์ค๊ณ ๋ฐ ํ๋ จ์ ์ํด ์๊ฐ๋ ์ค์ฉ์ ์ธ ๊ถ์ฅ์ฌํญ์ ์ ์ํ๋๋ฐ, ์๋์ ๊ฐ๋ค.
โ Initialization
- ํ์คํธ์ฐจ = sqrt(2/n) ์ธ Gaussian๋ถํฌ๋ฅผ ์ฌ์ฉํด Conv.layer weight์ด๊ธฐํ๋ฅผ ๊ถ์ฅ
(์ด๋, n์ input channel์)
โก Batch Normalization
- pre-activation์ ์ด์ฉํ Batch Normalization์ ๊ถ์ฅํ๋ค.
- mini-batch์ statistics ์ํฅ์ ์ค์ด๊ธฐ ์ํด training/test์ค์๋ statistics์ด๋ํ๊ท ์ ์ฌ์ฉํ๋ค.
โข Learning Rate Schedule
- ์ด๊ธฐ ์๋ ด์ ๊ฐ์ํ๋ฅผ ์ํด warming up๊ตฌ๊ฐ์์ ์๋์ ์ผ๋ก ํฐ ํ์ต๋ฅ ์ฌ์ฉ
- ๋ฏธ์ธ์กฐ์ ์ ์ํด decay๊ตฌ๊ฐ์์ ๋ ์์ ํ์ต๋ฅ ์ฌ์ฉ
โฃ Weight Decay
- overfitting ๋ฐฉ์ง๋ฅผ ์ํด weight_decay = 1e-4 (0.0001)๋ฅผ ์ฌ์ฉ
โค Data Augmentation
- random cropping
- horizontal flipping
๐ง ๋ ผ๋ฌธ์ ์ฝ๊ณ Architecture ์์ฑ (with tensorflow)
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
def conv2d_bn(x, filters, kernel_size, strides=1, padding='same'):
x = Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding)(x)
x = BatchNormalization()(x)
return x
def residual_block(x, filters, kernel_size, strides=1):
shortcut = x
x = conv2d_bn(x, filters, kernel_size, strides)
x = ReLU()(x)
x = conv2d_bn(x, filters, kernel_size, 1)
if x.shape != shortcut.shape:
shortcut = Conv2D(filters=filters,
kernel_size=1,
strides=strides,
padding='same')(shortcut)
x = Add()([x, shortcut])
x = ReLU()(x)
return x
def resnetv2(input_shape, num_classes, num_layers, use_bottleneck=False):
num_blocks = (num_layers - 2) // 9
filters = 16
inputs = Input(shape=input_shape)
x = conv2d_bn(inputs, filters, 3)
for i in range(num_blocks):
for j in range(3):
strides = 1
if j == 0:
strides = 2
x = residual_block(x, filters, 3, strides)
filters *= 2
x = GlobalAveragePooling2D()(x)
x = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=inputs, outputs=x)
return model
model = resnetv2(input_shape=(224,224,3), num_classes=200, num_layers=152, use_bottleneck=True)
model.summary()