lstm_layer
LSTM Layer๋?
LSTM์ Long Short Term Memory networks์ ์ฝ์์ ๋๋ค. RNN๊ณผ ๊ฐ์ด ์์ฐ์ด์ฒ๋ฆฌ, ์์ฑ์ฒ๋ฆฌ ๋ฑ Sequential ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๊ธฐ ์ํด ๋ง์ด ์ฌ์ฉ๋๋ layer์ ๋๋ค.
๊ธฐ์กด์ RNN์ ํ์ตํ๋ฉด์ ์ ์ ๊ณผ๊ฑฐ ์ ๋ณด๋ฅผ ์์ด๋ฒ๋ฆฌ๋(Gradient Vanishing) ๋ฌธ์ ๊ฐ ๋ฐ์ํ๊ณ ์ด๋ฌํ ์ฅ๊ธฐ์ ์ธ ์์กด์ฑ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด์ ์ค๊ณ๋ layer์ ๋๋ค.
LSTM์ ํต์ฌ์ ์ธ ์์๋ cell state
์
๋๋ค. LSTM์ cell state
๋ ๊ณต์ฅ์ ์ปจ๋ฒ ์ด์ด ๋ฒจํธ์ ๊ฐ์ผ๋ฉฐ ์ด๋ฌํ ์ปจ๋ฒ ์ด์ด ๋ฒจํธ์ gate
๋ฅผ ์ด์ฉํ์ฌ ๊ฐ์ ๊ณต๊ธํ์ฌ ์ ๋ณด๋ฅผ ์ถ๊ฐํ๊ฑฐ๋ ์ ๊ฑฐํด ๊ฐ๋๋ค.
gate
๋ ์ด 3๊ฐ์ง๋ก ์ด๋ฃจ์ด์ ธ ์์ต๋๋ค.
forget gate๋ ์ ๋ณด๋ฅผ ์ผ๋ง๋ ์์ ๊ฒ์ธ์ง์ ๋ํด์ ์ฐ์ฐํ๋ gate์ ๋๋ค. sigmoid๋ฅผ ํตํด 0 ~ 1 ์ฌ์ด์ ๊ฐ์ด ์ถ๋ ฅ๋๋๋ฐ 1์ ๊ฐ๊น์ฐ๋ฉด ๊ธฐ์ตํ๋ผ๋ ์๋ฏธ๋ฅผ ํฌํจํ๊ณ 0์ ๊ฐ๊น์ฐ๋ฉด ์์ผ๋ผ๋ ์๋ฏธ๋ฅผ ํฌํจํฉ๋๋ค.
input gate๋ ์๋ก์ด ์ ๋ณด๋ฅผ ๊ณต๊ธํ๋ ์ฐ์ฐ์ ํ๋ gate์ ๋๋ค. sigmoid๋ฅผ ํตํด ์ด๋ค ์ ๋ ฅ๊ฐ์ ์ ๋ฐ์ดํธํด์ผ ํ ์ง ๊ฒฐ์ ํ๊ณ tanh๋ ์๋ก์ด ์ ๋ ฅ๊ฐ์ ๋ง๋ญ๋๋ค. ๋ ๊ฐ์ ๊ฐ์ ํฉ์ณ์ ์๋ก์ด ๊ฐ์ด ๊ธฐ์กด ๊ฐ์ ์ํฅ์ ์ฃผ๋ ๊ฐ์ ๋ง๋ค์ด ๋ ๋๋ค.
output gate๋ ์ด๋ค ์ถ๋ ฅ๊ฐ์ ๋ค์ state์ ๋ณด๋ด์ค์ง ๊ฒฐ์ ํ๋ gate์ ๋๋ค. sigmoid๋ฅผ ํตํด ์ด๋ค ๊ฐ์ ์ถ๋ ฅํด์ผ ํ ์ง ๊ฒฐ์ ํ๊ณ tanh๋ ์ ๋ฐ์ดํธ ๋ cell state์ ์ํฅ์ ๋งํด์ค๋๋ค.
๊ฒฐ๋ก ์ ์ผ๋ก ์์๋ forget gate๋ก ์์ด์ผํ ๋ถ๋ถ์ ์๊ณ input gate๋ก ์๋ก์ด ๊ฐ์ ์ถ๊ฐํ๋ฉฐ cell state๋ฅผ ์ ๋ฐ์ดํธ ํ๊ณ output gate๋ฅผ ํตํด ์ต์ข ์ ์ผ๋ก ์ถ๋ ฅํฉ๋๋ค.
increment_layer
ํจ์ ์ด๋ฆ: increment_layer
์ ๋ ฅ:
layer ํฌ์ธํฐ l: ๊ฐ์ด ์ฆ๊ฐ๋ ๋ ์ด์ด ๊ฐ์ฒด ํฌ์ธํฐ
int steps: ์ฆ๊ฐํ ์คํ ์
๋์:
l ๊ฐ์ฒด์ output, delta, x, x_norm ํฌ์ธํฐ๊ฐ ๊ฐ๋ฆฌํค๋ ๊ฐ์ steps * l->outputs * l->batch๋ฅผ ๋ํด ๊ฐ์ ์ฆ๊ฐ์ํด
์ค๋ช :
์ด ํจ์๋ ๋ด๋ด ๋คํธ์ํฌ์์ ์ญ์ ํ ์๊ณ ๋ฆฌ์ฆ์ ์ํํ๊ธฐ ์ํ LSTM ๋ ์ด์ด์์ ์ฌ์ฉ๋๋ ํจ์์ ๋๋ค.
LSTM ๋ ์ด์ด์์๋ ์ํ์ค์ ๊ฐ ํ์์คํ ์ ๋ํด forward์ backward ํจ์ค๋ฅผ ์ํํด์ผ ํฉ๋๋ค. increment_layer ํจ์๋ backward ํจ์ค๋ฅผ ์ํํ ๋ ์ด์ ์ํ์ค ํ์์คํ ์ ๋ํ ์ถ๋ ฅ, ๋ธํ, ์ ๋ ฅ ๋ฑ์ ํฌ์ธํฐ๋ฅผ ์ฆ๊ฐ์ํค๊ธฐ ์ํด ์ฌ์ฉ๋ฉ๋๋ค.
l ๊ฐ์ฒด์ output, delta, x, x_norm ํฌ์ธํฐ๋ ๋ชจ๋ ์ด์ ํ์์คํ ์ ๋ํ ๊ฐ์ ๊ฐ๋ฆฌํค๊ณ ์์ต๋๋ค. ๋ฐ๋ผ์ steps * l->outputs * l->batch ๋งํผ ๊ฐ์ ๋ํด์ฃผ๋ฉด ์ด์ ํ์์คํ ์ ๊ฐ์ ๋ํ ํฌ์ธํฐ๋ฅผ ์ฆ๊ฐ์ํค๋ ํจ๊ณผ๋ฅผ ์ป์ ์ ์์ต๋๋ค.
์ด ํจ์๋ static ํค์๋๋ฅผ ๊ฐ์ง๊ณ ์์ผ๋ฏ๋ก ๊ฐ์ ์์ค ํ์ผ ๋ด์์๋ง ์ฌ์ฉ ๊ฐ๋ฅํฉ๋๋ค.
forward_lstm_layer
ํจ์ ์ด๋ฆ: forward_lstm_layer
์ ๋ ฅ:
layer l: LSTM ๋ ์ด์ด์ ๊ตฌ์กฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ์ฅํ๋ ๊ตฌ์กฐ์ฒด
network state: ์ ๋ ฅ ๋ฐ์ดํฐ์ ๊ทธ ์ธ์ ๋คํธ์ํฌ ์ ๋ณด๋ฅผ ์ ์ฅํ๋ ๊ตฌ์กฐ์ฒด
๋์:
LSTM ๋ ์ด์ด์ forward propagation์ ์ํํ๋ ํจ์
ํ์ฌ ๋ ์ด์ด์ ํ๋ผ๋ฏธํฐ์ ์ด์ ์์ ์ ์ถ๋ ฅ ๊ฐ์ ์ด์ฉํ์ฌ ํ์ฌ ์์ ์ ์ถ๋ ฅ ๊ฐ์ ๊ณ์ฐ
์ ๋ ฅ ๋ฐ์ดํฐ๋ state.input์ ์ ์ฅ๋์ด ์์ผ๋ฉฐ, l.steps ๋ฒ ๋งํผ forward propagation์ ๋ฐ๋ณตํ์ฌ ์ถ๋ ฅ ๊ฐ์ ๊ณ์ฐ
๊ฐ ์ฐ์ฐ์ ๋ด๋ถ์ ์ผ๋ก connected layer์ forward ์ฐ์ฐ์ ์ด์ฉํ์ฌ ์ํ๋จ
์ค๋ช :
์ ๋ ฅ์ผ๋ก ์ฃผ์ด์ง LSTM ๋ ์ด์ด์ ๊ตฌ์กฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ด์ฉํ์ฌ forward propagation์ ์ํํ๊ณ , ํ์ฌ ์์ ์ ์ถ๋ ฅ ๊ฐ์ ๊ณ์ฐํ์ฌ l.output์ ์ ์ฅํจ
์ด์ ์์ ์ ์ถ๋ ฅ ๊ฐ๊ณผ ํ์ฌ ์์ ์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํ์ฌ ํ์ฌ ์์ ์ ์ถ๋ ฅ ๊ฐ์ ๊ณ์ฐํจ
๊ฐ ๊ฒ์ดํธ(gate)์ ์ถ๋ ฅ ๊ฐ๊ณผ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํ์ฌ candidate ๊ฐ๊ณผ forget ๊ฐ์ ๊ณ์ฐํ๊ณ , cell ์ํ๋ฅผ ์ ๋ฐ์ดํธํ์ฌ ํ์ฌ ์์ ์ ์ถ๋ ฅ ๊ฐ์ ๊ณ์ฐํจ
forward propagation ๋์ค์๋ backpropagation์ ์ํ ๋ฏธ๋ถ ๊ฐ(delta)๋ค๋ ๊ณ์ฐ๋จ
state.train์ด true์ธ ๊ฒฝ์ฐ์๋ ํ์ฌ ์์ ์ ์ถ๋ ฅ ๊ฐ์ ๋ํ ์์ค ํจ์์ ๋ฏธ๋ถ ๊ฐ(l.delta)๋ ๊ณ์ฐ๋จ
backward_lstm_layer
ํจ์ ์ด๋ฆ: backward_lstm_layer
์ ๋ ฅ:
l: LSTM ๋ ์ด์ด ๋งค๊ฐ๋ณ์๋ฅผ ํฌํจํ๋ ๋ ์ด์ด ๊ฐ์ฒด
state: ํ์ฌ ๋คํธ์ํฌ ์ํ๋ฅผ ํฌํจํ๋ ๋คํธ์ํฌ ์ํ ๊ฐ์ฒด
๋์:
ํ์ฌ ์ํ์ค์์ ์ญ๋ฐฉํฅ LSTM ๋ ์ด์ด์ ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ณ์ฐํ๋ค.
s๋ผ๋ ์๋ก์ด ๋คํธ์ํฌ ๊ฐ์ฒด๋ฅผ ๋ชจ๋ ๊ฐ์ด 0์ผ๋ก ์ด๊ธฐํํ๊ณ , s์ train ํ๋๊ทธ๋ฅผ state.train ๊ฐ์ผ๋ก ์ค์ ํ๋ฉฐ, ์ ์ i๋ฅผ ์ด๊ธฐํํ๋ค.
์ ๋ฐฉ LSTM ๋ ์ด์ด์ ์ญ๋ฐฉํฅ LSTM ๋ ์ด์ด์ ๊ฐ์ค์น ํ๋ ฌ(๊ฐ๊ฐ wf, wi, wg, wo, uf, ui, ug, uo) 8๊ฐ์ ๋ ์ด์ด ๊ฐ์ฒด๋ฅผ ์ด๊ธฐํํ๋ค.
์ ๋ ฅ๊ณผ ๋ธํ ํฌ์ธํฐ๋ฅผ ํ์ฌ ์ํ์ค์ ๋ง์ง๋ง ํ์ ์คํ ์ ๊ฐ๋ฆฌํค๋๋ก ์ ๋ฐ์ดํธํ๋ค.
์ถ๋ ฅ, cell_cpu, delta ํฌ์ธํฐ๋ฅผ ํ์ฌ ์ํ์ค์ ๋ง์ง๋ง ํ์ ์คํ ์ ๊ฐ๋ฆฌํค๋๋ก ์ ๋ฐ์ดํธํ๋ค.
์ํ์ค๋ฅผ ์ญ์์ผ๋ก ๋ฐ๋ณตํ๋ฉฐ ๊ฐ ์๊ฐ ๋จ๊ณ์์ ๋ค์์ ์ํํ๋ค:
l.cell_cpu์ l.output์ ๊ฐ์ l.c_cpu์ l.h_cpu๋ก ๋ณต์ฌํ๋ค.
l.dh_cpu ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ณ์ฐํ๋ค.
l.f_cpu, l.i_cpu, l.g_cpu, l.o_cpu ๋ฐฐ์ด์ ๊ฐ์ ์ ๋ฐ์ดํธํ๋ค.
๋ก์ง์คํฑ ํจ์์ ํ์ดํผ๋ณผ๋ฆญ ํ์ ํธ ํจ์๋ฅผ l.f_cpu, l.i_cpu, l.g_cpu, l.o_cpu ๋ฐฐ์ด์ ์ ์ฉํ๋ค.
l.delta๋ฅผ l.temp3_cpu์ ๋ณต์ฌํ๊ณ , l.c_cpu์ ์ ์ฉ๋ ํ์ดํผ๋ณผ๋ฆญ ํ์ ํธ ํจ์์ ๊ธฐ์ธ๊ธฐ๋ฅผ l.temp2_cpu๋ก ๊ณ์ฐํ๋ค.
l.dc_cpu๋ฅผ l.dc_cpu์ l.temp2_cpu์ ํฉ์ผ๋ก ๊ณ์ฐํ๊ณ , l.temp_cpu์ l.c_cpu์ ๊ฐ์ ๋ณต์ฌํ ํ ํ์ดํผ๋ณผ๋ฆญ ํ์ ํธ ํจ์๋ฅผ ์ ์ฉํ๋ค.
๊ฒฐ๊ณผ ๋ฐฐ์ด์ l.temp3_cpu์ ๊ณฑํ์ฌ l.temp2_cpu๋ฅผ ์ป๋๋ค.
์ค๋ช :
LSTM ๋ ์ด์ด์ ์ญ๋ฐฉํฅ ํจ์๋ ์ญ์ ํ๋ฅผ ํตํด ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ณ์ฐํ๊ณ ์ต์ ํ์ ํ์ฉํ๋ ํจ์์ด๋ค.
์ด ํจ์๋ ํ์ฌ ์ํ์ค์์ ์ญ๋ฐฉํฅ LSTM ๋ ์ด์ด์ ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด ์ฌ์ฉ๋๋ค.
ํจ์๋ ๋ ์ด์ด ๊ฐ์ฒด์ ๋คํธ์ํฌ ์ํ ๊ฐ์ฒด๋ฅผ ๋งค๊ฐ๋ณ์๋ก ๋ฐ์ผ๋ฉฐ, ์ํ์ค๋ฅผ ์ญ์์ผ๋ก ๋ฐ๋ณตํ๋ฉด์ ๊ฐ ์๊ฐ ๋จ๊ณ์์ ํ์ํ ๊ณ์ฐ์ ์ํํ๋ค.
update_lstm_layer
ํจ์ ์ด๋ฆ: update_lstm_layer
์ ๋ ฅ:
layer l: LSTM ๋ ์ด์ด์ ํ๋ผ๋ฏธํฐ๋ฅผ ํฌํจํ ๋ ์ด์ด ๊ฐ์ฒด
update_args a: ํ์ต๋ฅ ๋ฑ ์ ๋ฐ์ดํธ ์ธ์๋ฅผ ํฌํจํ ๊ตฌ์กฐ์ฒด
๋์:
์ด ํจ์๋ LSTM ๋ ์ด์ด์ ํ๋ผ๋ฏธํฐ์ธ 8๊ฐ์ ๊ฐ์ค์น ํ๋ ฌ์ ๋ํด update_connected_layer ํจ์๋ฅผ ํธ์ถํ์ฌ ์ ๋ฐ์ดํธ๋ฅผ ์ํํฉ๋๋ค.
update_connected_layer ํจ์๋ ์ ๋ ฅ๋ ์ฐ๊ฒฐ์ธต์ ๊ฐ์ค์น๋ฅผ ์ ๋ฐ์ดํธํ๊ณ ์ ํ ๋ ์ฌ์ฉ๋๋ ํจ์์ด๋ฉฐ, a์ ๋ด๊ธด ์ ๋ฐ์ดํธ ์ธ์๋ฅผ ๋ฐํ์ผ๋ก ๊ฐ ๊ฐ์ค์น์ ๋ํ gradient descent๋ฅผ ์ํํฉ๋๋ค.
์ค๋ช :
update_lstm_layer ํจ์๋ LSTM ๋ ์ด์ด ๊ฐ์ฒด l๊ณผ ์ ๋ฐ์ดํธ ์ธ์๋ฅผ ํฌํจํ ๊ตฌ์กฐ์ฒด a๋ฅผ ์ ๋ ฅ๋ฐ์ต๋๋ค.
์ด ํจ์๋ l.wf, l.wi, l.wg, l.wo, l.uf, l.ui, l.ug, l.uo์ ๊ฐ๊ฐ ์ ๊ทผํ์ฌ, update_connected_layer ํจ์๋ฅผ ํธ์ถํ์ฌ ๊ฐ์ค์น๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค.
์ด ํจ์๋ gradient descent ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ์ฌ, ์ฃผ์ด์ง ํ์ต๋ฅ ๊ณผ ์ ๋ฐ์ดํธ ์ธ์๋ฅผ ๋ฐํ์ผ๋ก ๊ฐ์ค์น๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค.
์ด ๊ณผ์ ์ ํ์ต์ ์งํํ๋ฉด์ ๋ฐ๋ณต์ ์ผ๋ก ์ํ๋๋ฉฐ, ๊ฐ์ค์น๋ฅผ ์ต์ ํํ์ฌ ๋ชจ๋ธ์ ์์ธก ์ฑ๋ฅ์ ํฅ์์ํต๋๋ค.
Last updated
Was this helpful?