cifar
train_cifar
ํจ์ ์ด๋ฆ: train_cifar
์ ๋ ฅ:
char *cfgfile (์ค์ ํ์ผ ๊ฒฝ๋ก)
char *weightfile (๊ฐ์ค์น ํ์ผ ๊ฒฝ๋ก)
๋์:
CIFAR-10 ๋ฐ์ดํฐ์ ์ ๋ํ ์ ๊ฒฝ๋ง์ ํ๋ จ์ํค๋ ํจ์์ด๋ค.
์ง์ ๋ ๊ตฌ์ฑ ํ์ผ๊ณผ ๊ฐ์ค์น ํ์ผ์ ์ฌ์ฉํ์ฌ ๋คํธ์ํฌ๋ฅผ ๋ก๋ํ๊ณ , SGD(ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ)๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จ์ ์ํํ๋ค.
์ฃผ๊ธฐ์ ์ผ๋ก ํ์ฌ ์์ค, ํ๊ท ์์ค, ํ์ต๋ฅ , ๊ฒฝ๊ณผ ์๊ฐ, ์ด๋ฏธ์ง ์ ๋ฑ์ ์ถ๋ ฅํ๊ณ , ์ง์ ๋ ๋ฐฐ์น ์ ๋๋ ์ต๋ ๋ฐฐ์น ์์ ๋๋ฌํ๋ฉด ํ๋ จ์ ์ค์งํ๋ค.
๋ํ ํ๋ จ ์ค์ ์ง์ ๋ ๋ฐฑ์ ๋๋ ํ ๋ฆฌ์ ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ์ ์ฅํ๋ค.
์ค๋ช :
backup_directory: ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ์ ์ฅํ ๋ฐฑ์ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
classes: ๋ถ๋ฅ ํด๋์ค ์ (CIFAR-10์ ๊ฒฝ์ฐ 10)
N: ํ์ต ์ธํธ ์ด๋ฏธ์ง ์ (CIFAR-10์ ๊ฒฝ์ฐ 50000)
labels: ํด๋์ค ๋ ์ด๋ธ ๋ฐฐ์ด ํฌ์ธํฐ
epoch: ํ์ฌ ํ์ต epoch ์
train: CIFAR-10 ๋ฐ์ดํฐ์ ์ ํ์ต ๋ฐ์ดํฐ๋ฅผ ํฌํจํ๋ ๋ฐ์ดํฐ ๊ตฌ์กฐ์ฒด
avg_loss: ํ์ฌ๊น์ง์ ํ๊ท ์์ค๊ฐ
time: ํ์ฌ ๋ฐฐ์น์ ํ๋ จ ์๊ฐ
loss: ํ์ฌ ๋ฐฐ์น์ ์์ค๊ฐ
buff: ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ์ ์ฅํ ํ์ผ ์ด๋ฆ ๋ฐ ๊ฒฝ๋ก
net: ๋ก๋๋ ์ ๊ฒฝ๋ง ๊ตฌ์กฐ์ฒด
train_cifar_distill
ํจ์ ์ด๋ฆ: train_cifar_distill
์ ๋ ฅ:
cfgfile: char ํฌ์ธํฐ. ๋คํธ์ํฌ ์ค์ ํ์ผ์ ๊ฒฝ๋ก๋ฅผ ์ง์ ํ๋ค.
weightfile: char ํฌ์ธํฐ. ํ์ต๋ ๋คํธ์ํฌ ๊ฐ์ค์น ํ์ผ์ ๊ฒฝ๋ก๋ฅผ ์ง์ ํ๋ค.
๋์:
CIFAR-10 ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์ฌ ๋คํธ์ํฌ๋ฅผ ํ์ต์ํจ๋ค.
์ง๋ํ์ต(supervised learning)๋ ๋ชจ๋ธ์ ์์ธก ํ๋ฅ ๊ฐ(์ํํธ๋งฅ์ค ์ถ๋ ฅ)์ ์์๋ธ(ensemble) ๋ชจ๋ธ์ ์์ธก ํ๋ฅ ๊ฐ๊ณผ ๊ฒฐํฉํ์ฌ ์์ค์ ๊ณ์ฐํ๊ณ ์ญ์ ํ ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ์ฌ ๋คํธ์ํฌ ๊ฐ์ค์น๋ฅผ ์ ๋ฐ์ดํธํ๋ค.
์ง์ ๋ ๋ฐฐ์น(batch) ์(max_batches) ๋๋ ๋ฌด์ ํ ๋ฐ๋ณต(iteration)์ ์คํํ๋ฉฐ, ํ์ต ๊ณผ์ ์์ ๋คํธ์ํฌ์ ๊ฐ์ค์น๋ฅผ ์ฃผ๊ธฐ์ ์ผ๋ก ์ ์ฅํ๊ณ , ํ์ต ์๋(learning rate), ๋ชจ๋ฉํ (momentum), ๊ฐ์ค์น ๊ฐ์ (decay) ๋ฑ์ ์ ๋ณด๋ฅผ ์ถ๋ ฅํ๋ค.
์ค๋ช :
์ด ํจ์๋ CIFAR-10 ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์ฌ distillation ๊ธฐ๋ฒ์ ์ฌ์ฉํ์ฌ ๋คํธ์ํฌ๋ฅผ ํ์ต์ํค๋ ํจ์์ด๋ค.
distillation ๊ธฐ๋ฒ์ ์์ ๋ชจ๋ธ(ํ์)์ ํ์ต์ํค๋๋ฐ, ํฐ ๋ชจ๋ธ(์ ์)์ ์์ธก๊ฐ์ ์ฌ์ฉํ์ฌ ํ์ต์ํค๋ ๋ฐฉ๋ฒ์ด๋ค.
์ด ํจ์๋ ์ง๋ํ์ต๋ ๋ชจ๋ธ(์ ์)๊ณผ ์์๋ธ ๋ชจ๋ธ(์ ์)์ ์์ธก๊ฐ์ ๊ฒฐํฉํ์ฌ distillation ๊ธฐ๋ฒ์ ์ฌ์ฉํ์ฌ ์์ ๋ชจ๋ธ(ํ์)์ ํ์ต์ํจ๋ค.
์ด ํจ์๋ ์ฃผ์ด์ง ๋คํธ์ํฌ ์ค์ ํ์ผ๊ณผ ํ์ต๋ ๋คํธ์ํฌ ๊ฐ์ค์น ํ์ผ์ ์ฌ์ฉํ์ฌ ๋คํธ์ํฌ๋ฅผ ๋ก๋ํ๊ณ , CIFAR-10 ๋ฐ์ดํฐ์ ์ ๋ถ๋ฌ์์ ๋คํธ์ํฌ๋ฅผ ํ์ต์ํจ๋ค.
ํ์ต ์ค์๋ ์ง์ ๋ ๋ฐฐ์น ์(max_batches) ๋๋ ๋ฌด์ ํ ๋ฐ๋ณต(iteration)์ ์คํํ๋ฉฐ, ๋คํธ์ํฌ์ ๊ฐ์ค์น๋ฅผ ์ฃผ๊ธฐ์ ์ผ๋ก ์ ์ฅํ๊ณ , ํ์ต ์๋(learning rate), ๋ชจ๋ฉํ (momentum), ๊ฐ์ค์น ๊ฐ์ (decay) ๋ฑ์ ์ ๋ณด๋ฅผ ์ถ๋ ฅํ๋ค.
์ด ํจ์๋ ํ์ต ์ค์ ์์ค์ ๊ณ์ฐํ ๋ ์์๋ธ ๋ชจ๋ธ์ ์์ธก ํ๋ฅ ๊ฐ๊ณผ ์ง๋ํ์ต๋ ๋ชจ๋ธ์ ์์ธก ํ๋ฅ ๊ฐ์ ๊ฒฐํฉํ์ฌ ์ฌ์ฉํ๋ค.
์ด ํจ์๋ ํ์ต์ด ์๋ฃ๋ ํ, ์ฌ์ฉํ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ชจ๋ ํด์ ํ๋ค.
test_cifar_multi
ํจ์ ์ด๋ฆ: test_cifar_multi
์ ๋ ฅ:
filename: ํ ์คํธํ ๋ชจ๋ธ์ ์ค์ ํ์ผ ๊ฒฝ๋ก
weightfile: ํ ์คํธํ ๋ชจ๋ธ์ ๊ฐ์ค์น ํ์ผ ๊ฒฝ๋ก
๋์:
์ค์ ํ์ผ๊ณผ ๊ฐ์ค์น ํ์ผ์ ์ด์ฉํด ๋ชจ๋ธ์ ๋ก๋ํ๋ค.
๋ฐฐ์น ํฌ๊ธฐ๋ฅผ 1๋ก ์ค์ ํ๋ค.
CIFAR-10 ๋ฐ์ดํฐ์ ์ ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ๋ค.
๊ฐ๊ฐ์ ํ ์คํธ ์ด๋ฏธ์ง์ ๋ํด ๋ค์์ ์ํํ๋ค:
ํ ์คํธ ์ด๋ฏธ์ง๋ฅผ ๋คํธ์ํฌ์ ์ ๋ ฅ์ผ๋ก ๋ฃ๊ณ , ์ถ๋ ฅ๊ฐ์ ๊ฐ์ ธ์จ๋ค.
์ด๋ฏธ์ง๋ฅผ ์ข์ฐ๋ก ๋ค์ง์ด ๋ค์ ํ ๋ฒ ๋คํธ์ํฌ์ ์ ๋ ฅ์ผ๋ก ๋ฃ๊ณ , ์ถ๋ ฅ๊ฐ์ ๊ฐ์ ธ์จ๋ค.
๋ ๋ฒ์ ์ถ๋ ฅ๊ฐ์ ํ๊ท ๋ด์ด ์์ธก๊ฐ์ ๊ณ์ฐํ๋ค.
์์ธก๊ฐ๊ณผ ์ค์ ๋ ์ด๋ธ์ ๋น๊ตํ์ฌ ์ ํ๋๋ฅผ ๊ณ์ฐํ๊ณ , ์ด๋ฅผ ๋์ ํ๋ค.
์ด๋ฏธ์ง ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํด์ ํ๋ค.
ํ์ฌ๊น์ง์ ์ ํ๋๋ฅผ ์ถ๋ ฅํ๋ค.
์ค๋ช :
์ด ํจ์๋ ๋ก๋ํ ๋ชจ๋ธ์ ์ด์ฉํด CIFAR-10 ๋ฐ์ดํฐ์ ์ ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ํ๊ฐํ๋ ์ญํ ์ ํ๋ค.
๊ฐ๊ฐ์ ํ ์คํธ ์ด๋ฏธ์ง์ ๋ํด ๋ ๋ฒ์ ์์ธก๊ฐ์ ๊ณ์ฐํ ๋ค ํ๊ท ์ ๋ด์ด ์ต์ข ์์ธก๊ฐ์ ๊ณ์ฐํ๊ณ , ์ด๋ฅผ ์ค์ ๋ ์ด๋ธ๊ณผ ๋น๊ตํ์ฌ ์ ํ๋๋ฅผ ๊ณ์ฐํ๋ค.
์ด๋ ๊ฒ ๊ณ์ฐ๋ ์ ํ๋๋ ๊ฐ๊ฐ์ ์ด๋ฏธ์ง์์ ๋์ ๋์ด ์ ์ฒด ํ ์คํธ ๋ฐ์ดํฐ์ ์ ๋ํ ํ๊ท ์ ํ๋๋ฅผ ๊ณ์ฐํ๊ฒ ๋๋ค.
test_cifar
ํจ์ ์ด๋ฆ: test_cifar
์ ๋ ฅ:
filename: char* ํ์ ์ ํ์ผ ์ด๋ฆ (๋คํธ์ํฌ ๊ตฌ์กฐ๊ฐ ์ ์ฅ๋ ํ์ผ)
weightfile: char* ํ์ ์ ํ์ผ ์ด๋ฆ (๋คํธ์ํฌ ๊ฐ์ค์น๊ฐ ์ ์ฅ๋ ํ์ผ)
๋์:
์ ๋ ฅ์ผ๋ก ๋ฐ์ ํ์ผ์์ ๋คํธ์ํฌ ๊ตฌ์กฐ์ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ๊ณ , ์ด๋ฅผ ์ด์ฉํด cifar-10 ๋ฐ์ดํฐ์ ์ ์ ํ๋๋ฅผ ํ๊ฐํฉ๋๋ค.
ํ๊ฐ ๋ฐฉ์์ top-1๊ณผ top-5 ์ ํ๋๋ฅผ ์ธก์ ํฉ๋๋ค.
์ธก์ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๊ณ , ๋ง์ง๋ง์ผ๋ก ์ฌ์ฉํ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํด์ ํฉ๋๋ค.
์ค๋ช :
๋คํธ์ํฌ๋ฅผ ๋ก๋ํ๊ณ srand๋ฅผ ์ด์ฉํด ๋์ ๋ฐ์๊ธฐ ์ด๊ธฐํํฉ๋๋ค.
์๊ฐ์ ์ธก์ ํ๊ธฐ ์ํด clock() ํจ์๋ฅผ ์ฌ์ฉํฉ๋๋ค.
load_cifar10_data ํจ์๋ฅผ ์ฌ์ฉํด cifar-10 ๋ฐ์ดํฐ์ ์ ๋ก๋ํฉ๋๋ค.
network_accuracies ํจ์๋ฅผ ์ฌ์ฉํด cifar-10 ๋ฐ์ดํฐ์ ์ top-1๊ณผ top-5 ์ ํ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค.
์ธก์ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๊ณ , ์ฌ์ฉํ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํด์ ํฉ๋๋ค.
extract_cifar
ํจ์ ์ด๋ฆ: extract_cifar
์ ๋ ฅ:
์์
๋์:
CIFAR-10 ๋ฐ์ดํฐ์ ์์ ์ด๋ฏธ์ง๋ฅผ ์ถ์ถํ์ฌ ํด๋์ค ๋ ์ด๋ธ๊ณผ ํจ๊ป ์ ์ฅํฉ๋๋ค.
ํ๋ จ ์ด๋ฏธ์ง๋ 'data/cifar/train' ํด๋์ ์ ์ฅ๋๊ณ , ํ ์คํธ ์ด๋ฏธ์ง๋ 'data/cifar/test' ํด๋์ ์ ์ฅ๋ฉ๋๋ค.
์ด๋ฏธ์ง ํ์ผ ์ด๋ฆ์ ๊ฐ๊ฐ '์ธ๋ฑ์ค_ํด๋์ค๋ช ' ํ์์ผ๋ก ์ง์ ๋ฉ๋๋ค.
์ค๋ช :
labels: ํด๋์ค ๋ ์ด๋ธ์ ์ ์ฅํ๋ ๋ฌธ์์ด ๋ฐฐ์ด
train: ๋ชจ๋ CIFAR-10 ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ๋ ๋ฐ ์ฌ์ฉ๋๋ data ๊ตฌ์กฐ์ฒด
test: CIFAR-10 ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ๋ ๋ฐ ์ฌ์ฉ๋๋ data ๊ตฌ์กฐ์ฒด
for ๋ฃจํ๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จ ๋ฐ ํ ์คํธ ๋ฐ์ดํฐ์ ์์ ์ด๋ฏธ์ง๋ฅผ ์ถ์ถํ๊ณ , ํด๋น ์ด๋ฏธ์ง์ ํด๋์ค ๋ ์ด๋ธ์ ๊ฐ์ ธ์์ ์ด๋ฏธ์ง๋ฅผ ์ ์ฅํฉ๋๋ค.
sprintf ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง ํ์ผ ์ด๋ฆ์ ์ง์ ํฉ๋๋ค.
test_cifar_csv
ํจ์ ์ด๋ฆ: test_cifar_csv
์ ๋ ฅ:
filename (char*): ๋คํธ์ํฌ ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก
weightfile (char*): ํ์ต๋ ๊ฐ์ค์น ํ์ผ ๊ฒฝ๋ก
๋์:
์ง์ ๋ ๋คํธ์ํฌ ๋ชจ๋ธ ํ์ผ๊ณผ ๊ฐ์ค์น ํ์ผ์ ๋ก๋ํ์ฌ ๋คํธ์ํฌ๋ฅผ ์์ฑํฉ๋๋ค.
์๋ ๊ฐ์ ํ์ฌ ์๊ฐ์ผ๋ก ์ค์ ํ์ฌ ๋์ ์์ฑ๊ธฐ๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
CIFAR-10 ๋ฐ์ดํฐ์ ์ ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํฉ๋๋ค.
๋คํธ์ํฌ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ์์ธกํ๊ณ , ์์ธก ๊ฒฐ๊ณผ๋ฅผ matrix ํ์์ผ๋ก ๋ฐํํฉ๋๋ค.
ํ ์คํธ ๋ฐ์ดํฐ์ ์ด๋ฏธ์ง๋ฅผ ๋ฐ์ ์ํค๊ณ , ๋ค์ ํ ๋ฒ ๋คํธ์ํฌ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฐํํฉ๋๋ค.
์ฒซ ๋ฒ์งธ ์์ธก ๊ฒฐ๊ณผ์ ๋ ๋ฒ์งธ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ํฉ์ฐํ๊ณ , ์ด๋ฅผ csv ํ์ผ๋ก ์ ์ฅํฉ๋๋ค.
ํ ์คํธ ๋ฐ์ดํฐ์ ์ค์ ๋ ์ด๋ธ๊ณผ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๋น๊ตํ์ฌ top-1 ์ ํ๋๋ฅผ ๊ณ์ฐํ๊ณ , ํ์ค ์ค๋ฅ ์คํธ๋ฆผ(stderr)์ ์ถ๋ ฅํฉ๋๋ค.
๋ฉ๋ชจ๋ฆฌ ํ ๋น ํด์
์ค๋ช :
์ด ํจ์๋ CIFAR-10 ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์ฌ ๋ก๋ํ ํ ์คํธ ๋ฐ์ดํฐ์ ๋ํด ๋คํธ์ํฌ ๋ชจ๋ธ์ ์์ธก ๊ฒฐ๊ณผ๋ฅผ csv ํ์ผ๋ก ์ ์ฅํ๊ณ , ์ด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ ํ๋๋ฅผ ๊ณ์ฐํ๋ ์ญํ ์ ํฉ๋๋ค.
ํจ์๋ load_network(), load_cifar10_data(), network_predict_data(), matrix_topk_accuracy(), free_data() ํจ์ ๋ฑ์ ์ฌ์ฉํ์ฌ ๋์ํฉ๋๋ค.
test_cifar_csvtrain
ํจ์ ์ด๋ฆ: test_cifar_csvtrain
์ ๋ ฅ:
cfg (char*): ๋คํธ์ํฌ ๋ชจ๋ธ ์ค์ ํ์ผ ๊ฒฝ๋ก
weights (char*): ํ์ต๋ ๊ฐ์ค์น ํ์ผ ๊ฒฝ๋ก
๋์:
์ง์ ๋ ๋คํธ์ํฌ ๋ชจ๋ธ ์ค์ ํ์ผ๊ณผ ๊ฐ์ค์น ํ์ผ์ ๋ก๋ํ์ฌ ๋คํธ์ํฌ๋ฅผ ์์ฑํฉ๋๋ค.
์๋ ๊ฐ์ ํ์ฌ ์๊ฐ์ผ๋ก ์ค์ ํ์ฌ ๋์ ์์ฑ๊ธฐ๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
CIFAR-10 ๋ฐ์ดํฐ์ ์ ์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํฉ๋๋ค.
๋คํธ์ํฌ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ์์ธกํ๊ณ , ์์ธก ๊ฒฐ๊ณผ๋ฅผ matrix ํ์์ผ๋ก ๋ฐํํฉ๋๋ค.
์ ์ฒด ๋ฐ์ดํฐ์ ์ด๋ฏธ์ง๋ฅผ ๋ฐ์ ์ํค๊ณ , ๋ค์ ํ ๋ฒ ๋คํธ์ํฌ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ฐํํฉ๋๋ค.
์ฒซ ๋ฒ์งธ ์์ธก ๊ฒฐ๊ณผ์ ๋ ๋ฒ์งธ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ํฉ์ฐํ๊ณ , ์ด๋ฅผ csv ํ์ผ๋ก ์ ์ฅํฉ๋๋ค.
์ ์ฒด ๋ฐ์ดํฐ์ ์ค์ ๋ ์ด๋ธ๊ณผ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๋น๊ตํ์ฌ top-1 ์ ํ๋๋ฅผ ๊ณ์ฐํ๊ณ , ํ์ค ์ค๋ฅ ์คํธ๋ฆผ(stderr)์ ์ถ๋ ฅํฉ๋๋ค.
๋ฉ๋ชจ๋ฆฌ ํ ๋น ํด์
์ค๋ช :
์ด ํจ์๋ CIFAR-10 ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์ฌ ์ ์ฒด ๋ฐ์ดํฐ์ ๋ํด ๋คํธ์ํฌ ๋ชจ๋ธ์ ์์ธก ๊ฒฐ๊ณผ๋ฅผ csv ํ์ผ๋ก ์ ์ฅํ๊ณ , ์ด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ ํ๋๋ฅผ ๊ณ์ฐํ๋ ์ญํ ์ ํฉ๋๋ค.
ํจ์๋ load_network(), load_all_cifar10(), network_predict_data(), matrix_topk_accuracy(), free_data() ํจ์ ๋ฑ์ ์ฌ์ฉํ์ฌ ๋์ํฉ๋๋ค.
eval_cifar_csv
ํจ์ ์ด๋ฆ: eval_cifar_csv
์ ๋ ฅ:
์์
๋์:
CIFAR-10 ๋ฐ์ดํฐ์ ์ ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํฉ๋๋ค.
csv ํ์ผ๋ก ์ ์ฅ๋ ์์ธก ๊ฒฐ๊ณผ๋ฅผ matrix ํ์์ผ๋ก ๋ก๋ํฉ๋๋ค.
์์ธก ๊ฒฐ๊ณผ์ ํ๊ณผ ์ด ๊ฐ์๋ฅผ ์ถ๋ ฅํฉ๋๋ค.
์์ธก ๊ฒฐ๊ณผ์ ์ค์ ๋ ์ด๋ธ์ ๋น๊ตํ์ฌ top-1 ์ ํ๋๋ฅผ ๊ณ์ฐํ๊ณ , ํ์ค ์ค๋ฅ ์คํธ๋ฆผ(stderr)์ ์ถ๋ ฅํฉ๋๋ค.
๋ฉ๋ชจ๋ฆฌ ํ ๋น ํด์
์ค๋ช :
์ด ํจ์๋ CIFAR-10 ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์ฌ ์ ์ฅ๋ csv ํ์ผ์ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๋ก๋ํ๊ณ , ์ด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก top-1 ์ ํ๋๋ฅผ ๊ณ์ฐํ์ฌ ์ถ๋ ฅํ๋ ์ญํ ์ ํฉ๋๋ค.
ํจ์๋ load_cifar10_data(), csv_to_matrix(), matrix_topk_accuracy(), free_data(), free_matrix() ํจ์ ๋ฑ์ ์ฌ์ฉํ์ฌ ๋์ํฉ๋๋ค.
run_cifar
ํจ์ ์ด๋ฆ: run_cifar
์ ๋ ฅ:
int argc: ์ ๋ ฅ ์ธ์์ ๊ฐ์
char **argv: ์ ๋ ฅ ์ธ์์ ๋ฐฐ์ด ํฌ์ธํฐ
๋์:
์ ๋ ฅ ์ธ์์ ๊ฐ์๊ฐ 4๋ณด๋ค ์์ผ๋ฉด ์ฌ์ฉ ๋ฐฉ๋ฒ์ ์ถ๋ ฅํ๊ณ ํจ์๋ฅผ ์ข ๋ฃํฉ๋๋ค.
3๋ฒ์งธ ์ ๋ ฅ ์ธ์๋ฅผ cfg ๋ณ์์ ์ ์ฅํฉ๋๋ค.
4๋ฒ์งธ ์ ๋ ฅ ์ธ์๊ฐ ์กด์ฌํ๋ฉด weights ๋ณ์์ ์ ์ฅํฉ๋๋ค.
2๋ฒ์งธ ์ ๋ ฅ ์ธ์์ ๋ฐ๋ผ ๋ค์ ํจ์ ์ค ํ๋๋ฅผ ํธ์ถํฉ๋๋ค.
train_cifar()
extract_cifar()
train_cifar_distill()
test_cifar()
test_cifar_multi()
test_cifar_csv()
test_cifar_csvtrain()
eval_cifar_csv()
๊ฐ ํจ์๋ CIFAR-10 ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ๋ จํ๊ณ , ์์ธก ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๊ฑฐ๋, ์์ธก ๊ฒฐ๊ณผ๋ฅผ csv ํ์ผ๋ก ์ ์ฅํ๊ฑฐ๋, ์ ์ฅ๋ csv ํ์ผ์ ๋ก๋ํ์ฌ ์ ํ๋๋ฅผ ์ถ๋ ฅํ๋ ๋ฑ์ ๋์์ ์ํํฉ๋๋ค.
์ค๋ช :
์ด ํจ์๋ CIFAR-10 ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ์ฌ ๋ค์ํ ๋์์ ์ํํ๋ ํจ์๋ค์ ํธ์ถํ๋ ์ญํ ์ ํฉ๋๋ค.
ํจ์๋ ์ธ์๋ก ๋ฐ์ ์ ๋ ฅ ์ธ์์ ๊ฐ์์ ๋ฐฐ์ด ํฌ์ธํฐ๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ ๋์์ ํ์ํ cfg ํ์ผ๊ณผ weights ํ์ผ์ ๊ฒฐ์ ํ๊ณ , ์ด๋ฅผ ์ด์ฉํ์ฌ train_cifar(), extract_cifar(), train_cifar_distill(), test_cifar(), test_cifar_multi(), test_cifar_csv(), test_cifar_csvtrain(), eval_cifar_csv() ํจ์ ์ค ์ ์ ํ ํจ์๋ฅผ ํธ์ถํฉ๋๋ค.
์ด ํจ์๋ ๋ช ๋ นํ ์ธ์๋ฅผ ์ฒ๋ฆฌํ๋ ๋ฐ์ ์ฌ์ฉ๋๋ main() ํจ์์์ ํธ์ถ๋ฉ๋๋ค.
Last updated
Was this helpful?