gru_layer
GRU layer ๋?
GRU (Gated Recurrent Unit) ๋ ์ด์ด๋ ๋ฐ๋ณต ์ ๊ฒฝ๋ง (Recurrent Neural Network, RNN)์ ํ ์ข ๋ฅ๋ก, ๊ธด ์ํ์ค๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ์ ์ฌ์ฉ๋ฉ๋๋ค.
GRU๋ ๊ธฐ๋ณธ์ ์ผ๋ก LSTM (Long Short-Term Memory)๊ณผ ์ ์ฌํ ์์ด๋์ด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ๊ณ ์์ต๋๋ค. LSTM๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก, GRU๋ RNN ๊ณ์ด์ ๋ ์ด์ด๋ก์ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค. ํ์ง๋ง LSTM๊ณผ๋ ๋ฌ๋ฆฌ, GRU๋ ๊ฒ์ดํธ ๋ฉ์ปค๋์ฆ์ ์ฌ์ฉํ์ฌ ๊ธฐ์ต์ ๋ณดํธํ๊ณ , ์ด์ ์ํ์์ ์ ๋ณด๋ฅผ ๊ฐ์ ธ์ค๋ ๋ฐฉ๋ฒ์ ๊ฐ๋จํํ์ฌ ๋ ์ ์ ๊ณ์ฐ์ผ๋ก ์ฅ๊ธฐ์ ์ธ ์ํ๋ฅผ ์ ์งํ ์ ์๋๋ก ํฉ๋๋ค.
GRU๋ LSTM๋ณด๋ค ๋ ๊ฐ๋จํ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๊ณ ์์ผ๋ฉฐ, ๋ ์ ์ ํ๋ผ๋ฏธํฐ๋ฅผ ํ์๋ก ํฉ๋๋ค. GRU๋ LSTM๋ณด๋ค ํ์ต ์๋๊ฐ ๋ ๋น ๋ฅด๊ณ , ์์ ๋ฐ์ดํฐ์ ์์ ๋ ์ผ๋ฐ์ ์ธ ๋ชจ๋ธ์ ๋ง๋ค์ด๋ด๋ ๊ฒฝํฅ์ด ์์ต๋๋ค.
GRU ๋ ์ด์ด๋ 2๊ฐ์ ๊ฒ์ดํธ๋ฅผ ์ฌ์ฉํ์ฌ ๊ธฐ์ต์ ์กฐ์ ํฉ๋๋ค. ์ฒซ ๋ฒ์งธ ๊ฒ์ดํธ๋ "์ ๋ฐ์ดํธ ๊ฒ์ดํธ"๋ผ๊ณ ๋ถ๋ฆฌ๋ฉฐ, ํ์ฌ ์ ๋ ฅ๊ณผ ์ด์ ์ํ๋ฅผ ๊ฒฐํฉํ์ฌ ์๋ก์ด ์ํ๋ฅผ ์์ฑํฉ๋๋ค. ๋ ๋ฒ์งธ ๊ฒ์ดํธ๋ "์ฌ์ค์ ๊ฒ์ดํธ"๋ผ๊ณ ๋ถ๋ฆฌ๋ฉฐ, ์ด์ ์ํ์ ์ผ๋ถ๋ฅผ ๋ฒ๋ฆฌ๊ณ ์๋ก์ด ์ํ๋ฅผ ๋ง๋ญ๋๋ค. GRU ๋ ์ด์ด๋ ์ด๋ฌํ ๊ฒ์ดํธ๋ค์ ์ฌ์ฉํ์ฌ ์ ๋ ฅ ์ํ์ค์ ์ด์ ์ํ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ ๋ค์, ์๋ก์ด ์ํ๋ฅผ ์ถ๋ ฅํฉ๋๋ค.
GRU ๋ ์ด์ด๋ ์ฃผ๋ก ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ๋ค๋ฃจ๋ ์์ฐ์ด ์ฒ๋ฆฌ(NLP) ๋ถ์ผ์์ ์ฌ์ฉ๋ฉ๋๋ค. GRU ๋ ์ด์ด๋ฅผ ์ ์ฉํ ๋ชจ๋ธ์ ํ ์คํธ ์์ฑ, ๋ฒ์ญ, ๊ฐ์ฑ ๋ถ์ ๋ฑ ๋ค์ํ ํ์คํฌ์์ ์ข์ ์ฑ๋ฅ์ ๋ณด์ ๋๋ค.
increment_layer
ํจ์ ์ด๋ฆ: increment_layer
์ ๋ ฅ:
layer *l: ์ ๋ฐ์ดํธํ ๋ ์ด์ด
int steps: ์ด๋ํ ์คํ ์
๋์:
layer ๊ตฌ์กฐ์ฒด ํฌ์ธํฐ์ธ l์ output, delta, x, x_norm์ steps๋งํผ ์ด๋ํ ํฌ์ธํฐ๋ฅผ ํ ๋นํ๋ค.
GPU ํ๊ฒฝ์์๋ l์ output_gpu, delta_gpu, x_gpu, x_norm_gpu์ steps๋งํผ ์ด๋ํ ํฌ์ธํฐ๋ฅผ ํ ๋นํ๋ค.
์ค๋ช :
ํด๋น ํจ์๋ ๋ ์ด์ด์ ํฌ์ธํฐ๋ฅผ steps๋งํผ ์ด๋์์ผ ์ ๋ฐ์ดํธํ๋ ํจ์์ด๋ค.
ํฌ์ธํฐ๋ฅผ ์ด๋์์ผ์ ์ด์ ์ ๊ฐ์ ์ฐธ์กฐํ์ง ์๊ณ ์๋ก์ด ๊ฐ์ ์ฐธ์กฐํ ์ ์๋๋ก ํ๋ค.
GPU ํ๊ฒฝ์์๋ GPU ๋ฉ๋ชจ๋ฆฌ ์์ ํฌ์ธํฐ๋ฅผ ์ด๋์ํจ๋ค.
forward_gru_layer
ํจ์ ์ด๋ฆ: forward_gru_layer
์ ๋ ฅ:
layer l: GRU ๋ ์ด์ด์ ์ ๋ณด์ ๋งค๊ฐ๋ณ์๋ฅผ ๋ด๊ณ ์๋ layer ๊ตฌ์กฐ์ฒด
network net: ๋คํธ์ํฌ์ ์ ๋ณด์ ๋งค๊ฐ๋ณ์๋ฅผ ๋ด๊ณ ์๋ network ๊ตฌ์กฐ์ฒด
๋์:
์ ๋ ฅ ๋ฐ์ดํฐ์ GRU ๋ ์ด์ด๋ฅผ ํตํด ์๋ฐฉํฅ ์ ํ(forward propagation)๋ฅผ ์ํํ๋ ํจ์๋ก, ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ GRU ๋ ์ด์ด๋ฅผ ํตํด ์ฒ๋ฆฌํ์ฌ ์ถ๋ ฅ ๊ฐ์ ๊ณ์ฐํ๊ณ , ๊ทธ ๊ฐ์ ๋ค์ ๋ ์ด์ด์ ์ ๋ ฅ์ผ๋ก ๋๊ฒจ์ค.
์ด๋, backward propagation์ ์ํด ํ์ํ ์ค๊ฐ๊ฐ๋ค์ ์ ์ฅํด ๋์.
์ค๋ช :
GRU ๋ ์ด์ด์ ๋งค๊ฐ๋ณ์๋ค ์ค์์ uz, ur, uh๋ ์ด์ ์ํ(previous state)๋ก๋ถํฐ์ ์ ๋ ฅ(input)์ ์ฒ๋ฆฌํ๋ ๊ฐ์ค์น(weight) ๋งค๊ฐ๋ณ์์ด๊ณ , wz, wr, wh๋ ํ์ฌ ์ ๋ ฅ(input)์ ์ฒ๋ฆฌํ๋ ๊ฐ์ค์น ๋งค๊ฐ๋ณ์์.
GRU ๋ ์ด์ด๋ ์๊ณ์ด(sequence) ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํ RNN์ ํ ์ข ๋ฅ๋ก, ์ด์ ์์ ์ ์ํ(previous state)๋ฅผ ์ฌ์ฌ์ฉํ๋ ๋ ์ด์ด์.
forward_connected_layer ํจ์๋ฅผ ํตํด ๊ฐ์ค์น์ ์ ๋ ฅ์ ๊ณฑํ ๊ฐ๊ณผ bias๋ฅผ ๋ํ ๊ฐ์ ๊ณ์ฐํ์ฌ ํ์ฑํ ํจ์(Logistic ๋๋ Tanh)๋ฅผ ์ ์ฉํจ.
uz, ur, uh ๋ ์ด์ด์์ ๋์จ ์ถ๋ ฅ๊ฐ๊ณผ wz, wr, wh ๋ ์ด์ด์์ ๋์จ ์ถ๋ ฅ๊ฐ์ ์ด์ฉํ์ฌ z์ r ๊ฐ์ ๊ณ์ฐํจ.
z๊ฐ์ ์ด์ ์ํ์ ํ์ฌ ์ ๋ ฅ์ ์กฐํฉํ ํ ๋ก์ง์คํฑ ํจ์๋ฅผ ์ ์ฉํ์ฌ ๊ณ์ฐํจ.
r๊ฐ์ z๊ฐ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ์ด์ ์ํ์ ํ์ฌ ์ ๋ ฅ์ ์กฐํฉํ ํ ๋ก์ง์คํฑ ํจ์๋ฅผ ์ ์ฉํ์ฌ ๊ณ์ฐํจ.
h๊ฐ์ z๊ฐ๊ณผ ์ด์ ์ํ๋ฅผ ์ด์ฉํ์ฌ ์๋ก์ด ์ํ๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํ ๊ฒ์ดํธ(gate)๋ฅผ ๊ณ์ฐํจ.
๊ณ์ฐ๋ h๊ฐ์ Tanh ๋๋ Logistic ํจ์๋ฅผ ์ ์ฉํ์ฌ ์ถ๋ ฅ๊ฐ(output)์ ๊ณ์ฐํจ.
GRU ๋ ์ด์ด๋ ์ฌ๋ฌ ์์ (time step)์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ผ๋ฏ๋ก, steps ๋งํผ ๋ฐ๋ณต์ ์ผ๋ก forward_connected_layer ํจ์๋ฅผ ํธ์ถํ์ฌ ์ค๊ฐ๊ฐ๋ค์ ๊ณ์ฐํจ.
backward_gru_layer
ํจ์ ์ด๋ฆ: backward_gru_layer
์ ๋ ฅ:
layer l
network net (๋ ๋ค ๊ตฌ์กฐ์ฒด)
๋์:
GRU (๊ฒ์ดํธ ์ํ ์ ๋) ๋ ์ด์ด์ ์ญ์ ํ(backpropagation)๋ฅผ ๊ณ์ฐํ๊ณ ์ด์ ๋ ์ด์ด์๊ฒ ์ค์ฐจ ์ ํธ(error signal)๋ฅผ ์ ๋ฌํฉ๋๋ค.
์ด๋ฅผ ์ํด ์ ๋ ฅ ์ ํธ์ ๊ฐ์ค์น(weight)์ ๋ํ ๋ฏธ๋ถ(gradient)์ ๊ณ์ฐํฉ๋๋ค.
์ค๋ช :
l: GRU ๋ ์ด์ด์ ๊ตฌ์กฐ์ฒด๋ก, ์ ๋ ฅ ์ ํธ์ ๊ฐ์ค์น, ์ถ๋ ฅ๊ณผ ๊ฐ์ ๋ค์ํ ์ ๋ณด๋ฅผ ๋ด๊ณ ์์ต๋๋ค.
net: ์ ๊ฒฝ๋ง ๊ตฌ์กฐ์ฒด๋ก, ์ญ์ ํ ์์ ์ด์ ๋ ์ด์ด๋ก ์ค์ฐจ ์ ํธ๋ฅผ ์ ๋ฌํ๊ธฐ ์ํด ์ฌ์ฉ๋ฉ๋๋ค.
์ด ํจ์๋ ๋น ์ํ๋ก ๋จ๊ฒจ๋ ๊ฒ์ด ์๋๋ผ, ๊ตฌํ ๋ด์ฉ์ด ์๋ ๊ฒ์ ๋๋ค. ํจ์๋ฅผ ํธ์ถํ ๋ ์ค์ ๋ก ๊ณ์ฐ์ด ์ด๋ฃจ์ด์ง๋๋ค.
update_gru_layer
ํจ์ ์ด๋ฆ: update_gru_layer
์ ๋ ฅ:
layer l: GRU ๋ ์ด์ด ๊ตฌ์กฐ์ฒด
update_args a: ์ ๋ฐ์ดํธ ์ธ์ ๊ตฌ์กฐ์ฒด
๋์:
GRU ๋ ์ด์ด์ ๊ฐ๊ฐ์ ์ฐ๊ฒฐ๋ ๋ ์ด์ด(ur, uz, uh, wr, wz, wh)๋ค์ ๊ฐ์ค์น(weight)์ bias๋ฅผ ์ ๋ฐ์ดํธํ๋ ํจ์
์ค๋ช :
์ ๋ ฅ์ผ๋ก ์ฃผ์ด์ง GRU ๋ ์ด์ด ๊ตฌ์กฐ์ฒด l์ ์ฐ๊ฒฐ๋ ๋ ์ด์ด(ur, uz, uh, wr, wz, wh)๋ค์ ๊ฐ์ค์น์ bias๋ฅผ ์ ๋ฐ์ดํธํ๋ ํจ์์ด๋ค.
์ด๋ฅผ ์ํด update_connected_layer() ํจ์๋ฅผ ๊ฐ ๋ ์ด์ด์ ๋ํด ํธ์ถํ์ฌ ๊ฐ์ค์น๋ฅผ ์ ๋ฐ์ดํธํ๋ค.
make_gru_layer
ํจ์ ์ด๋ฆ: make_gru_layer
์ ๋ ฅ:
int batch: ๋ฐฐ์น ํฌ๊ธฐ
int inputs: ์ ๋ ฅ์ ํฌ๊ธฐ
int outputs: ์ถ๋ ฅ์ ํฌ๊ธฐ
int steps: ์๊ฐ ์คํ ์ ์
int batch_normalize: ๋ฐฐ์น ์ ๊ทํ ์ฌ์ฉ ์ฌ๋ถ
int adam: Adam ์ตํฐ๋ง์ด์ ์ฌ์ฉ ์ฌ๋ถ
๋์:
GRU ๋ ์ด์ด๋ฅผ ์์ฑํ๊ณ ์ด๊ธฐํํ๋ ํจ์์ด๋ค. GRU ๋ ์ด์ด๋ uz, wr, uh, wh ๋ฑ์ ์ฐ๊ฒฐ ๋ ์ด์ด๋ก ๊ตฌ์ฑ๋์ด ์๋ค.
์ค๋ช :
์ ๋ ฅ๊ฐ์ผ๋ก ๋ฐ์ batch ๊ฐ์ steps๋ก ๋๋์ด์ ธ์ ์ฌ์ฉ๋๋ค.
๋ ์ด์ด์ ํ์ ์ GRU๋ก ์ค์ ๋๋ค.
uz, wz, ur, wr, uh, wh ๋ฑ์ ์ฐ๊ฒฐ ๋ ์ด์ด๊ฐ ์์ฑ๋๊ณ ์ด๊ธฐํ๋๋ค.
์ถ๋ ฅ๊ฐ, delta, state, prev_state, forgot_state, forgot_delta, r_cpu, z_cpu, h_cpu ๋ฑ์ ๊ฐ๋ค์ด ์ด๊ธฐํ๋๋ค.
forward, backward, update ํจ์๊ฐ ์ค์ ๋๋ค.
์ด๊ธฐํ๋ GRU ๋ ์ด์ด๊ฐ ๋ฐํ๋๋ค.
Last updated
Was this helpful?