batchnorm_layer
Last updated
Was this helpful?
Last updated
Was this helpful?
Paper :
Gradient Vanishing, Gradient Exploding ๋ฌธ์ ์
internal covariate shift
: weight์ ๋ณํ๊ฐ ์ค์ฒฉ๋์ด ๊ฐ์ค๋๋ ํฌ๊ธฐ ๋ณํ๊ฐ ํฌ๋ค๋ ๋ฌธ์ ์
careful initialization
: Difficult
small learning rate
: Slow
์์ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํ ํ๋์ ๋ฐฉ๋ฒ ์ ๋๋ค.
๋ณดํต internal covariate shift
๋ฅผ ์ค์ด๊ธฐ ์ํ ๋ํ์ ์ธ ๋ฐฉ๋ฒ์ ๊ฐ layer์ ์
๋ ฅ์ whitening์ ํ๋ ๊ฒ ์
๋๋ค. ์ฌ๊ธฐ์์ whitening ์ด๋ ํ๊ท 0 ๋ถ์ฐ 1๋ก ๋ฐ๊พธ์ด ์ฃผ๋ ๊ฒ(์ ๊ทํ)์ ๋งํฉ๋๋ค. ํ์ง๋ง ์ด๋ฌํ ์ฐ์ฐ์ ๋ฌธ์ ๊ฐ ์์ต๋๋ค.
bias์ ์ํฅ์ด ๋ฌด์ ๋ฉ๋๋ค.
๋ง์ฝ ์ฐ์ฐ์ ํ ๋ค์ ์ ๊ทํํ๊ธฐ ์ํด์ ํ๊ท ์ ๋นผ์ฃผ๋ ๊ฒฝ์ฐ bias ์ ์ํฅ์ด ์ฌ๋ผ์ง๊ฒ ๋ฉ๋๋ค.(bias๋ ๊ณ ์ ์ค์นผ๋ผ ๊ฐ์ด๊ธฐ ๋๋ฌธ์ ํ๊ท ์ ๊ตฌํด๋ ๊ฐ์ ๊ฐ์ด ๋์ต๋๋ค.)
๋น์ ํ์ฑ์ด ์์ด์ง ์ ์์ต๋๋ค.
๋ง์ฝ sigmoid๋ฅผ ํต๊ณผํ๋ ๊ฒฝ์ฐ ๋๋ถ๋ถ์ ์ ๋ ฅ๊ฐ์ sigmoid์ ์ค๊ฐ ๋ถ๋ถ์ ์ํฉ๋๋ค. sigmoid์์ ์ค๊ฐ์ ์ ํ์ด๊ธฐ ๋๋ฌธ์ ๋น์ ํ์ฑ์ด ์ฌ๋ผ์ง ์ ์๋ค๋ ๊ฒ์ ๋๋ค.
์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด batch Normalization
์ด ๋์์ต๋๋ค.
: mini-batch์ ํฌ๊ธฐ
: mean
: std
: scale
: shifts
๋ ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ ์ ๋๋ค. ์ด๊ฒ์ด ๋น์ ํ์ฑ์ ์ํ์ํค๊ธฐ ์ํ ํ๋ผ๋ฏธํฐ ์ ๋๋ค.
๋ฐฐ์น ์ ๊ทํ๋ ํ์ต ํ๋ ๊ฒฝ์ฐ์๋ ๋ฏธ๋ ๋ฐฐ์น์ ํ๊ท ๊ณผ ๋ถ์ฐ์ ๊ตฌํ ์ ์์ง๋ง ์ถ๋ก ์ ํ๋ ๊ฒฝ์ฐ๋ ๋ฏธ๋ ๋ฐฐ์น๊ฐ ์๊ธฐ ๋๋ฌธ์ ํ์ต ํ๋ ๋์ ๊ณ์ฐ ๋ ์ด๋ ํ๊ท
์ ์ฌ์ฉ ํฉ๋๋ค.
์ด๋ ํ๊ท : ๊ฐ ๋ฏธ๋ ๋ฐฐ์น ํ๊ท ์ ํ๊ท
์ด๋ ๋ถ์ฐ : ๊ฐ ๋ฏธ๋ ๋ฐฐ์น ๋ถ์ฐ์ ํ๊ท * m/(m-1) [Besselโs Correction]
CNN์ ๊ฒฝ์ฐ bias์ ์ญํ ์ ๊ฐ ๋์ ํ๊ธฐ ๋๋ฌธ์ bias๋ฅผ ์ ๊ฑฐํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ปจ๋ณผ๋ฃจ์ ์ฐ์ฐ์ ํตํด ์ถ๋ ฅ๋๋ ํน์ง ๋งต์ผ๋ก ๊ฐ ์ฑ๋๋ง๋ค ํ๊ท ๊ณผ ๋ถ์ฐ์ ๊ณ์ฐํ๊ณ ๋ฅผ ๋ง๋ญ๋๋ค. ์ฆ, ์ฑ๋์ ๊ฐ์ ๋งํผ ๊ฐ ์๊ฒจ๋ฉ๋๋ค.
internal covariate shift ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ค.
learning rate๋ฅผ ํฌ๊ฒ ํด๋ ๋๋ค.
์ ์คํ๊ฒ ์ด๊ธฐ๊ฐ์ ์ ํ ํ์๊ฐ ์๋ค.
dropout์ ๋์ฒด ํ ์ ์๋ค.
ํจ์ ์ด๋ฆ: forward_batchnorm_layer
์ ๋ ฅ:
l: layer ๊ตฌ์กฐ์ฒด
net: network ๊ตฌ์กฐ์ฒด
๋์:
Batch normalization ๋ ์ด์ด๋ฅผ ์ํํฉ๋๋ค.
์ค๋ช :
์ด ํจ์๋ ์ ๋ ฅ ๋ฐ์ดํฐ์ ๋ํด Batch normalization์ ์ํํฉ๋๋ค. ์ ๋ ฅ์ผ๋ก๋ layer ๊ตฌ์กฐ์ฒด์ network ๊ตฌ์กฐ์ฒด๊ฐ ํ์ํฉ๋๋ค.
ํจ์๋ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ๋ณต์ฌํ ํ, ํ์ต ๋ชจ๋์ ์ถ๋ก ๋ชจ๋๋ฅผ ๊ตฌ๋ถํ์ฌ ์ฒ๋ฆฌํฉ๋๋ค. ํ์ต ๋ชจ๋์์๋ ํ์ฌ ๋ฐฐ์น์ ๋ํ ํ๊ท ๊ณผ ๋ถ์ฐ์ ๊ณ์ฐํ๊ณ , ์ด๋ฅผ ์ฌ์ฉํ์ฌ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ๊ทํํฉ๋๋ค. ๊ทธ ๋ค์, ์ ๊ทํ๋ ๋ฐ์ดํฐ์ ์ค์ผ์ผ๊ณผ ๋ฐ์ด์ด์ค๋ฅผ ์ ์ฉํฉ๋๋ค. ์ค์ผ์ผ๊ณผ ๋ฐ์ด์ด์ค๋ layer ๊ตฌ์กฐ์ฒด ๋ด์ scales ๋ฐ biases ํ๋์์ ๊ฐ์ ธ์ต๋๋ค.
๋ฐ๋ฉด, ์ถ๋ก ๋ชจ๋์์๋ ๋ฐฐ์น์ ๋ํ ์ด๋ ํ๊ท ๊ณผ ์ด๋ ๋ถ์ฐ์ ์ฌ์ฉํ์ฌ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ๊ทํํฉ๋๋ค. ์ด๋ ํ๊ท ๊ณผ ์ด๋ ๋ถ์ฐ์ layer ๊ตฌ์กฐ์ฒด ๋ด์ rolling_mean ๋ฐ rolling_variance ํ๋์์ ๊ฐ์ ธ์ต๋๋ค.
๊ฒฐ๊ณผ์ ์ผ๋ก, Batch normalization์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ๊ทํํ์ฌ ๋ชจ๋ธ ํ์ต์ ์์ ํํ๊ณ , ํจ์ฌ ๋น ๋ฅด๊ฒ ์๋ ดํ๋๋ก ๋์์ค๋๋ค.
ํจ์ ์ด๋ฆ: backward_batchnorm_layer
์ ๋ ฅ:
l: layer ๊ตฌ์กฐ์ฒด
net: network ๊ตฌ์กฐ์ฒด
๋์:
net.train์ด false์ด๋ฉด ํ์ฌ ์ธต์ rolling mean๊ณผ rolling variance๋ก l.mean๊ณผ l.variance๋ฅผ ๋์ฒดํ๋ค.
bias ์ ๋ฐ์ดํธ์ scale ์ ๋ฐ์ดํธ๋ฅผ ์ํํ๋ค.
delta๋ฅผ scale๋ก ๊ณฑํ๊ณ bias๋ฅผ ๋ํด์ค๋ค.
delta์ ๋ํ mean delta๋ฅผ ๊ณ์ฐํ๋ค.
delta์ ๋ํ variance delta๋ฅผ ๊ณ์ฐํ๋ค.
delta๋ฅผ ์ ๊ทํํ๋ค.
๋ง์ฝ l.type์ด BATCHNORM์ด๋ฉด delta๋ฅผ net.delta๋ก ๋ณต์ฌํ๋ค.
์ค๋ช :
์ด ํจ์๋ ๋ฐฐ์น ์ ๊ทํ ์ธต์ ์ญ์ ํ(backpropagation)๋ฅผ ์ํํ๋ค. ๋ฐฐ์น ์ ๊ทํ ์ธต์ ์ญ์ ํ๋ ์์ ํ(forward propagation)์๋ ๋ค๋ฅด๊ฒ ์ฌ๋ฌ ๋จ๊ณ๋ก ๊ตฌ์ฑ๋์ด ์์ด ๋ณต์กํ๋ค. ์ด ํจ์๋ ์ด๋ฌํ ๋จ๊ณ๋ค์ ์ํํ์ฌ ์ญ์ ํ๋ฅผ ๊ตฌํํ๋ค.
์ฐ์ , net.train์ด false์ธ ๊ฒฝ์ฐ ํ์ฌ ์ธต์ rolling mean๊ณผ rolling variance๋ก l.mean๊ณผ l.variance๋ฅผ ๋์ฒดํ๋ค. rolling mean๊ณผ rolling variance๋ ํ์ฌ mini-batch ์ด์ ์ ๋ชจ๋ ๋ฐ์ดํฐ์ ์์ ๊ณ์ฐ๋ ํ๊ท ๊ณผ ๋ถ์ฐ์ ์ ์ฅํ๊ณ ์๋ค. ๋ฐ๋ผ์ ์ด์ ๋ฐ์ดํฐ์ ์ ํต๊ณ๋์ ์ฌ์ฉํ์ฌ ํ์ฌ mini-batch์ ์ ๊ทํ๋ฅผ ์ํํ๋ค.
๋ค์์ผ๋ก, bias ์ ๋ฐ์ดํธ์ scale ์ ๋ฐ์ดํธ๋ฅผ ์ํํ๋ค. ์ด์ ์ธต์์ ์ ๋ฐ์ดํธํ bias์ scale์ ์ฌ์ฉํ์ฌ ํ์ฌ ์ธต์ bias์ scale์ ์ ๋ฐ์ดํธํ๋ค.
๊ทธ ํ, delta๋ฅผ scale๋ก ๊ณฑํ๊ณ bias๋ฅผ ๋ํด์ค๋ค. ์ด๋ ์์ ํ์์ delta๋ฅผ ์ ๊ทํํ๊ธฐ ์ ์ ์ํํ scale๊ณผ bias์ ์ฐ์ฐ์ ์ญ์ ํํ๋ ๊ฒ์ด๋ค.
๊ทธ ๋ค์, delta์ ๋ํ mean delta๋ฅผ ๊ณ์ฐํ๋ค. ์ด๋ ์์ ํ์์ ์ ๊ทํ๋ ์ ๋ ฅ ๊ฐ์ ๋ํ delta๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด ์ฌ์ฉํ mean์ ์ญ์ ํํ๋ ๊ฒ์ด๋ค.
delta์ ๋ํ variance delta๋ฅผ ๊ณ์ฐํ ํ, ์ด๋ฅผ ์ฌ์ฉํ์ฌ delta๋ฅผ ์ ๊ทํํ๋ค. ์ด๋ ์์ ํ์์ ์ ๊ทํ๋ ์ ๋ ฅ ๊ฐ์ ๋ํ delta๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด ์ฌ์ฉํ variance๋ฅผ ์ญ์ ํํ๋ ๊ฒ์ด๋ค.
๋ง์ง๋ง์ผ๋ก, ๋ง์ฝ l.type์ด BATCHNORM์ธ ๊ฒฝ์ฐ delta๋ฅผ net.delta๋ก ๋ณต์ฌํ๋ค. ์ด๋ ์ด์ ์ธต์ delta์ ๋ํ ์ญ์ ํ๋ฅผ ์ํ ์์ ์ด๋ค.
ํจ์ ์ด๋ฆ: make_batchnorm_layer
์ ๋ ฅ:
batch: ๋ฐฐ์น ํฌ๊ธฐ
w: ์ด๋ฏธ์ง ๋๋น
h: ์ด๋ฏธ์ง ๋์ด
c: ์ฑ๋ ์
๋์:
๋ฐฐ์น ์ ๊ทํ ๋ ์ด์ด๋ฅผ ์์ฑํ๊ณ ์ด๊ธฐํํฉ๋๋ค.
๋ฐฐ์น ์ ๊ทํ์ ์ค์ผ์ผ๊ณผ ๋ฐ์ด์ด์ค ๊ฐ์ 1๊ณผ 0์ผ๋ก ์ด๊ธฐํํ๊ณ , ํ๊ท ๊ณผ ๋ถ์ฐ, ๊ทธ๋ฆฌ๊ณ ๋กค๋ง ํ๊ท ๊ณผ ๋กค๋ง ๋ถ์ฐ ๊ฐ์ 0์ผ๋ก ์ด๊ธฐํํฉ๋๋ค.
์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ ํฌ๊ธฐ๋ฅผ ์ค์ ํ๊ณ , forward_batchnorm_layer์ backward_batchnorm_layer ํจ์๋ฅผ ์ค์ ํฉ๋๋ค.
์ค๋ช :
์ด ํจ์๋ ๋ฐฐ์น ์ ๊ทํ ๋ ์ด์ด๋ฅผ ์์ฑํ๊ณ ์ด๊ธฐํํฉ๋๋ค.
๋ฐฐ์น ์ ๊ทํ ๋ ์ด์ด๋ ์ธ๊ณต ์ ๊ฒฝ๋ง์์ ์ ๋ ฅ๊ฐ์ ์ ๊ทํํ๋ ๋ ์ด์ด๋ก, ์ ๋ ฅ๊ฐ์ ๋ถํฌ๋ฅผ ์์ ํํ์ฌ ํ์ต์ ๋์ฑ ์์ ์ ์ผ๋ก ๋ง๋ญ๋๋ค. ์ด ํจ์์์๋ ๋ฐฐ์น ์ ๊ทํ ๋ ์ด์ด๋ฅผ ์์ฑํ๊ณ ํ์ํ ๋ณ์๋ค์ ์ด๊ธฐํํฉ๋๋ค.
์ ๋ ฅ๊ฐ๊ณผ ์ถ๋ ฅ๊ฐ์ ํฌ๊ธฐ๋ ์ธ์๋ก ๋ฐ์ ๊ฐ์ ๋ฐ๋ผ ์ค์ ํ๋ฉฐ, ์ค์ผ์ผ, ๋ฐ์ด์ด์ค, ํ๊ท , ๋ถ์ฐ ๋ฑ์ ๋ณ์๋ ์ด๊ธฐํํฉ๋๋ค. ์ด ํจ์์์๋ ์ด๊ธฐ ์ค์ผ์ผ ๊ฐ์ 1๋ก, ๋ฐ์ด์ด์ค ๊ฐ์ 0์ผ๋ก ์ค์ ํฉ๋๋ค.
ํ๊ท ๊ณผ ๋ถ์ฐ, ๋กค๋ง ํ๊ท ๊ณผ ๋กค๋ง ๋ถ์ฐ์ ๋ชจ๋ 0์ผ๋ก ์ด๊ธฐํ๋ฉ๋๋ค.
์ด ํจ์์์๋ forward_batchnorm_layer์ backward_batchnorm_layer ํจ์๋ฅผ ์ค์ ํ๋ฉฐ, ์ด ํจ์๋ฅผ ํตํด ์์ฑ๋ ๋ฐฐ์น ์ ๊ทํ ๋ ์ด์ด๋ ์ธ๊ณต ์ ๊ฒฝ๋ง ๋ชจ๋ธ์ ๊ตฌ์ฑ ์์๋ก ํ์ฉ๋ฉ๋๋ค.
ํจ์ ์ด๋ฆ: backward_scale_cpu
์ ๋ ฅ:
x_norm: normalization๋ ์ ๋ ฅ๊ฐ์ ๊ฐ๋ฆฌํค๋ ํฌ์ธํฐ(float ๋ฐฐ์ด)
delta: ์ถ๋ ฅ๊ฐ์ ๋ํ ์์ค์ ๋ฏธ๋ถ๊ฐ์ ๊ฐ๋ฆฌํค๋ ํฌ์ธํฐ(float ๋ฐฐ์ด)
batch: ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ(int)
n: ํํฐ ์(int)
size: ํํฐ ํฌ๊ธฐ(int)
scale_updates: ์ค์ผ์ผ ๋งค๊ฐ๋ณ์์ ์ ๋ฐ์ดํธ ๊ฐ์ ์ ์ฅํ ํฌ์ธํฐ(float ๋ฐฐ์ด)
๋์:
์ ๋ ฅ๊ฐ์ normalizationํ ํ, ์ถ๋ ฅ๊ฐ์ ๋ํ ์์ค์ ๋ฏธ๋ถ๊ฐ๊ณผ ๊ณฑํ ๊ฒฐ๊ณผ๋ฅผ ๋ฏธ๋๋ฐฐ์น์ ํํฐ, ํฌ๊ธฐ๋ณ๋ก ํฉํ์ฌ ์ค์ผ์ผ ๋งค๊ฐ๋ณ์์ ์ ๋ฐ์ดํธ ๊ฐ์ ๊ณ์ฐํฉ๋๋ค.
์ค๋ช :
Convolutional Neural Network์์ Batch Normalization ๊ณ์ธต์์ ์ฌ์ฉ๋๋ ํจ์ ์ค ํ๋๋ก, ์ค์ผ์ผ ๋งค๊ฐ๋ณ์์ ์ ๋ฐ์ดํธ ๊ฐ์ ๊ณ์ฐํ๋ ํจ์์ ๋๋ค.
์ค์ผ์ผ ๋งค๊ฐ๋ณ์๋ ์ ๊ทํ๋ ์ ๋ ฅ๊ฐ์ ๋ํด ๊ณฑํด์ง๋ ๋งค๊ฐ๋ณ์๋ก, ์ ๋ฐ์ดํธ๋ ์ด ๋งค๊ฐ๋ณ์๊ฐ ์์ค์ ์ค์ด๋ ๋ฐฉํฅ์ผ๋ก ์กฐ์ ๋ฉ๋๋ค.
์ด ํจ์๋ backward propagation ๋จ๊ณ์์ ํธ์ถ๋๋ฉฐ, ์ถ๋ ฅ๊ฐ์ ๋ํ ์์ค์ ๋ฏธ๋ถ๊ฐ๊ณผ normalization๋ ์ ๋ ฅ๊ฐ์ ๊ณฑ์ ๋ฏธ๋๋ฐฐ์น์ ํํฐ, ํฌ๊ธฐ๋ณ๋ก ํฉํ์ฌ ์ค์ผ์ผ ๋งค๊ฐ๋ณ์์ ์ ๋ฐ์ดํธ ๊ฐ์ ๊ณ์ฐํฉ๋๋ค.
ํจ์ ์ด๋ฆ: mean_delta_cpu
์ ๋ ฅ:
delta: ์ด์ ์ธต์ ๋ธํ ๊ฐ (float ํํ์ 1์ฐจ์ ๋ฐฐ์ด)
variance: ํ์ฌ ์ธต์ ๋ถ์ฐ ๊ฐ (float ํํ์ 1์ฐจ์ ๋ฐฐ์ด)
batch: ๋ฐฐ์น ์ฌ์ด์ฆ (int)
filters: ํํฐ ๊ฐ์ (int)
spatial: ๊ณต๊ฐ ์ฐจ์์ ํฌ๊ธฐ (int)
mean_delta: ํ์ฌ ์ธต์ ํ๊ท ๋ธํ ๊ฐ (float ํํ์ 1์ฐจ์ ๋ฐฐ์ด)
๋์:
ํ์ฌ ์ธต์ ํ๊ท ๋ธํ ๊ฐ์ ๊ณ์ฐํ๋ ํจ์
๊ฐ ํํฐ๋ณ๋ก ๋ธํ ๊ฐ์ ํฉ์ ๊ตฌํ๊ณ , ๋ถ์ฐ์ ์ ๊ณฑ๊ทผ ๊ฐ์ผ๋ก ๋๋์ด ํ๊ท ๋ธํ ๊ฐ์ ๊ตฌํจ
์ค๋ช :
mean_delta_cpu ํจ์๋ Batch Normalization์ ํ์ต ๊ณผ์ ์ค ํ์ฌ ์ธต์ ํ๊ท ๋ธํ ๊ฐ์ ๊ณ์ฐํ๋ ํจ์์ด๋ค.
๊ฐ ํํฐ๋ณ๋ก ๋ธํ ๊ฐ์ ํฉ์ ๊ตฌํ ํ, ํด๋น ํํฐ์ ๋ถ์ฐ ๊ฐ์ ์ ๊ณฑ๊ทผ์ผ๋ก ๋๋์ด ํ๊ท ๋ธํ ๊ฐ์ ๊ณ์ฐํ๋ค.
์ด๋ ๋ถ์ฐ ๊ฐ์ 0์ผ๋ก ๋๋๋ ๊ฒ์ ๋ฐฉ์งํ๊ธฐ ์ํด ์์ ์์๊ฐ(.00001f)์ ๋ํด์ค๋ค.
๊ณ์ฐ๋ ํ๊ท ๋ธํ ๊ฐ์ mean_delta ๋ฐฐ์ด์ ์ ์ฅ๋๋ค.
ํจ์ ์ด๋ฆ: variance_delta_cpu
์ ๋ ฅ:
x: ํ์ฌ ์ธต์ ์ ๋ ฅ
delta: ํ์ฌ ์ธต์ ๋ธํ
mean: ํ์ฌ ์ธต์ ํ๊ท
variance: ํ์ฌ ์ธต์ ๋ถ์ฐ
batch: ๋ฐฐ์น ํฌ๊ธฐ
filters: ํํฐ ๊ฐ์
spatial: ๊ณต๊ฐ ํฌ๊ธฐ
variance_delta: ๋ถ์ฐ ๋ธํ
๋์:
ํ์ฌ ์ธต์ ๋ถ์ฐ ๋ธํ๋ฅผ ๊ณ์ฐํ๋ ํจ์์ ๋๋ค.
์ ๋ ฅ์ผ๋ก ํ์ฌ ์ธต์ ์ ๋ ฅ(x), ๋ธํ(delta), ํ๊ท (mean), ๋ถ์ฐ(variance), ๋ฐฐ์น ํฌ๊ธฐ(batch), ํํฐ ๊ฐ์(filters), ๊ณต๊ฐ ํฌ๊ธฐ(spatial)๊ฐ ์ฃผ์ด์ง๋๋ค.
๊ฐ ํํฐ๋ณ๋ก ๋ถ์ฐ ๋ธํ๋ฅผ ๊ณ์ฐํ๋ฉฐ, ์ด๋ฅผ ์ํด ๊ฐ ๋ฐฐ์น์์ ํ์ฌ ํํฐ์ ๊ณต๊ฐ ์์น์ ๋ฐ๋ฅธ ์ธ๋ฑ์ค๋ฅผ ๊ณ์ฐํฉ๋๋ค.
๋ถ์ฐ ๋ธํ๋ delta์ (x-mean)์ ๊ณฑ์ ํฉ์ผ๋ก ๊ณ์ฐ๋๋ฉฐ, ์ด ๊ฐ์ -(variance+0.00001f)^(3/2)๋ฅผ ๊ณฑํ ๊ฐ์ ๋ฐ๋๊ฐ์ ์ ์ฅํฉ๋๋ค.
์ค๋ช :
๋ฐฐ์น ์ ๊ทํ(batch normalization)์์๋ ๊ฐ ์ธต์์ ์ ๋ ฅ(x)์ ๋ํ ํ๊ท (mean)๊ณผ ๋ถ์ฐ(variance)์ ๊ณ์ฐํฉ๋๋ค.
๊ทธ๋ฆฌ๊ณ ๋ถ์ฐ ๋ธํ(variance delta)๋ ์ด์ ์ธต์ ์ถ๋ ฅ๊ฐ๊ณผ ํ์ฌ ์ธต์ ๋ธํ๋ฅผ ์ด์ฉํ์ฌ ๊ณ์ฐ๋ฉ๋๋ค.
์ด ํจ์๋ ๋ถ์ฐ ๋ธํ๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํ ํจ์ ์ค ํ๋๋ก, ํ์ฌ ์ธต์ ์ ๋ ฅ, ๋ธํ, ํ๊ท , ๋ถ์ฐ์ ์ด์ฉํ์ฌ ๋ถ์ฐ ๋ธํ๋ฅผ ๊ณ์ฐํฉ๋๋ค.
ํจ์ ์ด๋ฆ: normalize_delta_cpu
์ ๋ ฅ:
float *x: ์ ๋ ฅ๊ฐ ํฌ์ธํฐ
float *mean: ํ๊ท ๊ฐ ํฌ์ธํฐ
float *variance: ๋ถ์ฐ๊ฐ ํฌ์ธํฐ
float *mean_delta: ํ๊ท ๊ฐ ๋ณํ๋ ํฌ์ธํฐ
float *variance_delta: ๋ถ์ฐ๊ฐ ๋ณํ๋ ํฌ์ธํฐ
int batch: ๋ฐฐ์น ํฌ๊ธฐ
int filters: ํํฐ ๊ฐ์
int spatial: ๊ณต๊ฐ ํฌ๊ธฐ
float *delta: ๋ธํ๊ฐ ํฌ์ธํฐ
๋์:
์ ๋ ฅ๊ฐ x, ํ๊ท ๊ฐ mean, ๋ถ์ฐ๊ฐ variance, ํ๊ท ๊ฐ ๋ณํ๋ mean_delta, ๋ถ์ฐ๊ฐ ๋ณํ๋ variance_delta, ๋ฐฐ์น ํฌ๊ธฐ batch, ํํฐ ๊ฐ์ filters, ๊ณต๊ฐ ํฌ๊ธฐ spatial, ๋ธํ๊ฐ delta๋ฅผ ๋ฐ์์ ๋ธํ๊ฐ delta๋ฅผ ์ ๊ทํ(normalize)ํ๋ค.
์ ๊ทํ๋ฅผ ํ๊ธฐ ์ํด ๋ธํ๊ฐ delta๋ฅผ ํ๊ท (mean)๊ณผ ๋ถ์ฐ(variance)์ ์ด์ฉํ์ฌ ํ์คํ(standardize)ํ๋ค.
๋ํ, ํ๊ท ๊ฐ ๋ณํ๋ mean_delta, ๋ถ์ฐ๊ฐ ๋ณํ๋ variance_delta๋ฅผ ์ด์ฉํ์ฌ ํ๊ท ๊ณผ ๋ถ์ฐ์ ๋ณํ๋์ ์ถ๊ฐ๋ก ๋ฐ์ํ๋ค.
์ค๋ช :
์ ๋ ฅ๊ฐ๊ณผ ๋ธํ๊ฐ์ ๋ชจ๋ (batch * filters * spatial) ํฌ๊ธฐ์ 1์ฐจ์ ๋ฐฐ์ด๋ก ํํ๋๋ค.
์ด ํจ์๋ CPU์์ ๋์ํ๋ฉฐ, GPU์์ ๋์ํ๋ ๋ฒ์ ๋ ์กด์ฌํ๋ค.
์ผ๋ฐ์ ์ผ๋ก ๋ฅ๋ฌ๋์์ ์ ๋ ฅ๊ฐ์ ์ ๊ทํํ๋ ๊ฒ์ ํ์ต์ ์์ ์ฑ๊ณผ ์ฑ๋ฅ ํฅ์์ ๋๋ชจํ๊ธฐ ์ํ ๋ฐฉ๋ฒ ์ค ํ๋์ด๋ค.
์ ๋ ฅ๊ฐ x๋ฅผ ํ๊ท mean๊ณผ ๋ถ์ฐ variance๋ฅผ ์ด์ฉํ์ฌ ํ์คํํ ํ, ๋ธํ๊ฐ delta์ ๋ค์ ๊ณฑํด์ค์ผ๋ก์จ ์ ๋ ฅ๊ฐ์ ์ ๊ทํํ๋ค.
ํ๊ท ๊ฐ ๋ณํ๋ mean_delta์ ๋ถ์ฐ๊ฐ ๋ณํ๋ variance_delta๋ ์ด์ ๋ฐฐ์น์์์ ๊ฐ๋ค์ ๊ณ ๋ คํ์ฌ ์๋ก์ด ๋ฐฐ์น์์์ ํ๊ท ๊ณผ ๋ถ์ฐ์ด ์ด๋ป๊ฒ ๋ณํํ๋์ง ์ถ์ ํ๊ธฐ ์ํ ๊ฐ์ด๋ค.