๐ถ ์ด๋ก (Abstract)
- ์ฌ์ธต์ ๊ฒฝ๋ง์ ํ์ต์ํค๊ธฐ ๋์ฑ ์ด๋ ต๋ค.
- "residual learning framework"๋ก ์ด์ ๊ณผ ๋ฌ๋ฆฌ ์๋นํ ๊น์ ์ ๊ฒฝ๋ง์ training์ ์ฝ๊ฒํ๋ ๋ฐฉ๋ฒ์ ์๊ฐํ๊ณ ์ ํ๋ค.
unreferencedํจ์ ๋์ layer input์ ์ฐธ์กฐํด "learning residual function"์ผ๋ก layer๋ฅผ ๋ช ์์ ์ผ๋ก ์ฌ๊ตฌ์ฑํ์๋ค.
์ด๋ฐ ์์ฐจ์ ๊ฒฝ๋ง(residual network)์ด ์ต์ ํํ๊ธฐ ๋์ฑ ์ฝ๊ณ ์๋นํ ์ฆ๊ฐ๋ ๊น์ด์์ ์ ํ๋๋ฅผ ์ป์ ์ ์๋ค๋ ๊ฒ์ ๋ณด์ฌ์ฃผ๋ ํฌ๊ด์ ์ธ(comprehensive) ๊ฒฝํ์ ์ฆ๊ฑฐ๋ฅผ ์ ๊ณตํ๋ค.
- ImageNet dataset์์ VGGNet๋ณด๋ค 8๋ฐฐ ๋ ๊น์ง๋ง ๋ณต์ก์ฑ์ ๋ฎ์ 152์ธต์ ์์ฐจ์ ๊ฒฝ๋ง์ ๋ํด ํ๊ฐํ๋ค.
์ด๋ฐ ์์ฐจ์ ๊ฒฝ๋ง์ ์์๋ธ์ ImageNet testset์์ 3.57%์ ์ค๋ฅ๋ฅผ ๋ฌ์ฑํ๋๋ฐ, ์ด๋ ILSVRC-2015์์ 1์๋ฅผ ์ฐจ์งํ๋ค.
๋ํ 100์ธต๊ณผ 1000์ธต์ ๊ฐ๋ CIFAR-10์ ๋ํ ๋ถ์๋ ์ ์ํ๋ค.
- ๊น์ด์ ๋ํ ํํ์ ๋ง์ ์๊ฐ์ ์ธ์ง์์ ์์ ๋งค์ฐ ์ค์ํ๋ค.
์ฐ๋ฆฌ์ ๋งค์ฐ๊น์๋ฌ์ฌ๋ COCO Object Detection dataset์์ ์๋์ ์ผ๋ก 28% ํฅ์๋ ๊ฒฐ๊ณผ๋ฅผ ์ป์๋ค.
์ฌ์ธต์์ฐจ์ ๊ฒฝ๋ง(Deep residual net)์ ILSVRC. &. COCO 2015 ๋ํ1์ ์ ์ถํ ์๋ฃ์ ๊ธฐ๋ฐ์ผ๋ก ImageNet๊ฐ์ง, ImageNet Localization, COCO ๊ฐ์ง, COCO segmentation ์์ ์์๋ 1์๋ฅผ ์ฐจ์งํ๋ค.
1. ์๋ก (Introduction)
• Deep CNN์ image classification์ ๋ํ๊ตฌ๋ก ์ด์ด์ก๋ค.
์ฌ์ธต์ ๊ฒฝ๋ง์ ๋น์ฐํ๊ฒ๋ ์ ·์ค·๊ณ ์์ค์ ํน์ง์ ํตํฉํ๊ณ , ๋ถ๋ฅ๊ธฐ๋ ๋ค์ธต์ ์ฒ์๋ถํฐ ๋๊น์ง ๋ถ๋ฅํ๋ฉฐ feature์ "level"์ ๊น์ด๊ฐ ๊น์ด์ง๋ฉด์ ์ธต์ด ์์ผ์๋ก ํ๋ถํด์ง๋ค. (ํ์ฌ ์ ๊ฒฝ๋ง์ ๊น์ด๋ ์์ฃผ ์ค์ํ๋ค๋ ๊ฒ์ด ์ ๋ก ์ด๋ค.)
- Depth์ ์ค์์ฑ์ ๋ํด ๋ค์๊ณผ ๊ฐ์ ์ง๋ฌธ์ด ๋ฐ์ํ๋ค: ๋ ๋ง์ ์ธต์ ์์ ๊ฒ ๋งํผ ์ ๊ฒฝ๋ง์ ํ์ต์ํค๊ธฐ ๋ ์ฌ์ธ๊น?
[Is learning better networks as easy as stacking more layers?]
- ์ด ์ง๋ฌธ์ ๋ต์ ์ํ ํฐ ์ฅ์ ๋ฌผ์ ์ ๋ช ๋์ gradient vanishing/exploding๋ฌธ์ ๋ก ์์๋ถํฐ ์๋ ด์ ๋ฐฉํดํ๋ ๊ฒ์ด๋ค.
๋ค๋ง, ์ด ๋ฌธ์ ๋ ์ด๊ธฐํ๋ฅผ ์ ๊ทํํ๊ฑฐ๋ ์ค๊ฐ์ ์ ๊ทํ์ธต์ ๋ฃ์ด ์์ญ๊ฐ(tens)์ ์ธต์ ์ ๊ฒฝ๋ง์ด ์ญ์ ํ๋ฅผ ํตํด SGD๋ฅผ ์ํ ์๋ ด์ ์์ํ๋ ๋ฐฉ๋ฒ๊ณผ ๊ฐ์ด ๋๋ถ๋ถ ๋ค๋ค์ก๋ค.
[Degradation Problem]
- ๋ ๊น์ ์ ๊ฒฝ๋ง์ด ์๋ ด์ ์์ํ ๋, ์ฑ๋ฅ์ ํ(degradation)๋ฌธ์ ๊ฐ ๋ ธ์ถ๋์๋ค : ์ ๊ฒฝ๋ง๊น์ด๊ฐ ์ฆ๊ฐํ๋ฉด ์ ํ๋๊ฐ ํฌํ์ํ๊ฐ ๋๊ณ , ๊ทธ ๋ค์์ ๋น ๋ฅด๊ฒ ์ ํ๋๋ค.
- ์ด ๋ฌธ์ ์ ์์์น ๋ชปํ ๋ฌธ์ ์ ์ ๋ฐ๋ก overfitting์ด ์ด ๋ฌธ์ ๋ฅผ ์ผ๊ธฐํ์ง ์๋๋ค๋ ์ ์ธ๋ฐ, ์ ์ ํ ์ฌ์ธต๋ชจ๋ธ์ ๋ ๋ง์ ์ธต์ ์ถ๊ฐํ๋ฉด (์ฐ๋ฆฌ์ ์ฐ๊ตฌ๊ฒฐ๊ณผ์ฒ๋ผ) ๋ ๋์ training error๊ฐ ๋ฐ์ํ๋ค.
์์ ๊ทธ๋ฆผ์ ๋ํ์ ์ธ ์์์ด๋ค.
- training ์ ํ๋์ ์ฑ๋ฅ์ ํ๋ ๋ชจ๋ ์์คํ ์ด optimize๋ฅผ ๋น์ทํ ์์ค์ผ๋ก ์ฝ๊ฒ ํ ์ ์๋ค๋ ๊ฒ์ ์์ฌํ๋ค.
์ด๋ฅผ ์ํด ๋ ์์ ๊ตฌ์กฐ์ ๋ ๋ง์ ์ธต์ ์ถ๊ฐํ๋ ๋ ๊น์ ๊ตฌ์กฐ๋ฅผ ๊ณ ๋ คํด๋ณด์.
[Shallow Architecture. vs. Deeper Architecture]
๋ ๊น์ ๋ชจ๋ธ์ ๋ํ ๊ตฌ์ฑ์ ์ํ ํด๊ฒฐ์ฑ ์ด ์กด์ฌํ๋ค
- ์ถ๊ฐ๋ layer๋ identity mapping์ด๋ค.
- ๋ค๋ฅธ์ธต์ ํ์ต๋ ์์๋ชจ๋ธ์์ ๋ณต์ฌ๋๋ค.
์ด๋ฐ ๊ตฌ์กฐ์ ํด๊ฒฐ์ฑ ์ ์ฌ์ธต๋ชจ๋ธ์ด ์์๋ชจ๋ธ๋ณด๋ค ๋ ๋์ training error๋ฅผ ์์ฑํ์ง ์์์ผ ํจ์ ๋ํ๋ธ๋ค.
ํ์ง๋ง, ์คํ์ ํ์ฌ์ ํด๊ฒฐ์ฑ ์ ์ด๋ฐ๊ตฌ์กฐ์ ํด๊ฒฐ์ฑ ๋ณด๋ค ๋น๊ต์ ์ข๊ฑฐ๋ ๋ ๋์ ํด๊ฒฐ์ฑ ์ ์ฐพ์ ์ ์์๋ค.
cf. [Identity. &. Identity Mapping]
ResNet๊ตฌ์กฐ์์ Residual Block์ identity(= identity mapping) ์ ๋งํ๋ค.
- ์ ๋ ฅ์์ ์ถ๋ ฅ์ผ๋ก์ ์ ์ฒด ๋งคํ์ ํ์ตํ๋ ๋์ ์ ๋ ฅ์ ์ฝ๊ฐ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ํ์ต์ ์งํํ๋ค.
- Identity Block์ Skip Connection์ด Identity mapping์ ์ํํ๋ residual block์ ์ผ์ข ์ด๋ค.
- ์ฆ, ๋ณํ์์ด ์ ๋ ฅ์ด block์ ์ถ๋ ฅ์ ์ง์ ์ถ๊ฐ๋๋ ๊ฒ์ผ๋ก Identity Block์ input์ ์ฐจ์์ ์ ์ง์ํจ๋ค.
- Identity Block์ Residual Networks์์ ์ ๋ ฅ์ ์ฐจ์์ ์ ์งํ๋ฉด์ ๋น์ ํ์ฑ์ ๋์ ํ๋ ๋ฐ ์ฌ์ฉ๋๋ค.
์ด๋ ์ ๊ฒฝ๋ง์ด ๋ ๋ณต์กํ ํน์ง์ ํ์ตํ๋๋ก ๋๊ณ ๋งค์ฐ ๊น์ ์ ๊ฒฝ๋ง์์ ๋ฐ์ํ ์ ์๋ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ฅผ ๋ฐฉ์งํ๋ค.
์ด ๋ ผ๋ฌธ์์, Degradation Problem์ "deep residual learning framework"๋ฅผ ์ด์ฉํด ๋ค๋ฃฐ ๊ฒ์ด๋ค.
- ๊ฐ ๋ช ๊ฐ์ ์ ์ธต(stacked layer)์ด ์ํ๋ ๊ธฐ๋ณธ์ ๋งตํ์ ์ง์ ๋ง์ถ๊ธฐ(fit)๋ฅผ ๋ฐ๋ผ๋ ๋์ , ์ด๋ฌํ layer๊ฐ ์์ฐจ๋งตํ์ ์ ํฉํ๋๋ก ๋ช ์์ ์ผ๋ก ํ์ฉํ๋ค.
- ์ฐ๋ฆฐ ์ฑ๋ฅ์ ํ๋ฌธ์ (Degradation Problem)๋ฅผ ๋ณด์ฌ์ฃผ๊ณ ์ฐ๋ฆฌ์ ํด๊ฒฐ๋ฒ์ ํ๊ฐํ๊ธฐ ์ํด ImageNet์ผ๋ก ํฌ๊ด์ ์ธ ์คํ์ ์งํํ์ฌ ๋ค์ 2๊ฐ์ง๋ฅผ ํ์ธํ์๋ค.
โ ๊ทน๋๋ก ๊น์ ์์ฐจ์ ๊ฒฝ๋ง์ ์ต์ ํ ํ๊ธฐ ์ฝ๋ค.
- ๋จ์ํ ์ธต๋ง ์๋ ์๋์ ์ผ๋ก "ํ๋ฒํ(plain)" ์ ๊ฒฝ๋ง์ ๊น์ด๊ฐ ์ฆ๊ฐํ๋ฉด ๋ ๋์ training error๋ฅผ ๋ณด์ฌ์ค๋ค.
โก ์ฐ๋ฆฌ์ ์ฌ์ธต์์ฐจ์ ๊ฒฝ๋ง์ ๊น์ด๊ฐ ํฌ๊ฒ ์ฆ๊ฐํ์ฌ ์ ํ๋๋ฅผ ์ฝ๊ฒ ๋์๋ค.
์ด๋ optimization์ ์ด๋ ค์๊ณผ ์ฐ๋ฆฌ์ ๋ฐฉ๋ฒ์ ํจ๊ณผ๊ฐ ํน์ dataset์ ์ ์ฌํ์ง ์์์ ์์ฌํ๋ค.
- ImageNet Classification dataset์์ ์ฐ๋ฆฌ๋ ๊ทน๋๋ก ์ฌ์ธต์ ์ธ ์์ฐจ์ ๊ฒฝ๋ง์์ํด ์ฐ์ํ ๊ฒฐ๊ณผ๋ฅผ ์ป์๋ค.
์ฐ๋ฆฌ์ "152-layer residual net"์ ImageNet์ ์ ์ถ๋ ๊ฐ์ฅ ๊น์ ์ ๊ฒฝ๋ง์ด์ง๋ง VGG๋ณด๋ค๋ ๋ณต์ก์ฑ์ด ๋ฎ๋ค.
์์๋ธ์ ImageNet testset์์ 3.57%์ top-5 error๋ฅผ ๊ธฐ๋กํ์ผ๋ฉฐ ILSVRC 2015 classification๋ํ์์ 1์๋ฅผ ์ฐจ์งํ๋ค.
๊ทน๋๋ก ๊น์์ ๊ฒฝ๋ง์ ์๋ก๋ค๋ฅธ ์ธ์์์ ์์๋ ์ฐ์ํ ์ผ๋ฐํ(generalization)์ฑ๋ฅ์ ๋ฐํํด ๋ค์๊ฐ์ ILSVRC ๋ฐ COCO 2015์์ 1์๋ฅผ ์ถ๊ฐ๋ก ๋ฌ์ฑํ๋ค : ImageNet detection, ImageNet localization, COCO detection, and COCO segmentation in ILSVRC & COCO 2015 competitions
2. Related Work
• Residual Representations.
- image recognition์์, VLAD(Vector of Locally Aggregated Descriptors)๋ dictionary์ ๊ดํ์ฌ ์์ฐจ๋ฒกํฐ์ ์ํด encoding๋๋ ํํ์ด๋ฉฐ, Fisher Vector๋ VLAD์ ํ๋ฅ ๋ก ์ ์ธ ๋ฒ์ ์ ๊ณต์์ผ๋ก ๋ง๋ค์ด์ง๋ค.
๋๊ฐ์ง ๋ชจ๋ image ํ๋ณต ๋ฐ ๋ถ๋ฅ๋ฅผ ์ํ ๊ฐ๋ ฅํ shallow representation์ด๋ค.
Vector ์์ํ(quantization)์ ๊ฒฝ์ฐ, ์์ฐจ๋ฒกํฐ์ ์ธ์ฝ๋ฉ์ด ๊ธฐ์กด๋ฒกํฐ ์ธ์ฝ๋ฉ๋ณด๋ค ๋ ํจ๊ณผ์ ์ผ๋ก ๋ํ๋ฌ๋ค.
cf) VLAD๋ feature encoding์ ์ํ ์ด๋ฏธ์ง์ฒ๋ฆฌ๊ธฐ์ ๋ก Local Image feature๋ฅผ ๊ณ ์ ๊ธธ์ด๋ฒกํฐํํ(fixed-length vector representation)์ผ๋ก ์ธ์ฝ๋ฉํ๋ ๊ฒ์ด๋ค.
ResNet์ ์ฌ์ฉํ๋ ์ผ๋ถ๊ณผ์ ์์ VLAD๋ ๋ถ๋ฅ๋ฅผ ์ํด ์ต์ข softmax์ธต์ ์ฌ์ฉํ๋ ๋์ , ResNet์ ์ค๊ฐ์ธต์์ ์ถ์ถํ ํน์ง์ ์ธ์ฝ๋ฉํ๋๋ฐ ์ฌ์ฉํ์ฌ ๋์ฑ ์ธ๋ถํ๋ ํน์งํํ์ด ๊ฐ๋ฅํ๋ค.
- ์ ์์ค์ ๋น์ ์์ ํธ๋ฏธ๋ถ๋ฐฉ์ ์(PDE.,Partial Differential Equations)์ ํด๊ฒฐํ๊ธฐ ์ํด ๋๋ฆฌ ์ฌ์ฉ๋๋ Mulit-Grid๋ฐฉ๋ฒ์ ์์คํ ์ ์ฌ๋ฌ์ฒ๋(multiple scale)์์ ํ์๋ฌธ์ (subproblem)๋ก ์ฌ๊ตฌ์ฑ(reformulate)ํ๋ค. ์ด๋, ๊ฐ ํ์๋ฌธ์ ๋ ๋ ๊ฑฐ์น ์ฒ๋์ ๋ฏธ์ธํ ์ฒ๋์ฌ์ด์์ ์์ฐจํด๊ฒฐ(residual solution)์ ๋ด๋นํ๋ค.
Multi-Grid์ ๋์์ ๊ณ์ธต์ ๊ธฐ๋ณธ ๋์ ๋๋ฆฌ์ ์กฐ๊ฑดํ(hierarchical basis pre-conditioning)์ด๋ค.
์ด๋ ๋ ์ฒ๋ ์ฌ์ด์ ์์ฐจ๋ฒกํฐ๋ฅผ ๋ํ๋ด๋ ๋ณ์์ ์์กดํ๋ฉฐ ์ด๋ฐ ํด๊ฒฐ๋ฒ์ ํด๊ฒฐ์ฑ ์ ์์ฐจํน์ฑ(residual nature)์ ๋ชจ๋ฅด๋ ๊ธฐ์กด์ ํด๊ฒฐ์ฑ ๋ณด๋ค ํจ์ฌ ๋น ๋ฅด๊ฒ ์๋ ดํ๋ ๊ฒ์ผ๋ก ๋ํ๋ฌ์ผ๋ฉฐ ์ด๋ฐ ๋ฐฉ๋ฒ์ ์ข์ ์ฌ๊ตฌ์ฑ(reformulation)์ด๋ ์ ์ ์กฐ๊ฑด(preconditioning)์ด ์ต์ ํ๋ฅผ ๋จ์ํ ํ๋ค๋ ๊ฒ์ ์์ฌํ๋ค.
• Shorcut Connections
- Shorcut Connection์ ์ ๋ํ๋ ์ด๋ก ๋ฐ ์ค์ต์ ์ค๋ ์ฐ๊ตฌํด์๋ค.
MLP training์ ์ด๊ธฐ์ค์ต์ ์ ๊ฒฝ๋ง ์ ๋ ฅ์์ ์ถ๋ ฅ์ผ๋ก ์ฐ๊ฒฐ๋ ์ ํ์ธต(linear layer)๋ฅผ ์ถ๊ฐํ๋ ๊ฒ์ด๋ค.
๋ช๋ช์ ์ค๊ฐ์ธต์ด gradient vanishing/exploding์ ๋ค๋ฃจ๊ธฐ ์ํด ๋ณด์กฐ๋ถ๋ฅ๊ธฐ(auxiliary classifier)์ "์ง์ ์ฐ๊ฒฐ"๋๋ค.
GoogLeNet(https://chan4im.tistory.com/149)์์, "Inception"์ธต์ shortcut ๋ถ๊ธฐ์ ๋ช๊ฐ์ ๋ ๊น์ ๋ถ๊ธฐ๋ก ๊ตฌ์ฑ๋๋ค.
- ์ฐ๋ฆฌ์ ์ฐ๊ตฌ๊ฐ ์งํ๋ ๋ ๋์์, "highway networks"๋ gating function์ด ์๋ shortcut์ฐ๊ฒฐ์ ์ ์ํ์๋ค.
์ด gate๋ parameter๊ฐ ์๋ identity shortcut๊ณผ ๋ฌ๋ฆฌ data์ ์์กดํ๊ณ parameter๋ฅผ ๊ฐ๊ณ ์๋ค.
gate shortcut์ด "closed"(= 0์ ์ ๊ทผํ ์๋ก) "hightway networks"์ layer๋ non-residual function์ ๋ํ๋ธ๋ค.
๋์กฐ์ ์ผ๋ก, ์ฐ๋ฆฌ์ ๊ณต์์ ํญ์ ์์ฐจํจ์๋ฅผ ํ์ตํ๋ค.
์ฐ๋ฆฌ์ identity shortcut์ ๊ฒฐ์ฝ "closed"๋์ง ์๊ณ ํ์ตํด์ผํ ์ถ๊ฐ ์์ฐจํจ์์ ๋ชจ๋ ์ ๋ณด๊ฐ ํญ์ ํต๊ณผํ๋ค.
๊ฒ๋ค๊ฐ highway network๋ ๊น์ด๊ฐ ๊ทน๋๋ก ์ฆ๊ฐ(100๊ฐ ์ด์์ ์ธต)ํ์ฌ๋ ์ ํ๋์ ํฅ์์ ๋ณด์ฌ์ฃผ์ง ์์๋ค.
3. Deep Residual Learning
3.1. Residual Learning
3.2. Identity Mapping by Shortcuts
3.3. Network Architectures
•Plain Network
- plain ์ ๊ฒฝ๋ง์ ํ ๋๋ ์ฃผ๋ก VGGNet์ ์ฒ ํ์์ ์๊ฐ์ ๋ฐ์๋ค. (Fig.3, ์ผ์ชฝ)
- conv.layer๋ ๋๋ถ๋ถ 3x3 filter๋ฅผ ์ฌ์ฉํ๋ฉฐ, 2๊ฐ์ง์ ๊ฐ๋จํ ์ค๊ณ๋ฐฉ์์ ๋ฐ๋ฅธ๋ค.
โ ๋์ผํ ํฌ๊ธฐ์ ํน์ง๋งต์ถ๋ ฅ์ ๋ํ layer๋ ๋์ผํ ์์ filter๋ฅผ ๊ฐ๋๋ค.
โก ํน์ง๋งต ํฌ๊ธฐ๊ฐ 1/2(์ ๋ฐ์ผ๋ก ์ค)์ด๋ฉด, layer๋น ์๊ฐ๋ณต์ก์ฑ์ ์ ์งํด์ผ ํ๊ธฐ์ filter์๊ฐ 2๋ฐฐ๊ฐ ๋๋ค.
- ์ฐ๋ฆฐ stride=2์ธ conv.layer์ ์ํด ์ง์ downsampling์ ์ํํ๋ค.
- ์ ๊ฒฝ๋ง์ Global AveragePooling๊ณผ Softmax๊ฐ ์๋ 1000-way Fully-Connected๋ก ์ข ๋ฃ๋๋ค.
- weight-layer์ ์ด ๊ฐ์๋ ๊ทธ๋ฆผ3์ ์ค๊ฐ๊ณผ ๊ฐ์ 34๊ฐ์ด๋ค.
• Residual Network
- ์์ plain์ ๊ฒฝ๋ง์ ๊ธฐ๋ฐ์ผ๋ก, ์ฐ๋ฆฌ๋ ์ ๊ฒฝ๋ง์ counterpart residual ๋ฒ์ ์ผ๋ก ๋ฐ๊พธ๋ Shortcut Connection(๊ทธ๋ฆผ 3, ์ค๋ฅธ์ชฝ)์ ์ฝ์ ํ๋ค.
- Identity shortcuts(Eqn. (1))์ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ด ๋์ผํ ์น์์ผ ๋ ์ง์ ์ฌ์ฉํ ์ ์๋ค(๊ทธ๋ฆผ 3์ ์ค์ shortcuts).
- ์ฐจ์๊ฐ ์ฆ๊ฐํ๋ฉด(๊ทธ๋ฆผ 3์ ์ ์ shortcuts), ์ฐ๋ฆฌ๋ ๋ ๊ฐ์ง ์ต์ ์ ๊ณ ๋ คํ๋ค:
โ shortcut์ ์ฐจ์์ ๋๋ฆฌ๊ธฐ ์ํด ์ถ๊ฐ์ ์ผ๋ก 0๊ฐ์ ํญ๋ชฉ์ด ํจ๋ฉ๋ ์ํ์์ Identity mapping์ ๊ณ์ ์ํํฉ๋๋ค.
์ด๋, ์ถ๊ฐ์ ์ธ ๋งค๊ฐ ๋ณ์๋ฅผ ๋์ ํ์ง ์์ต๋๋ค;
โก Eqn. (2)์ Projection shortcut์ ์ฐจ์์ ์ผ์น์ํค๋ ๋ฐ ์ฌ์ฉ๋๋ค(1×1 convolutions๋ก ์ํ).
- ๋ ์ต์ ๋ชจ๋์์ shortcut์ ๋ ๊ฐ์ง ํฌ๊ธฐ์ ํน์ง๋งต ํต๊ณผ ์, stride=2๋ก ์ํ๋๋ค.
3.4. Implementation
- [AlexNet, VGGNet]์ ์คํ์ ๋ฐ๋ผ์ ๊ตฌํ์ ์งํํ์๋ค. [https://chan4im.tistory.com/145, https://chan4im.tistory.com/146]
- scale augmentation[VGGNet]์ ์ํด [256, 480]๋ก randomํ๊ฒ ์ํ๋ง, resize๋ฅผ ์งํํ์๋ค.
- 224x224 crop์ ๋๋คํ๊ฒ image์์ ์ํ๋ง๋๊ฑฐ๋ [AlexNet]์ฒ๋ผ ํฝ์ ๋น ํ๊ท ๊ฐ์ ์ฐจ๋ฅผ ์ด์ฉํ horizontal flip์ ์งํํ์์ผ๋ฉฐ ์ ์์ ์ธ color augmentation์ [Alexnet]๋ฐฉ์์ ์ด์ฉํ๋ค.
- Batch Normalization์ ์ฑํํ์ฌ BN๋ ผ๋ฌธ[https://chan4im.tistory.com/147]์์ ๋์จ ๊ฒ ์ฒ๋ผ Conv.layer ์งํ, activation์ด์ ์ ์ฌ์ฉ์ ํด์ฃผ์๋ค.
- ReLU๋ ผ๋ฌธ[https://chan4im.tistory.com/150]์์ ์ฒ๋ผ weight๋ฅผ ์ด๊ธฐํํ๊ณ ๋ชจ๋ ๊ธฐ๋ณธ/์์ฐจ์ ๊ฒฝ๋ง์ ์ฒ์๋ถํฐ trainingํ๋ค.
- mini-batch size๊ฐ 256์ธ SGD(weight decay=0.0001. &. momentum=0.9)๋ฅผ ์ฌ์ฉํ๋ฉฐ, learning rate๋ ์ด๊ธฐ๊ฐ์ด 0.1๋ก ์์ํด ํ์ต์ ์ฒดํ์ ์ฆ, error plateaus(SGD๋ Plateau์ ์ทจ์ฝํจ)๊ฐ ๋ฐ์ํ๋ฉด 10์ฉ ํ์ต๋ฅ ์ ๋๋์ด ์ค๋ค.
๋ชจ๋ธ์ ์ต๋ 60 x 10^4 iteration์ผ๋ก training๋๋ค.
- [Batch Normalization]๋ ผ๋ฌธ์ ๊ทผ๊ฑฐํ์ฌ Dropout์ ๋ฐฐ์ ํ๊ณ ์คํ์ ์งํํ๋ค.
- ์คํ์, ๋น๊ต๋ถ์์ ์ํด standard 10-crop testing์ ์ฑํํ๋ฉฐ[AlexNet], ์ต์์ ๊ฒฐ๊ณผ๋ฅผ ์ํด [VGG, Delving Deep into Rectifiers]๋ ผ๋ฌธ์ฒ๋ผ Fully-Convolutionalํํ๋ฅผ ์ฑํํ๋ค.
๋ํ, ์ฌ๋ฌ ์ฒ๋์์์ score๋ฅผ ํ๊ท (average)ํ๋๋ฐ, ์ด๋ image์ ํฌ๊ธฐ๋ ์งง์์ชฝ์ด {224, 256, 384, 480, 640}์ ์ค๋๋ก ์กฐ์ ๋๋ค.
4. Experiments
4.1. ImageNet Classification
• Plain Networks.
- ๋จผ์ 18์ธต, 34์ธต plain์ ๊ฒฝ๋ง์ ํ๊ฐํ๋ค. [34-layer plain net (Fig.3.์ค๊ฐ), ์์ธํ ๊ตฌ์กฐ๋ ์๋ ํ1์ ์ฐธ์กฐ.]
- ํ 2์ ๊ฒฐ๊ณผ๋ ๋ ๊น์ 34-layer plain net์ด ๋ ์์ 18-layer plain net๋ณด๋ค ๋ ๋์ val_Error๊ฐ์ ๊ฐ์์ ๋ณด์ธ๋ค.
์ด์ ๋ฅผ ๋ฐํ๊ธฐ ์ํด ๊ทธ๋ฆผ 4(์ผ์ชฝ)์์ training๊ณผ์ ์ค ๋ฐ์ํ training/validation error๋ฅผ ๋น๊ตํ๋ค
์ ๊ทธ๋ฆผ์์ ์ฐ๋ฆฌ๋ ์ฑ๋ฅ์ ํ๋ฌธ์ (Degradation Problem)์ ๋ฐ๊ฒฌํ๋ค.
18-layer plain net์ solution space๊ฐ 34-layer plain net์ ๋์ฒดํจ์๋ ๋ถ๊ตฌํ๊ณ 34-layer plain net์ ์ ์ฒด์ ์ธ training์ ์ฐจ์ ๊ฑธ์ณ ๋์ training error๋ฅผ ๊ฐ๋๋ค.
- ์ฐ๋ฆฐ ์ด๋ฐ ์ต์ ํ ์ด๋ ค์์ด gradient vanishing์ผ๋ก ์ธํ ๊ฐ๋ฅ์ฑ์ ๋ฎ๋ค๊ณ ์ฃผ์ฅํ๋ค.
์ด๋ฐ plain์ ๊ฒฝ๋ง์ ์์ ํ์ ํธ๊ฐ 0์ด ์๋ ๋ถ์ฐ๊ฐ์ ๊ฐ๋๋ก ๋ณด์ฅํ๋ Batch Normalization์ผ๋ก ํ๋ จ๋๋ค.
์ฐ๋ฆฐ ๋ํ ์ญ์ ํ๋ ๊ธฐ์ธ๊ธฐ๊ฐ BN๊ณผ ํจ๊ป healthy norm์ ๋ํ๋ด๋ ๊ฒ์ ํ์ธํ๋ค.
๊ทธ๋์ ์์ ํ์ ์ญ์ ํ์ ์ ํธ๋ค์ด ์ฌ๋ผ์ง์ง ์๋๋ค.
(์ค์ ๋ก 34-layer plain net์ ํ 3์์ ๋ณด์ด๋ฏ ์ฌ์ ํ ๊ฒฝ์๋ ฅ์๋ ์ ํ๋๋ฅผ ๋ฌ์ฑํ๊ธฐ์ solver๊ฐ ์ด๋์ ๋ ์๋ํ๋ ๊ฒ์ ๋ณด์ฌ์ค๋ค.)
์ฐ๋ฆฐ deep plain์ ๊ฒฝ๋ง์ด ๊ธฐํ๊ธ์์ ์ผ๋ก ๋ฎ์ ์๋ ด๋ฅ ์ ๊ฐ์ง ์ ์์ด์ training error๊ฐ์์ ์ํฅ์ ๋ฏธ์น๋ค๊ณ ์ถ์ธกํ๋ค.
• Residual Networks.
๋ค์์ผ๋ก 18 ๋ฐ 34-layer residual nets (ResNets)์ ํ๊ฐํ๋ค.
๊ทผ๊ฐ์ด ๋๋ ๊ตฌ์กฐ๋ ์์ plain์ ๊ฒฝ๋ง๊ณผ ๋์ผํ๋ฉฐ ๊ทธ๋ฆผ 3(์ฐ์ธก)์ฒ๋ผ 3x3 filter์ ๊ฐ ์์ shortcut connection์ ์ถ๊ฐํ๋ค.
์ฒซ ๋น๊ต(ํ2์ ๊ทธ๋ฆผ4,์ค๋ฅธ์ชฝ)์์ ๋ชจ๋ Shortcut์ identity mapping์ ์ฌ์ฉํ๊ณ ์ฐจ์๋ฅผ ๋๋ฆฌ๊ธฐ ์ํด zero-padding์ ํ๊ธฐ์ [option A], plain์ ๋นํด ์ถ๊ฐ์ ์ธ parameter๊ฐ ์๋ค.
[ํ 2์ ๊ทธ๋ฆผ 4์์ ์ป์ 3๊ฐ์ง ์ฃผ์ ๊ด์ฐฐ๊ฒฐ๊ณผ]
โ ๋จผ์ , ์์ฐจํ์ต์ผ๋ก ์ํฉ์ด ์ญ์ ๋๋ค.
- 34-layer ResNet์ 18-layer ResNet๋ณด๋ค 2.8%๋ซ๋ค.
- ๋ ์ค์ํ ์ ์, 34-layer ResNet์ด ์๋นํ ๋ฎ์ training error๋ฅผ ๋ํ๋ด๊ธฐ์ validation dataset์ผ๋ก generalization, ์ฆ ์ผ๋ฐํ๊ฐ ๊ฐ๋ฅํ๋ค๋ ์ ์ธ๋ฐ, ์ด๋ ์ฑ๋ฅ์ ํ๋ฌธ์ ๊ฐ ์ด ์ค์ ์์ ์ ํด๊ฒฐ๋์๊ธฐ์ ๊น์ด์ฆ๊ฐ๋ฅผ ์ด์ฉํด ์ ํ์ฑ์ ์ด๋์ ์ป์ ์ ์๋ค๋ ์ ์ ์์ฌํ๋ค.
โก ๋์งธ๋ก, plain์ ๊ฒฝ๋ง๊ณผ ๋น๊ตํด 34-layer ResNet์ ํ 2์ฒ๋ผ top-1 error rate๋ฅผ 3.5% ๊ฐ์์ํค๋ฉฐ ์ฑ๊ณต์ ์ผ๋ก ๊ฐ์๋ training error๋ฅผ ๋ณด์ฌ์ค๋ค. (๊ทธ๋ฆผ 4์์ ์ค๋ฅธ์ชฝ vs. ์ผ์ชฝ)
์ด ๋น๊ต๋ ๊ทน๋๋ก ์ฌ์ธต์ ์ธ ์ ๊ฒฝ๋ง์ ๋ํ ์์ฐจํ์ต์ ํจ๊ณผ๋ฅผ ๊ฒ์ฆํด์ค๋ค.
โข ๋ง์ง๋ง์ผ๋ก, 18-layer plain/ResNet์ด ๋น๊ต์ ์ ํํ์ง๋ง (ํ 2), 18-layer ResNet์ ๋ ๋นจ๋ฆฌ ์๋ ดํ๋ค๋ ๊ฒ์ ์ฃผ๋ชฉํ์.(๊ทธ๋ฆผ 4์์ ์ค๋ฅธ์ชฝ vs. ์ผ์ชฝ)
- ์ฌ๊ธฐ์ 18-layer์ฒ๋ผ ์ ๊ฒฝ๋ง์ด "์ง๋์น๊ฒ ๊น์ง ์์" ๊ฒฝ์ฐ, ํ์ฌ์ SGD solver๋ ์ฌ์ ํ plain์ ๊ฒฝ๋ง์ ๋ํด ์ข์ ํด๊ฒฐ์ฑ ์ ์ฐพ๋๋ค.
- ์ด ๊ฒฝ์ฐ, ResNet์ ์ด๊ธฐ๋จ๊ณ์์ ๋ ๋น ๋ฅธ ์๋ ด์ ์ ๊ณตํด ์ต์ ํ๋ฅผ ์ํํด์ค๋ค.
• Identity. vs. Projection Shortcuts.
- parameter๊ฐ ์๋ identity shorcut์ training์ ๋์์ด ๋๋ค๋ ๊ฒ์ ๋ณด์ฌ์ฃผ์์ผ๋ฏ๋ก ๋ค์์ผ๋ก๋ projection shortcut(Eqn. (2))์ ๋ํด ์กฐ์ฌํ๋ค.
ํ 3์์๋ 3๊ฐ์ง ์ต์ ์ ๋น๊ตํ๋ค.
(A) zero-padding shortcut์ ์ฐจ์๋ฅผ ๋๋ฆฌ๊ธฐ ์ํด ์ฌ์ฉ๋๋ค.
- ์ด๋, ๋ชจ๋ shortcut์ parameter๊ฐ ์๋ค(ํ 2 ๋ฐ ๊ทธ๋ฆผ 4์ ์ค๋ฅธ์ชฝ)
(B) projection shortcut์ ์ฐจ์๋ฅผ ๋๋ฆฌ๊ธฐ ์ํด ์ฌ์ฉ๋๋ค.
- ์ด๋, ๋ค๋ฅธ shortcut์ identity์ด๋ค.
(C) ๋ชจ๋ shortcut์ projection์ด๋ค.
- ํ 3์ 3๊ฐ์ง ์ต์ ๋ชจ๋ plain๋ณด๋ค ๋ซ๋ค๋ ๊ฒ์ ๋ณด์ฌ์ค๋ค. ์ด๋, B๊ฐ A๋ณด๋ค ์ฝ๊ฐ ๋ ๋ซ๊ณ C๊ฐ B๋ณด๋ค ์ฝ๊ฐ ๋ซ๋ค. (C > B > A)
์ฐ๋ฆฐ ์ด์ ๋ํด A์ zero-padding์ฐจ์์ด ์ค์ ๋ก ์์ฐจํ์ต์ ๊ฐ์ง ์๊ธฐ ๋๋ฌธ์ด๋ฉฐ (B > A์ธ ์ด์ )
13๊ฐ์ ๋ง์ projection shortcut์ ์ํด ๋์ ๋ ์ถ๊ฐ์ ๋งค๊ฐ๋ณ์ ๋๋ฌธ์ด๋ผ ๋ณธ๋ค. (C > B์ธ ์ด์ )
๊ทธ๋ฌ๋ A/B/C์ฌ์ด ์์ ์ฐจ์ด๋ ์ฑ๋ฅ์ ํ๋ฌธ์ ํด๊ฒฐ์ ์ํด Projection Shortcut์ด ํ์์ ์ด์ง๋ ์๋ค๋ ๊ฒ์ ์์ฌํ๋ค.
∴ ๋ฉ๋ชจ๋ฆฌ/์๊ฐ๋ณต์ก์ฑ๊ณผ ๋ชจ๋ธํฌ๊ธฐ๋ฅผ ์ค์ด๊ธฐ ์ํด ์ด ๋ ผ๋ฌธ์ ๋๋จธ์ง ํํธ์์๋ C๋ฅผ ์ฌ์ฉํ์ง ์๋๋ค.
identity shortcut์ ์๋์ ์๊ฐ๋ Residual Architecture์ ๋ณต์ก์ฑ์ ์ฆ๊ฐ์ํค์ง ์๊ธฐ์ ๋์ฑ ์ค์ํ๋ค.
• Deeper Bottleneck Architectures.
[50-layer ResNet]
- ์ฐ๋ฆฌ๋ 34-layer net์ ๊ฐ 2์ธต block์ 3-layer bottleneck block์ผ๋ก ๋์ฒดํ์ฌ 50-layer ResNet(ํ 1)์ ์์ฑํ๋ค.
- ์ฐ๋ฆฌ๋ ์ฐจ์๋ฅผ ๋๋ฆฌ๊ธฐ ์ํด option B๋ฅผ ์ฌ์ฉํ๋ค.
- ์ด ๋ชจ๋ธ์๋ 38์ต ๊ฐ์ FLOPS๊ฐ ์์ต๋๋ค.
[101-layer ResNet. &. 152-layer ResNet]
- ์ฐ๋ฆฌ๋ ๋ ๋ง์ 3-layer blocks(ํ 1)์ ์ฌ์ฉํ์ฌ 101์ธต ๋ฐ 152์ธต ResNets๋ฅผ ๊ตฌ์ฑํ๋ค.
- ๋๋๊ฒ๋ ๊น์ด๊ฐ ํฌ๊ฒ ์ฆ๊ฐํ์ง๋ง 152์ธต ResNet(113์ต FLOP)์ ์ฌ์ ํ VGG-16/19(153์ต/196์ต FLOP)๋ณด๋ค ๋ณต์ก์ฑ์ด ๋ฎ๋ค.
- 50/101/152-layer ResNets๋ 34-layer๋ณด๋ค ์๋นํ ์ ํํ๋ค(ํ 3 ๋ฐ 4).
- ์ฐ๋ฆฌ๋ ์ฑ๋ฅ์ ํ๋ฅผ ๊ด์ฐฐํ์ง ์๊ธฐ์ ์๋นํ ์ฆ๊ฐ๋ ๊น์ด์์ ์๋นํ ์ ํ๋ ํฅ์์ ์ด๋ค๋ค.
- ๊น์ด์ ์ด์ ์ ๋ชจ๋ ํ๊ฐ ์งํ(evaluation metric)์์ ํ์ธํ ์ ์๋ค(ํ 3 ๋ฐ 4).
cf. FLOPs
์ปดํจํฐ์ ์ฑ๋ฅ์ ์์น๋ก ๋ํ๋ผ ๋ ์ฃผ๋ก ์ฌ์ฉ๋๋ ๋จ์์ด๋ค. ์ด๋น ๋ถ๋์์์ ์ฐ์ฐ์ด๋ผ๋ ์๋ฏธ๋ก ์ปดํจํฐ๊ฐ 1์ด๋์ ์ํํ ์ ์๋ ๋ถ๋์์์ ์ฐ์ฐ์ ํ์๋ฅผ ๊ธฐ์ค์ผ๋ก ์ผ๋๋ค.
•Comparision with State-of-the-art Method.
- ํ 4์์ ์ด์ ์ ์ต๊ณ ์ ๋จ์ผ ๋ชจ๋ธ ๊ฒฐ๊ณผ์ ๋น๊ตํ์๋๋ฐ, ์ฐ๋ฆฌ์ ๊ทผ-๋ณธ 34์ธต ResNets๋ ๋งค์ฐ ๊ฒฝ์๋ ฅ ์๋ ์ ํ๋๋ฅผ ๋ฌ์ฑํ๋ค.
- 152์ธต ResNet์ ๋จ์ผ ๋ชจ๋ธ์์ top-5 error rate์์ 4.49%๋ฅผ ๊ฐ๋๋ค.
- ์ด ๋จ์ผ ๋ชจ๋ธ ๊ฒฐ๊ณผ๋ ์ด์ ๊น์ง์ ๋ชจ๋ ์์๋ธ ๊ฒฐ๊ณผ๋ฅผ ๋ฅ๊ฐํ๋ค (ํ 5).
- ๊น์ด๊ฐ ๋ค๋ฅธ 6๊ฐ์ ๋ชจ๋ธ์ ๊ฒฐํฉํ์ฌ ์์๋ธ์ ํ์ฑํ์๋๋ฐ, testset์์ top-5 error๊ฐ 3.57% ์๋ค(ํ 5).
(์ ์ถ ๋น์, ์์๋ธ ๊ธฐ๋ฒ์ 152์ธต ๋ชจ๋ธ ๋ ๊ฐ๋ง ํฌํจํ์์ผ๋ฉฐ ILSVRC 2015์์ 1์๋ฅผ ์ฐจ์งํ๋ค.)
4.2. CIFAR-10 and Analysis
- ์ฐ๋ฆฌ๋ 10๊ฐ์ ํด๋์ค์์ 5๋ง๊ฐ์ traininset๊ณผ 1๋ง๊ฐ์ test image๋ก ๊ตฌ์ฑ๋ CIFAR-10 dataset์ ๋ํด ๋ ๋ง์ ์ฐ๊ตฌ๋ฅผ ์ํํ๋ค.
์ฐ๋ฆฌ๋ training set์ ๋ํด ํ๋ จ๋๊ณ testset์ ๋ํด ํ๊ฐ๋ ์คํ์ ๊ฒฐ๊ณผ๋ก ์ ์ํ๋ค.
์ฐ๋ฆฌ๋ ๊ทน๋๋ก ์ฌ์ธต์ ์ธ ์ ๊ฒฝ๋ง์ ๋์์ ์ด์ ์ ๋ง์ถ๋ค.
์ต์ฒจ๋จ ๊ฒฐ๊ณผ๋ฅผ ์ถ์งํ๋ ๊ฒ์๋ ์ด์ ์ ๋ง์ถ์ง ์๊ธฐ์ ๋ค์๊ณผ ๊ฐ์ด ๋จ์ํ ๊ตฌ์กฐ๋ฅผ ์๋์ ์ผ๋ก ์ฌ์ฉํ๋ค.
- plain/residual architecture๋ ๊ทธ๋ฆผ 3(๊ฐ์ด๋ฐ/์ค๋ฅธ์ชฝ)์ ํํ๋ฅผ ๋ฐ๋ฅธ๋ค.
์ ๊ฒฝ๋ง์ ์ ๋ ฅ์ ํฝ์ ๋น ํ๊ท ์ ์ฐจ(per-pixel mean subtracted)์ 32x32 ์ด๋ฏธ์ง์ด๋ค.
์ฒซ ๋ฒ์งธ ์ธต์ 3×3 convolution์ด๋ฉฐ
๊ทธ ํ ํฌ๊ธฐ๊ฐ ๊ฐ๊ฐ {32, 16, 8}์ธ ํน์ง๋งต์์ 3×3 convolution์ ๊ฐ์ง 6n layers์ ์ ์ธต ํ ๊ฐ ํน์ง๋งต ํฌ๊ธฐ์ ๋ํด 2n layer๋ฅผ ์ฌ์ฉํ๋ค. ์ด๋, filter์๋ ๊ฐ๊ฐ {16, 32, 64}๊ฐ์ด๋ค.
sub-sampling์ stride=2์ธ convolution์ ์ํด ์ํ๋๋ค.
์ ๊ฒฝ๋ง์ Global AveragePooling, 10-way FC.layer ๋ฐ softmax๋ก ์ข ๋ฃ๋๋ฉฐ ์ด 6n+2๊ฐ์ ์ ์ธต๋ weight ์ธต์ด ์๋ค.
๋ค์ ํ๋ Architecture๋ฅผ ์์ฝํ ๊ฒ์ด๋ค:
Shortcut Connection ์ฌ์ฉ์, ํ ์์ 3x3 layer(์ด 3n๊ฐ์ shortcut)์ ์ฐ๊ฒฐ๋๋ค.
์ด dataset์์ (option A์ ํฌํจํ)๋ชจ๋ ๊ฒฝ์ฐ์ identity shortcut์ ์ฌ์ฉํ๋ค.
๋ฐ๋ผ์ Residual Model์ Plain Model๊ณผ depth, width. &. parameter์๊ฐ ์ ํํ๊ฒ ๋์ผํ๋ค.
- weight decay=0.0001๊ณผ momentum=0.9์ ์ฌ์ฉํ๊ณ ๊ฐ์ค์น ์ด๊ธฐํ ๋ฐ Batch Normalization์ ์ฌ์ฉํ์ง๋ง Dropout์ ์ฌ์ฉํ์ง ์๋๋ค. (์ด๋, ๊ฐ์ค์น ์ด๊ธฐํ๋ https://chan4im.tistory.com/150๋ฅผ ๋ฐ๋ฅธ๋ค.)
์ด ๋ชจ๋ธ๋ค์ 2๊ฐ์ GPU์์ 128์ mini-batch๋ก ํ๋ จ๋๋ค.
learning rate=0.1์ ํ์ต ์๋๋ก ์์ํด 32000๊ณผ 48000 iteration์์ 10์ผ๋ก ๋๋๋ค.
64000๋ฒ์ iteration์์ 45k/5k๋ก ๋๋ train/validation์ ๊ฐ์ด ๊ฒฐ์ ๋๊ธฐ์ training์ ์ข ๋ฃํ๋ค.
training์ ์ํด [Supervised Net.,https://arxiv.org/abs/1409.5185]์ ์๊ฐ๋ ๊ฐ๋จํ Data Augmentation์ ๋ฐ๋ฅธ๋ค.
๊ฐ ๋ฉด์ 4๊ฐ์ pixel์ด padding๋๋ฉฐ
32x32 crop์ padding๋ ์ด๋ฏธ์ง ํน์ horizontal flip ์ค ๋ฌด์์๋ก ์ํ๋ง๋๋ค.
test๋ฅผ ์ํด ์๋ณธ 32x32 ์ด๋ฏธ์ง์ single view๋ง ํ๊ฐํฉ๋๋ค.
- 20, 32, 44 ๋ฐ 56 ์ธต ์ ๊ฒฝ๋ง์ผ๋ก ์ด์ด์ง๋ n = {3, 5, 7, 9}์ ๋น๊ตํ๋ค.
๊ทธ๋ฆผ 6(์ผ์ชฝ)์ Plain ์ ๊ฒฝ๋ง์ ์๋์ ๋ณด์ฌ์ค๋ค.
๊น์ Plain ์ ๊ฒฝ๋ง์ ๊น์ด๊ฐ ์ฆ๊ฐํ๋ฉด์ ์ด๋ ค์์ ๊ฒช๊ณ , ๋ ๊น์ด ๋ค์ด๊ฐ ๋ ๋ ๋์ train error๋ฅผ ๋ณด์ธ๋ค.
์ด๋ฐ ํ์์ ImageNet(๊ทธ๋ฆผ 4, ์ผ์ชฝ) ๋ฐ MNIST([42] ์ฐธ์กฐ)์์์ ํ์๊ณผ ์ ์ฌํ์ฌ ์ด๋ฌํ ์ต์ ํ์ ์ด๋ ค์์ด ๊ทผ๋ณธ์ ์ธ ๋ฌธ์ ์์ ์์ฌํ๋ค.
๊ทธ๋ฆผ 6(๊ฐ์ด๋ฐ)์ ResNets์ ๋์์ ๋ณด์ฌ์ค๋ค.
๋ํ ImageNet ์ฌ๋ก(๊ทธ๋ฆผ 4, ์ค๋ฅธ์ชฝ)์ ์ ์ฌํ๊ฒ ResNets๋ ์ต์ ํ ์ด๋ ค์์ ๊ทน๋ณตํ๊ณ ๊น์ด๊ฐ ์ฆ๊ฐํ ๋ ์ ํ๋๊ฐ ํฅ์๋จ์ ๋ณด์ฌ์ค๋ค.
- 110์ธต ResNet์ผ๋ก ์ด์ด์ง๋ n = 18์ ์ถ๊ฐ๋ก ํ๊ตฌํ๋ค.
์ด ๊ฒฝ์ฐ, ์ฐ๋ฆฌ๋ ์ด๊ธฐ์ learning rate=0.1์ด "์๋ ด์ ์์ํ๊ธฐ"์ ์ฝ๊ฐ ๋๋ฌด ํฌ๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ๋ค.
๋ฐ๋ผ์ training error๊ฐ 80%๋ฏธ๋ง์ด ๋ ๋๊น์ง ๋ฎ์ ํ์ต๋ฅ ์ธ 0.01๋ก ์ฌ์ training์ ์งํํ ํ(์ฝ 400 iteration) 0.1๋ก ๋ค์training์ ๊ณ์ํ๋ค.
๋๋จธ์ง training schedule์ ์ด์ ๊ณผ ๋์ผํ๋ฉฐ, ์ด 110์ธต ์ ๊ฒฝ๋ง์ ์ ์๋ ด๋๋ค(๊ทธ๋ฆผ 6, ์ค๊ฐ).
FitNet[https://arxiv.org/abs/1412.6550] ๋ฐ Highway[https://arxiv.org/abs/1505.00387](ํ 6)์ ๊ฐ์ ๊น๊ณ ๋ฐ ์์ ์ ๊ฒฝ๋ง๋ณด๋ค parameter์๊ฐ ์ ์ง๋ง, ์ต์ฒจ๋จ ๊ฒฐ๊ณผ(6.43%, ํ 6) ์ค ํ๋๋ฅผ ์ป์๋ค.
• Analysis of Layer Response.
- ๊ทธ๋ฆผ 7์ layer response์ ํ์ค ํธ์ฐจ(std)๋ฅผ ๋ณด์ฌ์ค๋ค.
์ด๋, response๋ BN ์ดํ ๋ฐ ๊ธฐํ ๋น์ ํ์ฑ(ReLU/์ถ๊ฐ) ์ด์ ์ ๊ฐ 3x3 layer์ ์ถ๋ ฅ์ด๋ค.
ResNets์ ๊ฒฝ์ฐ, ์ด ๋ถ์์ ์์ฐจ ํจ์์ response๊ฐ๋๋ฅผ ๋ํ๋ธ๋ค.
- ๊ทธ๋ฆผ 7์ ResNet์ด ์ผ๋ฐ์ ์ธ ์๋ต๋ณด๋ค ์ผ๋ฐ์ ์ผ๋ก ๋ ์์ ์๋ต์ ๊ฐ์ง๊ณ ์์์ ๋ณด์ฌ์ค๋ค.
์ด๋ฐ ๊ฒฐ๊ณผ๋ ์์ฐจ ํจ์๊ฐ ๋น์์ฐจ ํจ์๋ณด๋ค ์ผ๋ฐ์ ์ผ๋ก 0์ ๊ฐ๊น์ธ ์ ์๋ค๋ ๊ธฐ๋ณธ ๊ฐ์ (3.1์ )์ ๋ท๋ฐ์นจํ๋ค.
- ๋ํ ๊ทธ๋ฆผ 7์ ResNet-20, 56 ๋ฐ 110์ ๋น๊ต์์ ์ ์ฆ๋ ๋ฐ์ ๊ฐ์ด ๋ ๊น์ ResNet์ด ์๋ต์ ํฌ๊ธฐ๊ฐ ๋ ์๋ค๋ ๊ฒ์ ์ฃผ๋ชฉํ์.
์ฆ, ๋ ๋ง์ ์ธต์ด ์์ ๋, ResNets์ ๊ฐ๊ฐ์ ์ธต์ ์ ํธ๋ฅผ ๋ ์์ ํ๋ ๊ฒฝํฅ์ด ์๋ค.
• Exploring Over 1000 layers.
- ๊ณต๊ฒฉ์ ์ผ๋ก 1000๊ฐ ์ด์์ ์ธต์ ์๋ ๊น์ ๋ชจ๋ธ์ ํ๊ตฌํ๋ค.
์์์์ฒ๋ผ ํ๋ จ๋ 1202์ธต ์ ๊ฒฝ๋ง์ผ๋ก ์ด์ด์ง๋ n = 200์ ์ค์ ํ๋ค.
์ฐ๋ฆฌ์ ๋ฐฉ๋ฒ์ ์ต์ ํ ์ด๋ ค์์ ๋ณด์ด์ง ์๊ณ , ์ด 103์ธต ์ ๊ฒฝ๋ง์ training error < 0.1%๋ฅผ ๋ฌ์ฑํ ์ ์๋ค(๊ทธ๋ฆผ 6, ์ค๋ฅธ์ชฝ).
test error๋ ์ฌ์ ํ ์๋นํ ์ํธํฉ๋๋ค(7.93%, ํ 6).
- ๊ทธ๋ฌ๋ ๊ทธ๋ฌํ ๊ณต๊ฒฉ์ ์ฌ์ธต๋ชจ๋ธ์ ์ฌ์ ํ ๋ฏธํด๊ฒฐ ๋ฌธ์ ๊ฐ ์๋ค.
์ด 1202์ธต ์ ๊ฒฝ๋ง test ๊ฒฐ๊ณผ๋ ์ฐ๋ฆฌ์ 110์ธต ์ ๊ฒฝ๋ง ๊ฒฐ๊ณผ๋ณด๋ค ๋์์ง๋ง, ๋ ๋ค ๋น์ทํ training error๋ฅผ ๊ฐ๋๋ค.
์ฐ๋ฆฐ ์ด๋ฐ ํ์์ด "Overfitting"๋๋ฌธ์ด๋ผ ์ฃผ์ฅํ๋ค.
- 1202์ธต ์ ๊ฒฝ๋ง์ ์ด ์์ dataset์ ๋ํด์๋ ๋ถํ์ํ๊ฒ ํด ์ ์๋ค(19.4M).
์ด dataset์์ ์ต์์ ๊ฒฐ๊ณผ๋ฅผ ์ป๊ธฐ ์ํด maxout / dropout ๊ฐ์ ๊ฐ๋ ฅํ ์ ๊ทํ๊ฐ ์ ์ฉ๋๋ค.
๋ณธ ๋ ผ๋ฌธ์์๋, ์ต์ ํ์ ์ด๋ ค์์ ์ด์ ์ ๋ง์ถ์ง ์๊ธฐ์ maxout/dropout์ ์ฌ์ฉํ์ง ์๋๋ค.
์ค๊ณ์ ๋ฐ๋ผ ๊น๊ณ ์์ ๊ตฌ์กฐ๋ฅผ ํตํด ์ ๊ทํ๋ฅผ ์ ์ฉํ์ง๋ง ๋ ๊ฐ๋ ฅํ ์ ๊ทํ์ ๊ฒฐํฉํ๋ฉด ๊ฒฐ๊ณผ๊ฐ ๊ฐ์ ๋ ์ ์๋ค.
4.3. Object Detection on PASCAL and MS COCO
- ์ฐ๋ฆฌ์ ๋ฐฉ๋ฒ์ ๋ค๋ฅธ ์ธ์๊ณผ์ ์์๋ ์ผ๋ฐํ ์ฑ๋ฅ์ด ์ข๋ค.
- ํ 7๊ณผ 8์ PASCAL VOC 2007๊ณผ 2012 [5] ๋ฐ COCO์ ๋ํ Object Detection์ ๊ธฐ์ค๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋ค.
- ์ฐ๋ฆฌ๋ Detection๋ฐฉ๋ฒ์ Faster R-CNN์ ์ฌ์ฉํ๋ค.
์ด๋, VGG-16์ ResNet-101๋ก ๋์ฒดํ๋ ๊ฐ์ ์ฌํญ์ ๊ด์ฌ์ ๋๊ณ ์ฃผ๋ชฉํ๋ค.
๋ ๋ชจ๋ธ์ ์ฌ์ฉํ๋ Detection์ ๊ตฌํ(๋ถ๋ก ์ฐธ์กฐ)์ ๋์ผ., ๊ฒฐ๊ณผ๋ฌผ์ ๋ ๋์ ์ ๊ฒฝ๋ง์ ๊ท์๋๋ค.
- ๊ฐ์ฅ ์ฃผ๋ชฉํ ๋งํ ๊ฒ์ ๊น๋ค๋ก์ด COCO dataset์์ COCO์ ํ์ค metric์ธ (mAP@[.5, .95])์ด 6.0% ์ฆ๊ฐํ์ฌ 28%์ ์๋์ ๊ฐ์ ์ ์ป์๋๋ฐ, ์ด๋ ์ค์ง learned representation ๋๋ฌธ์ด๋ค.
- ์ฌ์ธต ResNet์ ๊ธฐ๋ฐ์ผ๋ก ILSVRC ๋ฐ COCO 2015 ๋ํ์์ ์ฌ๋ฌ ํธ๋์์ 1์๋ฅผ ์ฐจ์งํ์ต๋๋ค:
ImageNet Detection, ImageNet Localization, COCO Detection ๋ฐ COCO Segmentation.
- ์์ธํ ๋ด์ฉ์ ๋ถ๋ก์ ๊ธฐ์ฌ.
๐ถ ๋ถ๋ก (Appendix)
A. Object Detection. Baselines
• PASCAL VOC
-
• MS COCO
-
B. Object Detection Improvements
• MS COCO
-
• PASCAL VOC
-
• ImageNet Detection
-
C. ImageNet Localization
-
-
๐ง ๋ ผ๋ฌธ ๊ฐ์_์ค์๊ฐ๋ ํต์ฌ ์์ฝ
"Deep Residual Learning for Image Recognition"
์ฌ์ธต ์ ๊ฒฝ๋ง์์ gradient vanishing problem์ ํด๊ฒฐํ๋ ์ฌ์ธต ConvNet์ ์ํ ์๋ก์ด ๊ตฌ์กฐ๋ฅผ ์๊ฐํ๋ ์ฐ๊ตฌ ๋ ผ๋ฌธ์ผ๋ก ์ด ๋ ผ๋ฌธ์ ResNet์ด๋ผ๋ ์ฌ์ธต ํฉ์ฑ๊ณฑ ์ ๊ฒฝ๋ง์ ์ํ ์๋ก์ด ์ํคํ ์ฒ๋ฅผ ์ ์ํ๋ค.
[ํต์ฌ ๊ฐ๋ ]
1. Problem
- ์ด ๋ ผ๋ฌธ์ ๋คํธ์ํฌ ๊น์ด๊ฐ ์ฆ๊ฐํจ์ ๋ฐ๋ผ ์ ๊ฒฝ๋ง์ ์ ํ๋๊ฐ ํฌํ๋๊ฑฐ๋ ์ ํ๋ ์ ์๋ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ก ์ธํด ๋งค์ฐ ๊น์ ์ ๊ฒฝ๋ง์ ํ๋ จํ๋ ๊ฒ์ด ์ด๋ ต๋ค๋ ์ ์ ๊ฐ์กฐํ๋ค.
[Degradation Problem]
- ๋ ๊น์ ์ ๊ฒฝ๋ง์ด ์๋ ด์ ์์ํ ๋, ์ ๊ฒฝ๋ง๊น์ด๊ฐ ์ฆ๊ฐํ๋ฉด ์ ํ๋๊ฐ ํฌํ์ํ๊ฐ ๋๊ณ , ๊ทธ ๋ค์์ ๋น ๋ฅด๊ฒ ์ ํ๋๋ค.
- ์ด ๋ฌธ์ ์ ์์์น ๋ชปํ ๋ฌธ์ ์ ์ ๋ฐ๋ก overfitting์ด ์ด ๋ฌธ์ ๋ฅผ ์ผ๊ธฐํ์ง ์๋๋ค๋ ์ ์ธ๋ฐ, ์ ์ ํ ์ฌ์ธต๋ชจ๋ธ์ ๋ ๋ง์ ์ธต์ ์ถ๊ฐํ๋ฉด ๋ ๋์ training error๊ฐ ๋ฐ์ํ๋ค.
2. Solution
์ด ๋ ผ๋ฌธ์ ๊ธฐ๋ณธ ๋งตํ์ ์ง์ ํ์ตํ๋ ๋์ ์์ฐจ ํจ์(residual functions)๋ฅผ ํ์ตํ๋ ์์ฐจ ํ์ต(residual learning)์ด๋ผ๋ ์๋ก์ด ์ ๊ทผ ๋ฐฉ์์ ์ ์ํ๋ค.
์ด ์ ๊ทผ ๋ฐฉ์์ ์ ๊ฒฝ๋ง์ด identity mapping์ ํ์ตํ ์ ์๋๋ก skip connection์ ์ฌ์ฉํ๋ residual block์ ํตํด ๊ตฌํ๋๋ค.
cf. [Identity. &. Identity Mapping]
ResNet๊ตฌ์กฐ์์ Residual Block์ identity(= identity mapping) ์ ๋งํ๋ค.
- ์ ๋ ฅ์์ ์ถ๋ ฅ์ผ๋ก์ ์ ์ฒด ๋งคํ์ ํ์ตํ๋ ๋์ ์ ๋ ฅ์ ์ฝ๊ฐ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ํ์ต์ ์งํํ๋ค.
- Identity Block์ Skip Connection์ด Identity mapping์ ์ํํ๋ residual block์ ์ผ์ข ์ด๋ค.
- ์ฆ, ๋ณํ์์ด ์ ๋ ฅ์ด block์ ์ถ๋ ฅ์ ์ง์ ์ถ๊ฐ๋๋ ๊ฒ์ผ๋ก Identity Block์ input์ ์ฐจ์์ ์ ์ง์ํจ๋ค.
- Identity Block์ Residual Networks์์ ์ ๋ ฅ์ ์ฐจ์์ ์ ์งํ๋ฉด์ ๋น์ ํ์ฑ์ ๋์ ํ๋ ๋ฐ ์ฌ์ฉ๋๋ค.
์ด๋ ์ ๊ฒฝ๋ง์ด ๋ ๋ณต์กํ ํน์ง์ ํ์ตํ๋๋ก ๋๊ณ ๋งค์ฐ ๊น์ ์ ๊ฒฝ๋ง์์ ๋ฐ์ํ ์ ์๋ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ฅผ ๋ฐฉ์งํ๋ค.
[Shortcut Connection]
- skip connection์ด๋ผ๊ณ ๋ ๋ถ๋ฆฐ๋ค.
- ์ธต์ ์ ๋ ฅ๊ณผ ์ดํ ์ธต์ ์ถ๋ ฅ ์ฌ์ด์ ์ถ๊ฐํจ์ผ๋ก์จ ResNet์ ์ธต์ ์ฐํํ๊ณ ์ ๋ ฅ์ ์ดํ ๊ณ์ธต์ผ๋ก "์ง์ ์ ๋ฌ"ํ ์ ์์ต๋๋ค.
- ์ด๋ฅผ ํตํด ๋คํธ์ํฌ๋ ์์ฐจ ํจ์๋ฅผ ํ์ตํ ์ ์์ผ๋ฏ๋ก ๋งค์ฐ ๊น์ ์ ๊ฒฝ๋ง์ ํ๋ จ์ด ๊ฐ๋ฅํ๋ค.
- ๋ํ ์ฌ๋ฌ conv.layer์ ๊ธฐ์กด์ ์ ๋ ฅ์ ๋ธ๋ก์ ์ถ๋ ฅ์ ์ถ๊ฐํ์ฌ ์ ๊ฒฝ๋ง์ด identity mapping์ ํ์ตํ ์ ์๊ฒ ํด์ค๋ค.
์ด๋ฅผ ํตํด training์ ์ ํ์ฑ ๋ฐ ์๋๋ฅผ ํฅ์์ํฌ ์ ์๋ค.
[Residual learning]
์ด ๋ ผ๋ฌธ์ ์ํ๋ ๊ธฐ๋ณธ ๋งคํ์ ์ง์ ํ์ตํ๋ ๋์ ๋ ์ด์ด ์ ๋ ฅ์ ์ฐธ์กฐํ์ฌ ์์ฐจ ํจ์๋ฅผ ํ์ตํ๋ ์์ฐจ ํ์ต์ ๊ฐ๋ ์ ๋์ ํ๋ค.
์ด ์ ๊ทผ ๋ฐฉ์์ gradient vanishing์ ํผํ ์ ์์ด์ ์ฌ์ธต์ ๊ฒฝ๋งํ๋ จ์ ํ์ฉํ๋ค.
[Residual blocks]
ResNet์ ์ฌ๋ฌ conv.layer์ ๊ธฐ์กด์ ์ ๋ ฅ์ ๋ธ๋ก์ ์ถ๋ ฅ์ ์ถ๊ฐํ๋ skip connection์ผ๋ก ๊ตฌ์ฑ๋ ์์ฐจ ๋ธ๋ก(residual block)์ ์ฌ์ฉํ์ฌ ๊ตฌ์ถ๋๋ค.
์ด๋ฅผ ํตํด ์ ๊ฒฝ๋ง์ ์์ฐจ ํจ์๋ฅผ ํ์ตํ ์ ์์ผ๋ฏ๋ก ๋งค์ฐ ๊น์ ์ ๊ฒฝ๋ง์ ํ๋ จ์ด ๊ฐ๋ฅํฉ๋๋ค.
3. Bottleneck Architecture
- ResNet ๊ตฌ์กฐ๋ ๋งค์ฐ ๊น์ ์์ฌ ๋ธ๋ก ์คํ(์: 152๊ฐ ์ธต)์ผ๋ก ๊ตฌ์ฑ๋๋ฉฐ ์ ํ๋๋ฅผ ์ ์งํ๋ฉด์ ๊ณ์ฐ์ ์ค์ด๋ bottleneck์ค๊ณ๋ฅผ ์ฌ์ฉํ๋ค.
- ๋ ผ๋ฌธ์์๋ ์ธ ๊ฐ์ ์ธต์ผ๋ก ๊ตฌ์ฑ๋ ๋ณ๋ชฉ ๊ตฌ์กฐ๋ฅผ ์๊ฐํ์ผ๋ฉฐ ์ด ์ธ ์ธต์ Add๋ฅผ ์ด์ฉํด ๋ํ๋ค.
โ channel์๋ฅผ ์ค์ด๊ธฐ ์ํ 1x1 conv.layer
โก feature ํ์ต์ ์ํ 3x3 conv.layer
โข channel์๋ฅผ ๋ค์ ๋๋ฆฌ๊ธฐ ์ํ ๋ ๋ค๋ฅธ 1x1 conv.layer
cf. Pre-activation
์ด ๋ ผ๋ฌธ์ ์์ฉํ ResNetV2๋ ์ฌ์ ํ์ฑํ(pre-activation)์ ๋ํ ๊ฐ๋ ์ ๋์ ํ๋ค.
- BatchNormalization ๋ฐ ReLU๋ฅผ ๊ฐ conv.layer ์ดํ๊ฐ ์๋ ์ด์ ์ ์ ์ฉํ๋ค.
- ์ด๋ฅผ ํตํด training performance๋ฅผ ๊ฐ์ ํ๊ณ ๋งค์ฐ ๊น์ ์ ๊ฒฝ๋ง์์์ overfitting์ ์ค์ฌ์ฃผ์๋ค.
5. Results
- ResNet์ ํจ์ฌ ๋ ๊น์ ์ ๊ฒฝ๋ง์ผ๋ก ์ด์ ๋ฐฉ๋ฒ์ ๋ฅ๊ฐํด ImageNet dataset์ classification์์ ์์ ์ต์ฒจ๋จ ์ฑ๋ฅ์ ๋ฌ์ฑํ๋ค.
- ๋ํ ResNet์ Objcet Detection ๋ฐ Semantic Segmentation๊ณผ ๊ฐ์ ๋ค์ํ dataset ๋ฐ ์์ ์์ ๋ ๋์ ์ผ๋ฐํ ์ฑ๋ฅ(generalization performance)์ ๋ฌ์ฑํ๋ค.
- ์ด๋, ์ด๊ธฐ์ learning rate=0.1์ด "์๋ ด์ ์์ํ๊ธฐ"์ ์ฝ๊ฐ ๋๋ฌด ํฌ๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์๊ธฐ์
training error๊ฐ 80%๋ฏธ๋ง์ด ๋ ๋๊น์ง ๋ฎ์ ํ์ต๋ฅ ์ธ 0.01๋ก ์ฌ์ training์ ์งํํ ํ(์ฝ 400 iteration) 0.1๋ก ๋ค์training์ ์งํํ์๋ค.
์ ๋ฐ์ ์ผ๋ก ResNet ๋ ผ๋ฌธ์ ๋ฅ๋ฌ๋๊ณผ ํนํ ์ปดํจํฐ ๋น์ ์์ญ์์ ๋ง์ ์ต์ ๋ชจ๋ธ์ ํต์ฌ ๊ตฌ์ฑ ์์๊ฐ ๋ ๋งค์ฐ ๊น์ ConvNet์ ํ๋ จ์ ๊ฐ๋ฅํ๊ฒ ํ๋ ์๋ก์ด ๊ตฌ์กฐ๋ฅผ ์ ์ํ๋ ํ์ ์ ์ธ ์ฑ๊ณผ๋ฅผ ์ด๋ฃฉํ์๋ค.
๐ง ๋ ผ๋ฌธ์ ์ฝ๊ณ Architecture ์์ฑ (with tensorflow)
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dense, ReLU, BatchNormalization, ZeroPadding2D, Activation, Add
def conv_bn_relu(x, filters, kernel_size, strides=1, padding='same'):
x = Conv2D(filters, kernel_size, strides=strides, padding=padding)(x)
x = BatchNormalization()(x)
x = ReLU()(x)
return x
def identity_block(x, filters):
shortcut = x
x = conv_bn_relu(x, filters, 1)
x = conv_bn_relu(x, filters, 3)
x = conv_bn_relu(x, filters * 4, 1)
x = Add()([x, shortcut])
x = ReLU()(x)
return x
def projection_block(x, filters, strides):
shortcut = conv_bn_relu(x, filters * 4, 1, strides)
x = conv_bn_relu(x, filters, 1, strides)
x = conv_bn_relu(x, filters, 3)
x = conv_bn_relu(x, filters * 4, 1)
x = Add()([x, shortcut])
x = ReLU()(x)
return x
def resnet(input_shape, num_classes, num_layers):
if num_layers == 50:
num_blocks = [3, 4, 6, 3]
elif num_layers == 101:
num_blocks = [3, 4, 23, 3]
elif num_layers == 152:
num_blocks = [3, 8, 36, 3]
conv2_x, conv3_x, conv4_x, conv5_x = num_blocks
inputs = Input(shape=input_shape)
x = ZeroPadding2D(padding=(3, 3))(inputs)
x = conv_bn_relu(x, 64, 7, strides=2)
x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x)
x = projection_block(x, 64, strides=1)
for _ in range(conv2_x - 1):
x = identity_block(x, 64)
x = projection_block(x, 128, strides=2)
for _ in range(conv3_x - 1):
x = identity_block(x, 128)
x = projection_block(x, 256, strides=2)
for _ in range(conv4_x - 1):
x = identity_block(x, 256)
x = projection_block(x, 512, strides=2)
for _ in range(conv5_x - 1):
x = identity_block(x, 512)
x = GlobalAveragePooling2D()(x)
outputs = Dense(num_classes, activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)
return model
model = resnet(input_shape=(224,224,3), num_classes=200, num_layers=152)
model.summary()
Model: "ResNet152"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_24 (InputLayer) [(None, 224, 224, 3 0 []
)]
zero_padding2d_10 (ZeroPadding (None, 230, 230, 3) 0 ['input_24[0][0]']
2D)
conv2d_4159 (Conv2D) (None, 115, 115, 64 9472 ['zero_padding2d_10[0][0]']
)
batch_normalization_4158 (Batc (None, 115, 115, 64 256 ['conv2d_4159[0][0]']
hNormalization) )
re_lu_540 (ReLU) (None, 115, 115, 64 0 ['batch_normalization_4158[0][0]'
) ]
max_pooling2d_23 (MaxPooling2D (None, 58, 58, 64) 0 ['re_lu_540[0][0]']
)
conv2d_4161 (Conv2D) (None, 58, 58, 64) 4160 ['max_pooling2d_23[0][0]']
batch_normalization_4160 (Batc (None, 58, 58, 64) 256 ['conv2d_4161[0][0]']
hNormalization)
re_lu_542 (ReLU) (None, 58, 58, 64) 0 ['batch_normalization_4160[0][0]'
]
conv2d_4162 (Conv2D) (None, 58, 58, 64) 36928 ['re_lu_542[0][0]']
batch_normalization_4161 (Batc (None, 58, 58, 64) 256 ['conv2d_4162[0][0]']
hNormalization)
re_lu_543 (ReLU) (None, 58, 58, 64) 0 ['batch_normalization_4161[0][0]'
]
conv2d_4163 (Conv2D) (None, 58, 58, 256) 16640 ['re_lu_543[0][0]']
conv2d_4160 (Conv2D) (None, 58, 58, 256) 16640 ['max_pooling2d_23[0][0]']
batch_normalization_4162 (Batc (None, 58, 58, 256) 1024 ['conv2d_4163[0][0]']
hNormalization)
batch_normalization_4159 (Batc (None, 58, 58, 256) 1024 ['conv2d_4160[0][0]']
hNormalization)
re_lu_544 (ReLU) (None, 58, 58, 256) 0 ['batch_normalization_4162[0][0]'
]
re_lu_541 (ReLU) (None, 58, 58, 256) 0 ['batch_normalization_4159[0][0]'
]
add_1150 (Add) (None, 58, 58, 256) 0 ['re_lu_544[0][0]',
're_lu_541[0][0]']
re_lu_545 (ReLU) (None, 58, 58, 256) 0 ['add_1150[0][0]']
conv2d_4164 (Conv2D) (None, 58, 58, 64) 16448 ['re_lu_545[0][0]']
...
conv2d_4312 (Conv2D) (None, 8, 8, 512) 2359808 ['re_lu_741[0][0]']
batch_normalization_4311 (Batc (None, 8, 8, 512) 2048 ['conv2d_4312[0][0]']
hNormalization)
re_lu_742 (ReLU) (None, 8, 8, 512) 0 ['batch_normalization_4311[0][0]'
]
conv2d_4313 (Conv2D) (None, 8, 8, 2048) 1050624 ['re_lu_742[0][0]']
batch_normalization_4312 (Batc (None, 8, 8, 2048) 8192 ['conv2d_4313[0][0]']
hNormalization)
re_lu_743 (ReLU) (None, 8, 8, 2048) 0 ['batch_normalization_4312[0][0]'
]
add_1199 (Add) (None, 8, 8, 2048) 0 ['re_lu_743[0][0]',
're_lu_740[0][0]']
re_lu_744 (ReLU) (None, 8, 8, 2048) 0 ['add_1199[0][0]']
global_average_pooling2d_6 (Gl (None, 2048) 0 ['re_lu_744[0][0]']
obalAveragePooling2D)
dense_6 (Dense) (None, 200) 409800 ['global_average_pooling2d_6[0][0
]']
==================================================================================================
Total params: 58,780,744
Trainable params: 58,629,320
Non-trainable params: 151,424
__________________________________________________________________________________________________