gemm

GEMM ์ด๋ž€?

์ฐธ๊ณ  ์ž๋ฃŒ : https://petewarden.com/2015/04/20/why-gemm-is-at-the-heart-of-deep-learning/

  • General Matrix to Matrix Multiplication

  • 1979๋…„์— ๋งŒ๋“ค์–ด์ง„ BLAS ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ์ผ๋ถ€ ์ž…๋‹ˆ๋‹ค.

  • ๋‘๊ฐœ์˜ ์ž…๋ ฅ ํ–‰๋ ฌ์„ ๊ณฑํ•ด์„œ ์ถœ๋ ฅ์„ ์–ป๋Š” ๋ฐฉ๋ฒ• ์ž…๋‹ˆ๋‹ค.

๋”ฅ๋Ÿฌ๋‹์—์„œ ๋Œ€๋ถ€๋ถ„์˜ ์—ฐ์‚ฐ์€ output = input * weight + bias๋กœ ํ‘œํ˜„์ด ๋ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ input, output, weight๋ฅผ ํ–‰๋ ฌ๋กœ ํ‘œํ˜„ํ•ด์„œ GEMM์„ ์‚ฌ์šฉํ•ด ์—ฐ์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Fully Connected Layer

fully connected layer๋Š” ์œ„์™€ ๊ฐ™์ด ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Convolutional Layer

  • im2col : 3์ฐจ์› ์ด๋ฏธ์ง€ ๋ฐฐ์—ด์„ 2์ฐจ์› ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

convolutional layer๋Š” ์œ„์™€ ๊ฐ™์ด ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์œ„ ๊ทธ๋ฆผ์˜ ๊ฒฝ์šฐ๋Š” stride๊ฐ€ kernel size์™€ ๊ฐ™์€ ๊ฒฝ์šฐ๋ฅผ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.


gemm.c

gemm

gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);

void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float BETA,
        float *C, int ldc)
{
    gemm_cpu(TA,  TB,  M, N, K, ALPHA, A, lda, B, ldb, BETA, C, ldc);
}

ํ•จ์ˆ˜ ์ด๋ฆ„: gemm

์ž…๋ ฅ:

  • int TA: ํ–‰๋ ฌ A์˜ ์ „์น˜ ์—ฌ๋ถ€ (0: ์ „์น˜ํ•˜์ง€ ์•Š์Œ, 1: ์ „์น˜ํ•จ)

  • int TB: ํ–‰๋ ฌ B์˜ ์ „์น˜ ์—ฌ๋ถ€ (0: ์ „์น˜ํ•˜์ง€ ์•Š์Œ, 1: ์ „์น˜ํ•จ)

  • int M: ํ–‰๋ ฌ C์˜ ํ–‰์˜ ์ˆ˜

  • int N: ํ–‰๋ ฌ C์˜ ์—ด์˜ ์ˆ˜

  • int K: ํ–‰๋ ฌ A์˜ ์—ด์˜ ์ˆ˜ (ํ–‰๋ ฌ B์˜ ํ–‰์˜ ์ˆ˜์™€ ๊ฐ™์•„์•ผ ํ•จ)

  • float ALPHA: ์Šค์นผ๋ผ ๊ฐ’

  • float *A: ํ–‰๋ ฌ A์˜ ํฌ์ธํ„ฐ

  • int lda: ํ–‰๋ ฌ A์˜ ํ–‰ ๋‹จ์œ„ ํฌ๊ธฐ

  • float *B: ํ–‰๋ ฌ B์˜ ํฌ์ธํ„ฐ

  • int ldb: ํ–‰๋ ฌ B์˜ ํ–‰ ๋‹จ์œ„ ํฌ๊ธฐ

  • float BETA: ์Šค์นผ๋ผ ๊ฐ’

  • float *C: ํ–‰๋ ฌ C์˜ ํฌ์ธํ„ฐ

  • int ldc: ํ–‰๋ ฌ C์˜ ํ–‰ ๋‹จ์œ„ ํฌ๊ธฐ

๋™์ž‘:

  • ํ–‰๋ ฌ-ํ–‰๋ ฌ ๊ณฑ์…ˆ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•จ.

์„ค๋ช…:

  • ์ด ํ•จ์ˆ˜๋Š” CPU ์ƒ์—์„œ ํ–‰๋ ฌ-ํ–‰๋ ฌ ๊ณฑ์…ˆ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ํ•จ์ˆ˜์ด๋‹ค.

  • gemm_cpu ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ์ด ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

  • ํ–‰๋ ฌ A์™€ ํ–‰๋ ฌ B์˜ ํฌ๊ธฐ์™€ ์ „์น˜ ์—ฌ๋ถ€, ์Šค์นผ๋ผ ๊ฐ’ ALPHA์™€ BETA ๋“ฑ์„ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›๊ณ , ์—ฐ์‚ฐ ๊ฒฐ๊ณผ์ธ ํ–‰๋ ฌ C๋ฅผ ์ถœ๋ ฅ์œผ๋กœ ๋ฐ˜ํ™˜ํ•œ๋‹ค.

gemm_cpu

void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float BETA,
        float *C, int ldc)
{
    //printf("cpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc);
    int i, j;
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            C[i*ldc + j] *= BETA;
        }
    }
    if(!TA && !TB)
        gemm_nn(M, N, K, ALPHA, A,lda, B, ldb, C, ldc);
    else if(TA && !TB)
        gemm_tn(M, N, K, ALPHA, A,lda, B, ldb, C, ldc);
    else if(!TA && TB)
        gemm_nt(M, N, K, ALPHA, A,lda, B, ldb, C, ldc);
    else
        gemm_tt(M, N, K, ALPHA, A,lda, B, ldb, C, ldc);
}

ํ•จ์ˆ˜ ์ด๋ฆ„: gemm_cpu

์ž…๋ ฅ:

  • int TA: A ํ–‰๋ ฌ์˜ ์ „์น˜ ์—ฌ๋ถ€๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ํ”Œ๋ž˜๊ทธ

  • int TB: B ํ–‰๋ ฌ์˜ ์ „์น˜ ์—ฌ๋ถ€๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ํ”Œ๋ž˜๊ทธ

  • int M: C ํ–‰๋ ฌ์˜ ํ–‰ ์ˆ˜

  • int N: C ํ–‰๋ ฌ์˜ ์—ด ์ˆ˜

  • int K: A, B ํ–‰๋ ฌ์—์„œ ๊ณต์œ ํ•˜๋Š” ์ฐจ์›์˜ ํฌ๊ธฐ

  • float ALPHA: A, B ํ–‰๋ ฌ์˜ ๊ณฑ์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜

  • float *A: A ํ–‰๋ ฌ์˜ ํฌ์ธํ„ฐ

  • int lda: A ํ–‰๋ ฌ์˜ ํ–‰ ๋‹น ์›์†Œ ์ˆ˜

  • float *B: B ํ–‰๋ ฌ์˜ ํฌ์ธํ„ฐ

  • int ldb: B ํ–‰๋ ฌ์˜ ํ–‰ ๋‹น ์›์†Œ ์ˆ˜

  • float BETA: C ํ–‰๋ ฌ์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜

  • float *C: C ํ–‰๋ ฌ์˜ ํฌ์ธํ„ฐ

  • int ldc: C ํ–‰๋ ฌ์˜ ํ–‰ ๋‹น ์›์†Œ ์ˆ˜

๋™์ž‘:

  • CPU์—์„œ ํ–‰๋ ฌ ๊ณฑ์…ˆ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

  • A, B, C ์„ธ ๊ฐœ์˜ ํ–‰๋ ฌ์„ ์ธ์ž๋กœ ๋ฐ›๊ณ , A์™€ B์˜ ๊ณฑ์— ๊ฐ€์ค‘์น˜ ALPHA๋ฅผ ๊ณฑํ•œ ๊ฒฐ๊ณผ๋ฅผ C ํ–‰๋ ฌ์— ๋”ํ•œ๋‹ค.

์„ค๋ช…:

  • gemm_cpu ํ•จ์ˆ˜๋Š” CPU์—์„œ ํ–‰๋ ฌ ๊ณฑ์…ˆ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

  • ์ด ํ•จ์ˆ˜๋Š” A, B, C ์„ธ ๊ฐœ์˜ ํฌ์ธํ„ฐ์™€ ๋‹ค์–‘ํ•œ ์ธ์ž๋ฅผ ๋ฐ›์•„์„œ, ํ–‰๋ ฌ ๊ณฑ์…ˆ ์—ฐ์‚ฐ ๊ฒฐ๊ณผ๋ฅผ C ํ–‰๋ ฌ์— ์ €์žฅํ•œ๋‹ค.

  • ํ•จ์ˆ˜ ๋‚ด๋ถ€์—์„œ๋Š” TA์™€ TB ์ธ์ž๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ A์™€ B ํ–‰๋ ฌ์ด ์ „์น˜๋˜์–ด ์žˆ๋Š”์ง€ ์—ฌ๋ถ€๋ฅผ ํ™•์ธํ•˜๊ณ , ์ด์— ๋”ฐ๋ผ gemm_nn, gemm_tn, gemm_nt, gemm_tt ํ•จ์ˆ˜ ์ค‘ ํ•˜๋‚˜๋ฅผ ํ˜ธ์ถœํ•œ๋‹ค.

  • ์ด ํ•จ์ˆ˜๋“ค์€ ๋‹ค์–‘ํ•œ ํ–‰๋ ฌ ๊ณฑ์…ˆ ์—ฐ์‚ฐ ๋ฐฉ๋ฒ•์„ ๊ตฌํ˜„ํ•˜๊ณ  ์žˆ๋‹ค.

  • ๋”ฐ๋ผ์„œ gemm_cpu ํ•จ์ˆ˜๋Š” ์ด๋ฅผ ์ด์šฉํ•˜์—ฌ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์€ ํ–‰๋ ฌ A, B์˜ ๊ณฑ์— ๊ฐ€์ค‘์น˜ ALPHA๋ฅผ ๊ณฑํ•œ ๊ฒฐ๊ณผ๋ฅผ C ํ–‰๋ ฌ์— ๋”ํ•œ๋‹ค.

  • ์ด ๋•Œ BETA ์ธ์ž๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ธฐ์กด์˜ C ํ–‰๋ ฌ ๊ฐ’์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜๋ฅผ ์กฐ์ ˆํ•  ์ˆ˜ ์žˆ๋‹ค.

gemm_nn

void gemm_nn(int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
    for(i = 0; i < M; ++i){
        for(k = 0; k < K; ++k){
            register float A_PART = ALPHA*A[i*lda+k];
            for(j = 0; j < N; ++j){
                C[i*ldc+j] += A_PART*B[k*ldb+j];
            }
        }
    }
}

ํ•จ์ˆ˜ ์ด๋ฆ„: gemm_nn

์ž…๋ ฅ:

  • M: ํ–‰๋ ฌ A์˜ ํ–‰์˜ ๊ฐœ์ˆ˜

  • N: ํ–‰๋ ฌ B์˜ ์—ด์˜ ๊ฐœ์ˆ˜

  • K: ํ–‰๋ ฌ A์˜ ์—ด์˜ ๊ฐœ์ˆ˜ ๋˜๋Š” ํ–‰๋ ฌ B์˜ ํ–‰์˜ ๊ฐœ์ˆ˜

  • ALPHA: ํ–‰๋ ฌ A์™€ ํ–‰๋ ฌ B์˜ ๊ณฑ์…ˆ ๊ฒฐ๊ณผ์— ๊ณฑํ•ด์ง€๋Š” ์Šค์นผ๋ผ ๊ฐ’

  • A: ํฌ๊ธฐ M x K์˜ ํ–‰๋ ฌ A

  • lda: ํ–‰๋ ฌ A์˜ leading dimension

  • B: ํฌ๊ธฐ K x N์˜ ํ–‰๋ ฌ B

  • ldb: ํ–‰๋ ฌ B์˜ leading dimension

  • C: ํฌ๊ธฐ M x N์˜ ํ–‰๋ ฌ C

๋™์ž‘:

  • ํ–‰๋ ฌ A์™€ B๋ฅผ ๊ณฑํ•˜์—ฌ ํ–‰๋ ฌ C๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” General Matrix Multiply(GEMM) ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

  • i, k, j ์„ธ ๊ฐœ์˜ for ๋ฃจํ”„๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ–‰๋ ฌ C์˜ ๊ฐ ์š”์†Œ๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.

  • i ๋ฃจํ”„์—์„œ๋Š” ํ–‰๋ ฌ A์˜ ๊ฐ ํ–‰์„ ์ˆœํšŒํ•˜๋ฉฐ, k ๋ฃจํ”„์—์„œ๋Š” ํ–‰๋ ฌ A์˜ ๊ฐ ์—ด๊ณผ ํ–‰๋ ฌ B์˜ ๊ฐ ํ–‰์„ ์ˆœํšŒํ•˜๋ฉฐ, j ๋ฃจํ”„์—์„œ๋Š” ํ–‰๋ ฌ B์˜ ๊ฐ ์—ด์„ ์ˆœํšŒํ•˜๋ฉฐ ํ–‰๋ ฌ C์˜ ๊ฐ ์š”์†Œ๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.

  • OpenMP๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌํ•œ๋‹ค.

์„ค๋ช…:

  • General Matrix Multiply(GEMM) ์—ฐ์‚ฐ์€ ์ธ๊ณต ์‹ ๊ฒฝ๋ง์—์„œ ๊ฐ€์žฅ ๋งŽ์ด ์‚ฌ์šฉ๋˜๋Š” ์—ฐ์‚ฐ ์ค‘ ํ•˜๋‚˜์ด๋‹ค.

  • GEMM ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์€ ์—ฌ๋Ÿฌ ๊ฐ€์ง€๊ฐ€ ์žˆ์œผ๋ฉฐ, ์ด ํ•จ์ˆ˜์—์„œ๋Š” A ํ–‰๋ ฌ์„ ์ˆœํšŒํ•˜๋ฉด์„œ A์™€ B์˜ ๊ณฑ์„ ๊ณ„์‚ฐํ•œ๋‹ค.

  • OpenMP๋Š” ๋ฉ€ํ‹ฐ์ฝ”์–ด CPU์—์„œ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋Š” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋กœ, ์ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚จ๋‹ค.

gemm_nt

void gemm_nt(int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            register float sum = 0;
            for(k = 0; k < K; ++k){
                sum += ALPHA*A[i*lda+k]*B[j*ldb + k];
            }
            C[i*ldc+j] += sum;
        }
    }
}

ํ•จ์ˆ˜ ์ด๋ฆ„: gemm_nt

์ž…๋ ฅ:

  • M: A ํ–‰๋ ฌ์˜ ํ–‰ ๊ฐœ์ˆ˜

  • N: B ํ–‰๋ ฌ์˜ ์—ด ๊ฐœ์ˆ˜

  • K: A ํ–‰๋ ฌ์˜ ์—ด ๊ฐœ์ˆ˜ (๋™์‹œ์— B ํ–‰๋ ฌ์˜ ํ–‰ ๊ฐœ์ˆ˜)

  • ALPHA: A์™€ B ํ–‰๋ ฌ์˜ ๊ณฑ์…ˆ ๊ฒฐ๊ณผ์— ๊ณฑํ•ด์งˆ ์Šค์นผ๋ผ ๊ฐ’

  • *A: A ํ–‰๋ ฌ์˜ ํฌ์ธํ„ฐ

  • lda: A ํ–‰๋ ฌ์˜ ํ–‰ ๋‹น ์›์†Œ ๊ฐœ์ˆ˜

  • *B: B ํ–‰๋ ฌ์˜ ํฌ์ธํ„ฐ

  • ldb: B ํ–‰๋ ฌ์˜ ํ–‰ ๋‹น ์›์†Œ ๊ฐœ์ˆ˜

  • *C: C ํ–‰๋ ฌ์˜ ํฌ์ธํ„ฐ

  • ldc: C ํ–‰๋ ฌ์˜ ํ–‰ ๋‹น ์›์†Œ ๊ฐœ์ˆ˜

๋™์ž‘:

  • ํ–‰๋ ฌ A์™€ B๋ฅผ ๊ณฑํ•œ ํ›„, C ํ–‰๋ ฌ์— ๋”ํ•ด์ฃผ๋Š” ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

  • A์™€ B ํ–‰๋ ฌ์„ ๊ณฑํ•˜๊ธฐ ์œ„ํ•ด A๋Š” ๊ทธ๋Œ€๋กœ, B๋Š” ์ „์น˜(transpose)๋œ ํ˜•ํƒœ๋กœ ์‚ฌ์šฉ๋œ๋‹ค.

  • A์˜ i๋ฒˆ์งธ ํ–‰๊ณผ B์˜ j๋ฒˆ์งธ ์—ด์„ ๊ณฑํ•œ ๊ฐ’์„ C์˜ i๋ฒˆ์งธ ํ–‰ j๋ฒˆ์งธ ์—ด์— ๋ˆ„์ ํ•˜์—ฌ ๋”ํ•ด์ค€๋‹ค.

์„ค๋ช…:

  • ์ด ํ•จ์ˆ˜๋Š” B ํ–‰๋ ฌ์ด ์ „์น˜๋œ ํ˜•ํƒœ๋กœ ์ž…๋ ฅ์œผ๋กœ ๋“ค์–ด์˜ฌ ๋•Œ A์™€ B๋ฅผ ๊ณฑํ•œ ํ›„ C ํ–‰๋ ฌ์— ๋”ํ•ด์ฃผ๋Š” ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

  • ํ•จ์ˆ˜ ๋‚ด๋ถ€์—์„œ๋Š” OpenMP๋ฅผ ์ด์šฉํ•˜์—ฌ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•˜๋ฉฐ, i, j, k ์„ธ ๊ฐœ์˜ for ๋ฃจํ”„๋ฅผ ์ด์šฉํ•˜์—ฌ ํ–‰๋ ฌ์˜ ์›์†Œ ๊ณฑ์…ˆ ๋ฐ ๋ง์…ˆ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

gemm_tn

void gemm_tn(int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
    for(i = 0; i < M; ++i){
        for(k = 0; k < K; ++k){
            register float A_PART = ALPHA*A[k*lda+i];
            for(j = 0; j < N; ++j){
                C[i*ldc+j] += A_PART*B[k*ldb+j];
            }
        }
    }
}

ํ•จ์ˆ˜ ์ด๋ฆ„: gemm_tn

์ž…๋ ฅ:

  • int M: ํ–‰๋ ฌ A์˜ ํ–‰์˜ ์ˆ˜

  • int N: ํ–‰๋ ฌ B์˜ ์—ด์˜ ์ˆ˜

  • int K: ํ–‰๋ ฌ A์˜ ์—ด์˜ ์ˆ˜ ๋˜๋Š” ํ–‰๋ ฌ B์˜ ํ–‰์˜ ์ˆ˜

  • float ALPHA: ๊ณฑํ•ด์ง€๋Š” ์ƒ์ˆ˜

  • float *A: ํ–‰๋ ฌ A์˜ ๋ฐ์ดํ„ฐ ํฌ์ธํ„ฐ

  • int lda: ํ–‰๋ ฌ A์˜ ํ–‰ ๊ฐ„๊ฒฉ

  • float *B: ํ–‰๋ ฌ B์˜ ๋ฐ์ดํ„ฐ ํฌ์ธํ„ฐ

  • int ldb: ํ–‰๋ ฌ B์˜ ํ–‰ ๊ฐ„๊ฒฉ

  • float *C: ์ถœ๋ ฅ ํ–‰๋ ฌ C์˜ ๋ฐ์ดํ„ฐ ํฌ์ธํ„ฐ

  • int ldc: ์ถœ๋ ฅ ํ–‰๋ ฌ C์˜ ํ–‰ ๊ฐ„๊ฒฉ

๋™์ž‘:

  • ํ–‰๋ ฌ A์™€ B์˜ ์ „์น˜ํ–‰๋ ฌ์ธ AT์™€ BT๋ฅผ ๊ณฑํ•˜๊ณ  ALPHA๋ฅผ ๊ณฑํ•œ ๊ฐ’์„ ์ถœ๋ ฅ ํ–‰๋ ฌ C์— ๋”ํ•œ๋‹ค.

์„ค๋ช…:

  • ์ด ํ•จ์ˆ˜๋Š” ํ–‰๋ ฌ A์™€ B์˜ ์ „์น˜ํ–‰๋ ฌ์ธ AT์™€ BT๋ฅผ ๊ณฑํ•œ ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•˜๋Š” ํ•จ์ˆ˜์ด๋‹ค.

  • ์ด๋•Œ ALPHA๋ฅผ ๊ณฑํ•œ ๊ฐ’์ด ์ถœ๋ ฅ ํ–‰๋ ฌ C์— ๋”ํ•ด์ง„๋‹ค.

  • ๋‚ด๋ถ€์ ์œผ๋กœ๋Š” OpenMP๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•œ๋‹ค.

gemm_tt

void gemm_tt(int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    #pragma omp parallel for
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            register float sum = 0;
            for(k = 0; k < K; ++k){
                sum += ALPHA*A[i+k*lda]*B[k+j*ldb];
            }
            C[i*ldc+j] += sum;
        }
    }
}

ํ•จ์ˆ˜ ์ด๋ฆ„: gemm_tt

์ž…๋ ฅ:

  • int M: ํ–‰๋ ฌ C์˜ ํ–‰ ๊ฐœ์ˆ˜

  • int N: ํ–‰๋ ฌ C์˜ ์—ด ๊ฐœ์ˆ˜

  • int K: ํ–‰๋ ฌ A์˜ ์—ด ๊ฐœ์ˆ˜ (ํ–‰๋ ฌ B์˜ ํ–‰ ๊ฐœ์ˆ˜)

  • float ALPHA: ์Šค์นผ๋ผ ๊ฐ’

  • float *A: M x K ํฌ๊ธฐ์˜ ํ–‰๋ ฌ A

  • int lda: ํ–‰๋ ฌ A์˜ ์—ด ๊ฐœ์ˆ˜

  • float *B: K x N ํฌ๊ธฐ์˜ ํ–‰๋ ฌ B

  • int ldb: ํ–‰๋ ฌ B์˜ ์—ด ๊ฐœ์ˆ˜

  • float *C: M x N ํฌ๊ธฐ์˜ ํ–‰๋ ฌ C

  • int ldc: ํ–‰๋ ฌ C์˜ ์—ด ๊ฐœ์ˆ˜

๋™์ž‘:

  • ๋‘ ๊ฐœ์˜ ํ–‰๋ ฌ A์™€ B๋ฅผ ๊ณฑํ•œ ๊ฒฐ๊ณผ๋ฅผ ํ–‰๋ ฌ C์— ๋ˆ„์ ํ•œ๋‹ค.

  • A์™€ B๋Š” ์ „์น˜(transpose)๋˜์–ด ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•œ๋‹ค.

  • OpenMP๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณ‘๋ ฌ์ฒ˜๋ฆฌํ•œ๋‹ค.

์„ค๋ช…:

  • ์ผ๋ฐ˜์ ์œผ๋กœ ํ–‰๋ ฌ ๊ณฑ์…ˆ ์—ฐ์‚ฐ์—์„œ๋Š” A x B์™€ B x A๋Š” ๋‹ค๋ฅด๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ gemm_tt ํ•จ์ˆ˜์—์„œ๋Š” A์™€ B ๋ชจ๋‘ ์ „์น˜๋œ ์ƒํƒœ์—์„œ ๊ณฑ์…ˆ์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ๋•Œ๋ฌธ์— A^T x B^T = (B x A)^T์™€ ๊ฐ™์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป๋Š”๋‹ค.

  • i, j, k์˜ ์ˆœ์„œ๋กœ 3์ค‘ for ๋ฃจํ”„๋ฅผ ์ˆ˜ํ–‰ํ•˜๋ฉฐ, ๊ฐ๊ฐ C[ildc+j], A[i+klda], B[k+j*ldb]์˜ ๊ฐ’์„ ์ฐธ์กฐํ•œ๋‹ค.

  • ๊ฐ๊ฐ์˜ C[i*ldc+j]์˜ ๊ฐ’์„ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด sum ๋ณ€์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ˆ„์  ํ•ฉ์„ ๊ณ„์‚ฐํ•œ๋‹ค.

  • OpenMP๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ณ‘๋ ฌ์ฒ˜๋ฆฌํ•˜์—ฌ ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚จ๋‹ค.

Last updated

Was this helpful?