gemm
GEMM ์ด๋?
์ฐธ๊ณ ์๋ฃ : https://petewarden.com/2015/04/20/why-gemm-is-at-the-heart-of-deep-learning/
General Matrix to Matrix Multiplication
1979๋ ์ ๋ง๋ค์ด์ง BLAS ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ผ๋ถ ์ ๋๋ค.
๋๊ฐ์ ์ ๋ ฅ ํ๋ ฌ์ ๊ณฑํด์ ์ถ๋ ฅ์ ์ป๋ ๋ฐฉ๋ฒ ์ ๋๋ค.
๋ฅ๋ฌ๋์์ ๋๋ถ๋ถ์ ์ฐ์ฐ์ output = input * weight + bias
๋ก ํํ์ด ๋ฉ๋๋ค. ์ฌ๊ธฐ์ input
, output
, weight
๋ฅผ ํ๋ ฌ๋ก ํํํด์ GEMM์ ์ฌ์ฉํด ์ฐ์ฐํ ์ ์์ต๋๋ค.
Fully Connected Layer
fully connected layer
๋ ์์ ๊ฐ์ด ํํํ ์ ์์ต๋๋ค.
Convolutional Layer
im2col
: 3์ฐจ์ ์ด๋ฏธ์ง ๋ฐฐ์ด์ 2์ฐจ์ ๋ฐฐ์ด๋ก ๋ณํํฉ๋๋ค.
convolutional layer
๋ ์์ ๊ฐ์ด ํํํ ์ ์์ต๋๋ค. ์ ๊ทธ๋ฆผ์ ๊ฒฝ์ฐ๋ stride
๊ฐ kernel size
์ ๊ฐ์ ๊ฒฝ์ฐ๋ฅผ ์๋ฏธํฉ๋๋ค.
gemm.c
gemm
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float BETA,
float *C, int ldc)
{
gemm_cpu(TA, TB, M, N, K, ALPHA, A, lda, B, ldb, BETA, C, ldc);
}
ํจ์ ์ด๋ฆ: gemm
์ ๋ ฅ:
int TA: ํ๋ ฌ A์ ์ ์น ์ฌ๋ถ (0: ์ ์นํ์ง ์์, 1: ์ ์นํจ)
int TB: ํ๋ ฌ B์ ์ ์น ์ฌ๋ถ (0: ์ ์นํ์ง ์์, 1: ์ ์นํจ)
int M: ํ๋ ฌ C์ ํ์ ์
int N: ํ๋ ฌ C์ ์ด์ ์
int K: ํ๋ ฌ A์ ์ด์ ์ (ํ๋ ฌ B์ ํ์ ์์ ๊ฐ์์ผ ํจ)
float ALPHA: ์ค์นผ๋ผ ๊ฐ
float *A: ํ๋ ฌ A์ ํฌ์ธํฐ
int lda: ํ๋ ฌ A์ ํ ๋จ์ ํฌ๊ธฐ
float *B: ํ๋ ฌ B์ ํฌ์ธํฐ
int ldb: ํ๋ ฌ B์ ํ ๋จ์ ํฌ๊ธฐ
float BETA: ์ค์นผ๋ผ ๊ฐ
float *C: ํ๋ ฌ C์ ํฌ์ธํฐ
int ldc: ํ๋ ฌ C์ ํ ๋จ์ ํฌ๊ธฐ
๋์:
ํ๋ ฌ-ํ๋ ฌ ๊ณฑ์ ์ฐ์ฐ์ ์ํํจ.
์ค๋ช :
์ด ํจ์๋ CPU ์์์ ํ๋ ฌ-ํ๋ ฌ ๊ณฑ์ ์ฐ์ฐ์ ์ํํ๋ ํจ์์ด๋ค.
gemm_cpu ํจ์๋ฅผ ํธ์ถํ์ฌ ์ด ์ฐ์ฐ์ ์ํํ๋ค.
ํ๋ ฌ A์ ํ๋ ฌ B์ ํฌ๊ธฐ์ ์ ์น ์ฌ๋ถ, ์ค์นผ๋ผ ๊ฐ ALPHA์ BETA ๋ฑ์ ์ ๋ ฅ์ผ๋ก ๋ฐ๊ณ , ์ฐ์ฐ ๊ฒฐ๊ณผ์ธ ํ๋ ฌ C๋ฅผ ์ถ๋ ฅ์ผ๋ก ๋ฐํํ๋ค.
gemm_cpu
void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float BETA,
float *C, int ldc)
{
//printf("cpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc);
int i, j;
for(i = 0; i < M; ++i){
for(j = 0; j < N; ++j){
C[i*ldc + j] *= BETA;
}
}
if(!TA && !TB)
gemm_nn(M, N, K, ALPHA, A,lda, B, ldb, C, ldc);
else if(TA && !TB)
gemm_tn(M, N, K, ALPHA, A,lda, B, ldb, C, ldc);
else if(!TA && TB)
gemm_nt(M, N, K, ALPHA, A,lda, B, ldb, C, ldc);
else
gemm_tt(M, N, K, ALPHA, A,lda, B, ldb, C, ldc);
}
ํจ์ ์ด๋ฆ: gemm_cpu
์ ๋ ฅ:
int TA: A ํ๋ ฌ์ ์ ์น ์ฌ๋ถ๋ฅผ ๋ํ๋ด๋ ํ๋๊ทธ
int TB: B ํ๋ ฌ์ ์ ์น ์ฌ๋ถ๋ฅผ ๋ํ๋ด๋ ํ๋๊ทธ
int M: C ํ๋ ฌ์ ํ ์
int N: C ํ๋ ฌ์ ์ด ์
int K: A, B ํ๋ ฌ์์ ๊ณต์ ํ๋ ์ฐจ์์ ํฌ๊ธฐ
float ALPHA: A, B ํ๋ ฌ์ ๊ณฑ์ ๋ํ ๊ฐ์ค์น
float *A: A ํ๋ ฌ์ ํฌ์ธํฐ
int lda: A ํ๋ ฌ์ ํ ๋น ์์ ์
float *B: B ํ๋ ฌ์ ํฌ์ธํฐ
int ldb: B ํ๋ ฌ์ ํ ๋น ์์ ์
float BETA: C ํ๋ ฌ์ ๋ํ ๊ฐ์ค์น
float *C: C ํ๋ ฌ์ ํฌ์ธํฐ
int ldc: C ํ๋ ฌ์ ํ ๋น ์์ ์
๋์:
CPU์์ ํ๋ ฌ ๊ณฑ์ ์ฐ์ฐ์ ์ํํ๋ค.
A, B, C ์ธ ๊ฐ์ ํ๋ ฌ์ ์ธ์๋ก ๋ฐ๊ณ , A์ B์ ๊ณฑ์ ๊ฐ์ค์น ALPHA๋ฅผ ๊ณฑํ ๊ฒฐ๊ณผ๋ฅผ C ํ๋ ฌ์ ๋ํ๋ค.
์ค๋ช :
gemm_cpu ํจ์๋ CPU์์ ํ๋ ฌ ๊ณฑ์ ์ฐ์ฐ์ ์ํํ๋ค.
์ด ํจ์๋ A, B, C ์ธ ๊ฐ์ ํฌ์ธํฐ์ ๋ค์ํ ์ธ์๋ฅผ ๋ฐ์์, ํ๋ ฌ ๊ณฑ์ ์ฐ์ฐ ๊ฒฐ๊ณผ๋ฅผ C ํ๋ ฌ์ ์ ์ฅํ๋ค.
ํจ์ ๋ด๋ถ์์๋ TA์ TB ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ A์ B ํ๋ ฌ์ด ์ ์น๋์ด ์๋์ง ์ฌ๋ถ๋ฅผ ํ์ธํ๊ณ , ์ด์ ๋ฐ๋ผ gemm_nn, gemm_tn, gemm_nt, gemm_tt ํจ์ ์ค ํ๋๋ฅผ ํธ์ถํ๋ค.
์ด ํจ์๋ค์ ๋ค์ํ ํ๋ ฌ ๊ณฑ์ ์ฐ์ฐ ๋ฐฉ๋ฒ์ ๊ตฌํํ๊ณ ์๋ค.
๋ฐ๋ผ์ gemm_cpu ํจ์๋ ์ด๋ฅผ ์ด์ฉํ์ฌ ์ ๋ ฅ์ผ๋ก ๋ฐ์ ํ๋ ฌ A, B์ ๊ณฑ์ ๊ฐ์ค์น ALPHA๋ฅผ ๊ณฑํ ๊ฒฐ๊ณผ๋ฅผ C ํ๋ ฌ์ ๋ํ๋ค.
์ด ๋ BETA ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ ๊ธฐ์กด์ C ํ๋ ฌ ๊ฐ์ ๋ํ ๊ฐ์ค์น๋ฅผ ์กฐ์ ํ ์ ์๋ค.
gemm_nn
void gemm_nn(int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float *C, int ldc)
{
int i,j,k;
#pragma omp parallel for
for(i = 0; i < M; ++i){
for(k = 0; k < K; ++k){
register float A_PART = ALPHA*A[i*lda+k];
for(j = 0; j < N; ++j){
C[i*ldc+j] += A_PART*B[k*ldb+j];
}
}
}
}
ํจ์ ์ด๋ฆ: gemm_nn
์ ๋ ฅ:
M: ํ๋ ฌ A์ ํ์ ๊ฐ์
N: ํ๋ ฌ B์ ์ด์ ๊ฐ์
K: ํ๋ ฌ A์ ์ด์ ๊ฐ์ ๋๋ ํ๋ ฌ B์ ํ์ ๊ฐ์
ALPHA: ํ๋ ฌ A์ ํ๋ ฌ B์ ๊ณฑ์ ๊ฒฐ๊ณผ์ ๊ณฑํด์ง๋ ์ค์นผ๋ผ ๊ฐ
A: ํฌ๊ธฐ M x K์ ํ๋ ฌ A
lda: ํ๋ ฌ A์ leading dimension
B: ํฌ๊ธฐ K x N์ ํ๋ ฌ B
ldb: ํ๋ ฌ B์ leading dimension
C: ํฌ๊ธฐ M x N์ ํ๋ ฌ C
๋์:
ํ๋ ฌ A์ B๋ฅผ ๊ณฑํ์ฌ ํ๋ ฌ C๋ฅผ ๊ณ์ฐํ๋ General Matrix Multiply(GEMM) ์ฐ์ฐ์ ์ํํ๋ค.
i, k, j ์ธ ๊ฐ์ for ๋ฃจํ๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ ฌ C์ ๊ฐ ์์๋ฅผ ๊ณ์ฐํ๋ค.
i ๋ฃจํ์์๋ ํ๋ ฌ A์ ๊ฐ ํ์ ์ํํ๋ฉฐ, k ๋ฃจํ์์๋ ํ๋ ฌ A์ ๊ฐ ์ด๊ณผ ํ๋ ฌ B์ ๊ฐ ํ์ ์ํํ๋ฉฐ, j ๋ฃจํ์์๋ ํ๋ ฌ B์ ๊ฐ ์ด์ ์ํํ๋ฉฐ ํ๋ ฌ C์ ๊ฐ ์์๋ฅผ ๊ณ์ฐํ๋ค.
OpenMP๋ฅผ ์ฌ์ฉํ์ฌ ๋ณ๋ ฌ ์ฒ๋ฆฌํ๋ค.
์ค๋ช :
General Matrix Multiply(GEMM) ์ฐ์ฐ์ ์ธ๊ณต ์ ๊ฒฝ๋ง์์ ๊ฐ์ฅ ๋ง์ด ์ฌ์ฉ๋๋ ์ฐ์ฐ ์ค ํ๋์ด๋ค.
GEMM ์ฐ์ฐ์ ์ํํ๋ ๋ฐฉ๋ฒ์ ์ฌ๋ฌ ๊ฐ์ง๊ฐ ์์ผ๋ฉฐ, ์ด ํจ์์์๋ A ํ๋ ฌ์ ์ํํ๋ฉด์ A์ B์ ๊ณฑ์ ๊ณ์ฐํ๋ค.
OpenMP๋ ๋ฉํฐ์ฝ์ด CPU์์ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ํํ ์ ์๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก, ์ด๋ฅผ ์ฌ์ฉํ์ฌ ์ฑ๋ฅ์ ํฅ์์ํจ๋ค.
gemm_nt
void gemm_nt(int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float *C, int ldc)
{
int i,j,k;
#pragma omp parallel for
for(i = 0; i < M; ++i){
for(j = 0; j < N; ++j){
register float sum = 0;
for(k = 0; k < K; ++k){
sum += ALPHA*A[i*lda+k]*B[j*ldb + k];
}
C[i*ldc+j] += sum;
}
}
}
ํจ์ ์ด๋ฆ: gemm_nt
์ ๋ ฅ:
M: A ํ๋ ฌ์ ํ ๊ฐ์
N: B ํ๋ ฌ์ ์ด ๊ฐ์
K: A ํ๋ ฌ์ ์ด ๊ฐ์ (๋์์ B ํ๋ ฌ์ ํ ๊ฐ์)
ALPHA: A์ B ํ๋ ฌ์ ๊ณฑ์ ๊ฒฐ๊ณผ์ ๊ณฑํด์ง ์ค์นผ๋ผ ๊ฐ
*A: A ํ๋ ฌ์ ํฌ์ธํฐ
lda: A ํ๋ ฌ์ ํ ๋น ์์ ๊ฐ์
*B: B ํ๋ ฌ์ ํฌ์ธํฐ
ldb: B ํ๋ ฌ์ ํ ๋น ์์ ๊ฐ์
*C: C ํ๋ ฌ์ ํฌ์ธํฐ
ldc: C ํ๋ ฌ์ ํ ๋น ์์ ๊ฐ์
๋์:
ํ๋ ฌ A์ B๋ฅผ ๊ณฑํ ํ, C ํ๋ ฌ์ ๋ํด์ฃผ๋ ์ฐ์ฐ์ ์ํํ๋ค.
A์ B ํ๋ ฌ์ ๊ณฑํ๊ธฐ ์ํด A๋ ๊ทธ๋๋ก, B๋ ์ ์น(transpose)๋ ํํ๋ก ์ฌ์ฉ๋๋ค.
A์ i๋ฒ์งธ ํ๊ณผ B์ j๋ฒ์งธ ์ด์ ๊ณฑํ ๊ฐ์ C์ i๋ฒ์งธ ํ j๋ฒ์งธ ์ด์ ๋์ ํ์ฌ ๋ํด์ค๋ค.
์ค๋ช :
์ด ํจ์๋ B ํ๋ ฌ์ด ์ ์น๋ ํํ๋ก ์ ๋ ฅ์ผ๋ก ๋ค์ด์ฌ ๋ A์ B๋ฅผ ๊ณฑํ ํ C ํ๋ ฌ์ ๋ํด์ฃผ๋ ์ฐ์ฐ์ ์ํํ๋ค.
ํจ์ ๋ด๋ถ์์๋ OpenMP๋ฅผ ์ด์ฉํ์ฌ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ํํ๋ฉฐ, i, j, k ์ธ ๊ฐ์ for ๋ฃจํ๋ฅผ ์ด์ฉํ์ฌ ํ๋ ฌ์ ์์ ๊ณฑ์ ๋ฐ ๋ง์ ์ฐ์ฐ์ ์ํํ๋ค.
gemm_tn
void gemm_tn(int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float *C, int ldc)
{
int i,j,k;
#pragma omp parallel for
for(i = 0; i < M; ++i){
for(k = 0; k < K; ++k){
register float A_PART = ALPHA*A[k*lda+i];
for(j = 0; j < N; ++j){
C[i*ldc+j] += A_PART*B[k*ldb+j];
}
}
}
}
ํจ์ ์ด๋ฆ: gemm_tn
์ ๋ ฅ:
int M: ํ๋ ฌ A์ ํ์ ์
int N: ํ๋ ฌ B์ ์ด์ ์
int K: ํ๋ ฌ A์ ์ด์ ์ ๋๋ ํ๋ ฌ B์ ํ์ ์
float ALPHA: ๊ณฑํด์ง๋ ์์
float *A: ํ๋ ฌ A์ ๋ฐ์ดํฐ ํฌ์ธํฐ
int lda: ํ๋ ฌ A์ ํ ๊ฐ๊ฒฉ
float *B: ํ๋ ฌ B์ ๋ฐ์ดํฐ ํฌ์ธํฐ
int ldb: ํ๋ ฌ B์ ํ ๊ฐ๊ฒฉ
float *C: ์ถ๋ ฅ ํ๋ ฌ C์ ๋ฐ์ดํฐ ํฌ์ธํฐ
int ldc: ์ถ๋ ฅ ํ๋ ฌ C์ ํ ๊ฐ๊ฒฉ
๋์:
ํ๋ ฌ A์ B์ ์ ์นํ๋ ฌ์ธ AT์ BT๋ฅผ ๊ณฑํ๊ณ ALPHA๋ฅผ ๊ณฑํ ๊ฐ์ ์ถ๋ ฅ ํ๋ ฌ C์ ๋ํ๋ค.
์ค๋ช :
์ด ํจ์๋ ํ๋ ฌ A์ B์ ์ ์นํ๋ ฌ์ธ AT์ BT๋ฅผ ๊ณฑํ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๋ ํจ์์ด๋ค.
์ด๋ ALPHA๋ฅผ ๊ณฑํ ๊ฐ์ด ์ถ๋ ฅ ํ๋ ฌ C์ ๋ํด์ง๋ค.
๋ด๋ถ์ ์ผ๋ก๋ OpenMP๋ฅผ ์ฌ์ฉํ์ฌ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ํํ๋ค.
gemm_tt
void gemm_tt(int M, int N, int K, float ALPHA,
float *A, int lda,
float *B, int ldb,
float *C, int ldc)
{
int i,j,k;
#pragma omp parallel for
for(i = 0; i < M; ++i){
for(j = 0; j < N; ++j){
register float sum = 0;
for(k = 0; k < K; ++k){
sum += ALPHA*A[i+k*lda]*B[k+j*ldb];
}
C[i*ldc+j] += sum;
}
}
}
ํจ์ ์ด๋ฆ: gemm_tt
์ ๋ ฅ:
int M: ํ๋ ฌ C์ ํ ๊ฐ์
int N: ํ๋ ฌ C์ ์ด ๊ฐ์
int K: ํ๋ ฌ A์ ์ด ๊ฐ์ (ํ๋ ฌ B์ ํ ๊ฐ์)
float ALPHA: ์ค์นผ๋ผ ๊ฐ
float *A: M x K ํฌ๊ธฐ์ ํ๋ ฌ A
int lda: ํ๋ ฌ A์ ์ด ๊ฐ์
float *B: K x N ํฌ๊ธฐ์ ํ๋ ฌ B
int ldb: ํ๋ ฌ B์ ์ด ๊ฐ์
float *C: M x N ํฌ๊ธฐ์ ํ๋ ฌ C
int ldc: ํ๋ ฌ C์ ์ด ๊ฐ์
๋์:
๋ ๊ฐ์ ํ๋ ฌ A์ B๋ฅผ ๊ณฑํ ๊ฒฐ๊ณผ๋ฅผ ํ๋ ฌ C์ ๋์ ํ๋ค.
A์ B๋ ์ ์น(transpose)๋์ด ์๋ค๊ณ ๊ฐ์ ํ๋ค.
OpenMP๋ฅผ ์ฌ์ฉํ์ฌ ๋ณ๋ ฌ์ฒ๋ฆฌํ๋ค.
์ค๋ช :
์ผ๋ฐ์ ์ผ๋ก ํ๋ ฌ ๊ณฑ์ ์ฐ์ฐ์์๋ A x B์ B x A๋ ๋ค๋ฅด๋ค. ๊ทธ๋ฌ๋ gemm_tt ํจ์์์๋ A์ B ๋ชจ๋ ์ ์น๋ ์ํ์์ ๊ณฑ์ ์ ์ํํ๊ธฐ ๋๋ฌธ์ A^T x B^T = (B x A)^T์ ๊ฐ์ ๊ฒฐ๊ณผ๋ฅผ ์ป๋๋ค.
i, j, k์ ์์๋ก 3์ค for ๋ฃจํ๋ฅผ ์ํํ๋ฉฐ, ๊ฐ๊ฐ C[ildc+j], A[i+klda], B[k+j*ldb]์ ๊ฐ์ ์ฐธ์กฐํ๋ค.
๊ฐ๊ฐ์ C[i*ldc+j]์ ๊ฐ์ ๊ณ์ฐํ๊ธฐ ์ํด sum ๋ณ์๋ฅผ ์ฌ์ฉํ์ฌ ๋์ ํฉ์ ๊ณ์ฐํ๋ค.
OpenMP๋ฅผ ์ฌ์ฉํ์ฌ ๋ณ๋ ฌ์ฒ๋ฆฌํ์ฌ ์ฑ๋ฅ์ ํฅ์์ํจ๋ค.
Last updated
Was this helpful?