shortcut_layer

shortcut layer ๋ž€?

ResNet์—์„œ ์ œ์•ˆ๋œ skip connection๊ณผ ์œ ์‚ฌํ•ฉ๋‹ˆ๋‹ค.

์ž ์‹œ ์ถœ๋ ฅ์„ ์ €์žฅํ•˜๊ณ  ๊ทธ ํ›„์— layer์˜ ์ถœ๋ ฅ๊ณผ ํ•ฉ์น˜๋Š” ์ž‘์—…์—์„œ ์‚ฌ์šฉ ๋ฉ๋‹ˆ๋‹ค.

shortcut.c

forward_shortcut_layer

void forward_shortcut_layer(const layer l, network net)
{
    copy_cpu(l.outputs*l.batch, net.input, 1, l.output, 1);                                                                  // network input -> layer output
    shortcut_cpu(l.batch, l.w, l.h, l.c, net.layers[l.index].output, l.out_w, l.out_h, l.out_c, l.alpha, l.beta, l.output);  // layer output += i-th layer output
    activate_array(l.output, l.outputs*l.batch, l.activation);
}

ํ•จ์ˆ˜ ์ด๋ฆ„: forward_shortcut_layer

์ž…๋ ฅ:

  • const layer l: ํ˜„์žฌ layer ์ •๋ณด

  • network net: ํ˜„์žฌ network ์ •๋ณด

๋™์ž‘:

  • ํ˜„์žฌ layer์˜ ์ถœ๋ ฅ๊ฐ’์œผ๋กœ ๋„คํŠธ์›Œํฌ ์ž…๋ ฅ๊ฐ’์„ ๋ณต์‚ฌ

  • ํ˜„์žฌ layer์˜ ์ถœ๋ ฅ๊ฐ’์— shortcut ์—ฐ๊ฒฐ๋œ ์ด์ „ layer์˜ ์ถœ๋ ฅ๊ฐ’์„ ๋”ํ•ด์คŒ

  • ํ˜„์žฌ layer์˜ ์ถœ๋ ฅ๊ฐ’์— ํ™œ์„ฑํ™” ํ•จ์ˆ˜๋ฅผ ์ ์šฉ

์„ค๋ช…:

  • shortcut ์—ฐ๊ฒฐ์„ ํ†ตํ•ด ๋‹ค๋ฅธ layer์˜ ์ถœ๋ ฅ๊ฐ’์„ ํ˜„์žฌ layer์˜ ์ถœ๋ ฅ๊ฐ’์— ๋”ํ•ด์คŒ์œผ๋กœ์จ, ๋„คํŠธ์›Œํฌ์˜ ํ•™์Šต ํšจ์œจ์„ฑ์„ ๋†’์ด๊ธฐ ์œ„ํ•œ ๋ ˆ์ด์–ด

  • forward_shortcut_layer ํ•จ์ˆ˜๋Š” ํ•ด๋‹น layer์˜ forward propagation์„ ์ˆ˜ํ–‰ํ•˜๋ฉฐ, ์ž…๋ ฅ๊ฐ’์„ ํ˜„์žฌ layer์˜ ์ถœ๋ ฅ๊ฐ’์œผ๋กœ ๋ณต์‚ฌํ•˜๊ณ  shortcut ์—ฐ๊ฒฐ๋œ ์ด์ „ layer์˜ ์ถœ๋ ฅ๊ฐ’์„ ๋”ํ•ด์ฃผ๋ฉฐ ํ™œ์„ฑํ™” ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•˜๋Š” ์—ญํ• ์„ ์ˆ˜ํ–‰ํ•จ

backward_shortcut_layer

void backward_shortcut_layer(const layer l, network net)
{
    gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);                                                       // layer delta -> activation grad
    axpy_cpu(l.outputs*l.batch, l.alpha, l.delta, 1, net.delta, 1);                                                           // network delta += alpha * layer delta
    shortcut_cpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta, l.w, l.h, l.c, 1, l.beta, net.layers[l.index].delta);           // i-th layer delta += layer delta
}

ํ•จ์ˆ˜ ์ด๋ฆ„: backward_shortcut_layer

์ž…๋ ฅ:

  • const layer l: shortcut layer์˜ ์ •๋ณด๋ฅผ ๋‹ด๊ณ  ์žˆ๋Š” ๊ตฌ์กฐ์ฒด

  • network net: ์‹ ๊ฒฝ๋ง์„ ๊ตฌ์„ฑํ•˜๋Š” layer๋“ค์˜ ์ •๋ณด๋ฅผ ๋‹ด๊ณ  ์žˆ๋Š” ๊ตฌ์กฐ์ฒด

๋™์ž‘:

  • layer output์˜ activation gradient๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ layer delta์— ์ €์žฅํ•œ๋‹ค.

  • network delta์— alpha๊ฐ’๊ณผ layer delta๊ฐ’์„ ๊ณฑํ•˜์—ฌ ๋”ํ•ด์ค€๋‹ค.

  • i-th layer delta์—๋Š” beta๊ฐ’๊ณผ layer delta๊ฐ’์„ ๊ณฑํ•˜์—ฌ ๋”ํ•ด์ค€๋‹ค.

์„ค๋ช…:

  • Shortcut layer๋Š” ์ž…๋ ฅ๊ฐ’๊ณผ ์ด์ „ layer์˜ ์ถœ๋ ฅ๊ฐ’์„ ๋”ํ•˜์—ฌ ์ถœ๋ ฅ๊ฐ’์„ ๋งŒ๋“ค์–ด๋‚ธ๋‹ค.

  • ๋”ฐ๋ผ์„œ forward pass์—์„œ๋Š” ์ด์ „ layer์˜ ์ถœ๋ ฅ๊ฐ’์„ ํ˜„์žฌ layer์˜ ์ž…๋ ฅ๊ฐ’๊ณผ ๋”ํ•˜์—ฌ ์ถœ๋ ฅ๊ฐ’์„ ๊ณ„์‚ฐํ•˜๊ฒŒ ๋œ๋‹ค.

  • Backward pass์—์„œ๋Š” ํ˜„์žฌ layer์˜ ์ถœ๋ ฅ๊ฐ’์— ๋Œ€ํ•œ activation gradient๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ , ์ด์ „ layer์˜ delta๊ฐ’์—๋„ ํ˜„์žฌ layer์˜ delta๊ฐ’์„ ๋”ํ•˜์—ฌ ์ „ํŒŒํ•˜๊ฒŒ ๋œ๋‹ค.

resize_shortcut_layer

void resize_shortcut_layer(layer *l, int w, int h)
{
    assert(l->w == l->out_w);
    assert(l->h == l->out_h);
    l->w = l->out_w = w;
    l->h = l->out_h = h;
    l->outputs = w*h*l->out_c;
    l->inputs = l->outputs;
    l->delta =  realloc(l->delta, l->outputs*l->batch*sizeof(float));
    l->output = realloc(l->output, l->outputs*l->batch*sizeof(float));
}

ํ•จ์ˆ˜ ์ด๋ฆ„: resize_shortcut_layer

์ž…๋ ฅ:

  • layer *l: ํฌ๊ธฐ๋ฅผ ์กฐ์ •ํ•  shortcut layer์˜ ํฌ์ธํ„ฐ

  • int w: ์ƒˆ๋กœ์šด ๋„ˆ๋น„

  • int h: ์ƒˆ๋กœ์šด ๋†’์ด

๋™์ž‘:

  • l์˜ w์™€ out_w๊ฐ€ ๊ฐ™์•„์•ผ ํ•จ์„ ํ™•์ธ(assert)

  • l์˜ h์™€ out_h๊ฐ€ ๊ฐ™์•„์•ผ ํ•จ์„ ํ™•์ธ(assert)

  • l์˜ w์™€ out_w๋ฅผ w๋กœ ์—…๋ฐ์ดํŠธ

  • l์˜ h์™€ out_h๋ฅผ h๋กœ ์—…๋ฐ์ดํŠธ

  • l์˜ outputs๋ฅผ w, h, out_c์˜ ๊ณฑ์œผ๋กœ ์—…๋ฐ์ดํŠธ

  • l์˜ inputs๋ฅผ outputs์™€ ๊ฐ™๊ฒŒ ์—…๋ฐ์ดํŠธ

  • l์˜ delta ๋ฉ”๋ชจ๋ฆฌ๋ฅผ outputs * batch ํฌ๊ธฐ๋งŒํผ ์žฌํ• ๋‹น

  • l์˜ output ๋ฉ”๋ชจ๋ฆฌ๋ฅผ outputs * batch ํฌ๊ธฐ๋งŒํผ ์žฌํ• ๋‹น

์„ค๋ช…:

  • ์ด ํ•จ์ˆ˜๋Š” shortcut layer์˜ ํฌ๊ธฐ๋ฅผ ์กฐ์ •ํ•˜๋Š” ์—ญํ• ์„ ํ•œ๋‹ค.

  • shortcut layer๋Š” input๊ณผ output์˜ ํฌ๊ธฐ๊ฐ€ ๊ฐ™์•„์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— l์˜ w์™€ out_w, h์™€ out_h๊ฐ€ ๊ฐ™์€์ง€ ํ™•์ธํ•˜๊ณ  ๊ฐ™์ง€ ์•Š์œผ๋ฉด ์—๋Ÿฌ๋ฅผ ๋ฐœ์ƒ์‹œํ‚จ๋‹ค.

  • ๊ทธ ํ›„ w์™€ h๋กœ ๊ฐ๊ฐ ํฌ๊ธฐ๋ฅผ ์กฐ์ •ํ•ด์ฃผ๊ณ , outputs์™€ inputs๋ฅผ ์—…๋ฐ์ดํŠธํ•œ๋‹ค.

  • ๋งˆ์ง€๋ง‰์œผ๋กœ, delta์™€ output ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ƒˆ๋กœ์šด outputs * batch ํฌ๊ธฐ๋กœ ์žฌํ• ๋‹นํ•œ๋‹ค.

make_shortcut_layer

layer make_shortcut_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2)
{
    fprintf(stderr, "res  %3d                %4d x%4d x%4d   ->  %4d x%4d x%4d\n",index, w2,h2,c2, w,h,c);
    layer l = {0};
    l.type = SHORTCUT;
    l.batch = batch;
    l.w = w2;
    l.h = h2;
    l.c = c2;
    l.out_w = w;
    l.out_h = h;
    l.out_c = c;
    l.outputs = w*h*c;
    l.inputs = l.outputs;

    l.index = index;

    l.delta =  calloc(l.outputs*batch, sizeof(float));
    l.output = calloc(l.outputs*batch, sizeof(float));;

    l.forward = forward_shortcut_layer;
    l.backward = backward_shortcut_layer;

    return l;
}

ํ•จ์ˆ˜ ์ด๋ฆ„: make_shortcut_layer

์ž…๋ ฅ:

  • batch: ๋ฐฐ์น˜ ํฌ๊ธฐ

  • index: ๋ ˆ์ด์–ด ์ธ๋ฑ์Šค

  • w: ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๊ฐ€๋กœ ํฌ๊ธฐ

  • h: ์ž…๋ ฅ ์ด๋ฏธ์ง€ ์„ธ๋กœ ํฌ๊ธฐ

  • c: ์ž…๋ ฅ ์ด๋ฏธ์ง€ ์ฑ„๋„ ์ˆ˜

  • w2: shortcut ์—ฐ๊ฒฐ๋˜๋Š” ๋ ˆ์ด์–ด์˜ ๊ฐ€๋กœ ํฌ๊ธฐ

  • h2: shortcut ์—ฐ๊ฒฐ๋˜๋Š” ๋ ˆ์ด์–ด์˜ ์„ธ๋กœ ํฌ๊ธฐ

  • c2: shortcut ์—ฐ๊ฒฐ๋˜๋Š” ๋ ˆ์ด์–ด์˜ ์ฑ„๋„ ์ˆ˜

๋™์ž‘:

  • shortcut ๋ ˆ์ด์–ด๋ฅผ ์ƒ์„ฑํ•˜๊ณ , ํ•„๋“œ ๊ฐ’๋“ค์„ ์ดˆ๊ธฐํ™”ํ•œ๋‹ค.

์„ค๋ช…:

  • shortcut ๋ ˆ์ด์–ด๋Š” skip connection์„ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” ๋ ˆ์ด์–ด์ด๋‹ค.

  • ์ž…๋ ฅ ์ด๋ฏธ์ง€์˜ ํฌ๊ธฐ์™€ shortcut์œผ๋กœ ์—ฐ๊ฒฐ๋˜๋Š” ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ ํฌ๊ธฐ๊ฐ€ ๊ฐ™์€ ๊ฒฝ์šฐ์— ์‚ฌ์šฉ๋œ๋‹ค.

  • ์ถœ๋ ฅ ํฌ๊ธฐ๋Š” ์ž…๋ ฅ ์ด๋ฏธ์ง€์˜ ํฌ๊ธฐ์™€ ๊ฐ™๊ณ , ์ž…๋ ฅ ์ด๋ฏธ์ง€์™€ shortcut์œผ๋กœ ์—ฐ๊ฒฐ๋˜๋Š” ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ์„ ๋”ํ•œ ๊ฒฐ๊ณผ๊ฐ€ ์ถœ๋ ฅ๊ฐ’์ด ๋œ๋‹ค.

  • l.delta์™€ l.output์€ ๋ชจ๋‘ ์ถœ๋ ฅ๊ฐ’์„ ์ €์žฅํ•˜๋Š” ๋ฐฐ์—ด์ด๋‹ค.

  • l.forward์™€ l.backward๋Š” ํ•ด๋‹น ๋ ˆ์ด์–ด์—์„œ์˜ ์ˆœ์ „ํŒŒ์™€ ์—ญ์ „ํŒŒ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ํ•จ์ˆ˜ ํฌ์ธํ„ฐ์ด๋‹ค.

  • fprintf ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ˜„์žฌ shortcut ๋ ˆ์ด์–ด์˜ ์ •๋ณด๋ฅผ ์ถœ๋ ฅํ•œ๋‹ค.

Last updated

Was this helpful?