shortcut_layer
shortcut layer ๋?
ResNet์์ ์ ์๋ skip connection๊ณผ ์ ์ฌํฉ๋๋ค.
์ ์ ์ถ๋ ฅ์ ์ ์ฅํ๊ณ ๊ทธ ํ์ layer์ ์ถ๋ ฅ๊ณผ ํฉ์น๋ ์์ ์์ ์ฌ์ฉ ๋ฉ๋๋ค.
shortcut.c
forward_shortcut_layer
void forward_shortcut_layer(const layer l, network net)
{
copy_cpu(l.outputs*l.batch, net.input, 1, l.output, 1); // network input -> layer output
shortcut_cpu(l.batch, l.w, l.h, l.c, net.layers[l.index].output, l.out_w, l.out_h, l.out_c, l.alpha, l.beta, l.output); // layer output += i-th layer output
activate_array(l.output, l.outputs*l.batch, l.activation);
}
ํจ์ ์ด๋ฆ: forward_shortcut_layer
์ ๋ ฅ:
const layer l: ํ์ฌ layer ์ ๋ณด
network net: ํ์ฌ network ์ ๋ณด
๋์:
ํ์ฌ layer์ ์ถ๋ ฅ๊ฐ์ผ๋ก ๋คํธ์ํฌ ์ ๋ ฅ๊ฐ์ ๋ณต์ฌ
ํ์ฌ layer์ ์ถ๋ ฅ๊ฐ์ shortcut ์ฐ๊ฒฐ๋ ์ด์ layer์ ์ถ๋ ฅ๊ฐ์ ๋ํด์ค
ํ์ฌ layer์ ์ถ๋ ฅ๊ฐ์ ํ์ฑํ ํจ์๋ฅผ ์ ์ฉ
์ค๋ช :
shortcut ์ฐ๊ฒฐ์ ํตํด ๋ค๋ฅธ layer์ ์ถ๋ ฅ๊ฐ์ ํ์ฌ layer์ ์ถ๋ ฅ๊ฐ์ ๋ํด์ค์ผ๋ก์จ, ๋คํธ์ํฌ์ ํ์ต ํจ์จ์ฑ์ ๋์ด๊ธฐ ์ํ ๋ ์ด์ด
forward_shortcut_layer ํจ์๋ ํด๋น layer์ forward propagation์ ์ํํ๋ฉฐ, ์ ๋ ฅ๊ฐ์ ํ์ฌ layer์ ์ถ๋ ฅ๊ฐ์ผ๋ก ๋ณต์ฌํ๊ณ shortcut ์ฐ๊ฒฐ๋ ์ด์ layer์ ์ถ๋ ฅ๊ฐ์ ๋ํด์ฃผ๋ฉฐ ํ์ฑํ ํจ์๋ฅผ ์ ์ฉํ๋ ์ญํ ์ ์ํํจ
backward_shortcut_layer
void backward_shortcut_layer(const layer l, network net)
{
gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta); // layer delta -> activation grad
axpy_cpu(l.outputs*l.batch, l.alpha, l.delta, 1, net.delta, 1); // network delta += alpha * layer delta
shortcut_cpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta, l.w, l.h, l.c, 1, l.beta, net.layers[l.index].delta); // i-th layer delta += layer delta
}
ํจ์ ์ด๋ฆ: backward_shortcut_layer
์ ๋ ฅ:
const layer l: shortcut layer์ ์ ๋ณด๋ฅผ ๋ด๊ณ ์๋ ๊ตฌ์กฐ์ฒด
network net: ์ ๊ฒฝ๋ง์ ๊ตฌ์ฑํ๋ layer๋ค์ ์ ๋ณด๋ฅผ ๋ด๊ณ ์๋ ๊ตฌ์กฐ์ฒด
๋์:
layer output์ activation gradient๋ฅผ ๊ณ์ฐํ์ฌ layer delta์ ์ ์ฅํ๋ค.
network delta์ alpha๊ฐ๊ณผ layer delta๊ฐ์ ๊ณฑํ์ฌ ๋ํด์ค๋ค.
i-th layer delta์๋ beta๊ฐ๊ณผ layer delta๊ฐ์ ๊ณฑํ์ฌ ๋ํด์ค๋ค.
์ค๋ช :
Shortcut layer๋ ์ ๋ ฅ๊ฐ๊ณผ ์ด์ layer์ ์ถ๋ ฅ๊ฐ์ ๋ํ์ฌ ์ถ๋ ฅ๊ฐ์ ๋ง๋ค์ด๋ธ๋ค.
๋ฐ๋ผ์ forward pass์์๋ ์ด์ layer์ ์ถ๋ ฅ๊ฐ์ ํ์ฌ layer์ ์ ๋ ฅ๊ฐ๊ณผ ๋ํ์ฌ ์ถ๋ ฅ๊ฐ์ ๊ณ์ฐํ๊ฒ ๋๋ค.
Backward pass์์๋ ํ์ฌ layer์ ์ถ๋ ฅ๊ฐ์ ๋ํ activation gradient๋ฅผ ๊ณ์ฐํ๊ณ , ์ด์ layer์ delta๊ฐ์๋ ํ์ฌ layer์ delta๊ฐ์ ๋ํ์ฌ ์ ํํ๊ฒ ๋๋ค.
resize_shortcut_layer
void resize_shortcut_layer(layer *l, int w, int h)
{
assert(l->w == l->out_w);
assert(l->h == l->out_h);
l->w = l->out_w = w;
l->h = l->out_h = h;
l->outputs = w*h*l->out_c;
l->inputs = l->outputs;
l->delta = realloc(l->delta, l->outputs*l->batch*sizeof(float));
l->output = realloc(l->output, l->outputs*l->batch*sizeof(float));
}
ํจ์ ์ด๋ฆ: resize_shortcut_layer
์ ๋ ฅ:
layer *l: ํฌ๊ธฐ๋ฅผ ์กฐ์ ํ shortcut layer์ ํฌ์ธํฐ
int w: ์๋ก์ด ๋๋น
int h: ์๋ก์ด ๋์ด
๋์:
l์ w์ out_w๊ฐ ๊ฐ์์ผ ํจ์ ํ์ธ(assert)
l์ h์ out_h๊ฐ ๊ฐ์์ผ ํจ์ ํ์ธ(assert)
l์ w์ out_w๋ฅผ w๋ก ์ ๋ฐ์ดํธ
l์ h์ out_h๋ฅผ h๋ก ์ ๋ฐ์ดํธ
l์ outputs๋ฅผ w, h, out_c์ ๊ณฑ์ผ๋ก ์ ๋ฐ์ดํธ
l์ inputs๋ฅผ outputs์ ๊ฐ๊ฒ ์ ๋ฐ์ดํธ
l์ delta ๋ฉ๋ชจ๋ฆฌ๋ฅผ outputs * batch ํฌ๊ธฐ๋งํผ ์ฌํ ๋น
l์ output ๋ฉ๋ชจ๋ฆฌ๋ฅผ outputs * batch ํฌ๊ธฐ๋งํผ ์ฌํ ๋น
์ค๋ช :
์ด ํจ์๋ shortcut layer์ ํฌ๊ธฐ๋ฅผ ์กฐ์ ํ๋ ์ญํ ์ ํ๋ค.
shortcut layer๋ input๊ณผ output์ ํฌ๊ธฐ๊ฐ ๊ฐ์์ผ ํ๊ธฐ ๋๋ฌธ์ l์ w์ out_w, h์ out_h๊ฐ ๊ฐ์์ง ํ์ธํ๊ณ ๊ฐ์ง ์์ผ๋ฉด ์๋ฌ๋ฅผ ๋ฐ์์ํจ๋ค.
๊ทธ ํ w์ h๋ก ๊ฐ๊ฐ ํฌ๊ธฐ๋ฅผ ์กฐ์ ํด์ฃผ๊ณ , outputs์ inputs๋ฅผ ์ ๋ฐ์ดํธํ๋ค.
๋ง์ง๋ง์ผ๋ก, delta์ output ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์๋ก์ด outputs * batch ํฌ๊ธฐ๋ก ์ฌํ ๋นํ๋ค.
make_shortcut_layer
layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2)
{
fprintf(stderr, "res %3d %4d x%4d x%4d -> %4d x%4d x%4d\n",index, w2,h2,c2, w,h,c);
layer l = {0};
l.type = SHORTCUT;
l.batch = batch;
l.w = w2;
l.h = h2;
l.c = c2;
l.out_w = w;
l.out_h = h;
l.out_c = c;
l.outputs = w*h*c;
l.inputs = l.outputs;
l.index = index;
l.delta = calloc(l.outputs*batch, sizeof(float));
l.output = calloc(l.outputs*batch, sizeof(float));;
l.forward = forward_shortcut_layer;
l.backward = backward_shortcut_layer;
return l;
}
ํจ์ ์ด๋ฆ: make_shortcut_layer
์ ๋ ฅ:
batch: ๋ฐฐ์น ํฌ๊ธฐ
index: ๋ ์ด์ด ์ธ๋ฑ์ค
w: ์ ๋ ฅ ์ด๋ฏธ์ง ๊ฐ๋ก ํฌ๊ธฐ
h: ์ ๋ ฅ ์ด๋ฏธ์ง ์ธ๋ก ํฌ๊ธฐ
c: ์ ๋ ฅ ์ด๋ฏธ์ง ์ฑ๋ ์
w2: shortcut ์ฐ๊ฒฐ๋๋ ๋ ์ด์ด์ ๊ฐ๋ก ํฌ๊ธฐ
h2: shortcut ์ฐ๊ฒฐ๋๋ ๋ ์ด์ด์ ์ธ๋ก ํฌ๊ธฐ
c2: shortcut ์ฐ๊ฒฐ๋๋ ๋ ์ด์ด์ ์ฑ๋ ์
๋์:
shortcut ๋ ์ด์ด๋ฅผ ์์ฑํ๊ณ , ํ๋ ๊ฐ๋ค์ ์ด๊ธฐํํ๋ค.
์ค๋ช :
shortcut ๋ ์ด์ด๋ skip connection์ ๊ตฌํํ๋ ๋ฐ ์ฌ์ฉ๋๋ ๋ ์ด์ด์ด๋ค.
์ ๋ ฅ ์ด๋ฏธ์ง์ ํฌ๊ธฐ์ shortcut์ผ๋ก ์ฐ๊ฒฐ๋๋ ๋ ์ด์ด์ ์ถ๋ ฅ ํฌ๊ธฐ๊ฐ ๊ฐ์ ๊ฒฝ์ฐ์ ์ฌ์ฉ๋๋ค.
์ถ๋ ฅ ํฌ๊ธฐ๋ ์ ๋ ฅ ์ด๋ฏธ์ง์ ํฌ๊ธฐ์ ๊ฐ๊ณ , ์ ๋ ฅ ์ด๋ฏธ์ง์ shortcut์ผ๋ก ์ฐ๊ฒฐ๋๋ ๋ ์ด์ด์ ์ถ๋ ฅ์ ๋ํ ๊ฒฐ๊ณผ๊ฐ ์ถ๋ ฅ๊ฐ์ด ๋๋ค.
l.delta์ l.output์ ๋ชจ๋ ์ถ๋ ฅ๊ฐ์ ์ ์ฅํ๋ ๋ฐฐ์ด์ด๋ค.
l.forward์ l.backward๋ ํด๋น ๋ ์ด์ด์์์ ์์ ํ์ ์ญ์ ํ ์ฐ์ฐ์ ์ํํ๋ ํจ์ ํฌ์ธํฐ์ด๋ค.
fprintf ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ํ์ฌ shortcut ๋ ์ด์ด์ ์ ๋ณด๋ฅผ ์ถ๋ ฅํ๋ค.
Last updated
Was this helpful?