gemm
Last updated
Was this helpful?
Last updated
Was this helpful?
μ°Έκ³ μλ£ :
General Matrix to Matrix Multiplication
1979λ μ λ§λ€μ΄μ§ BLAS λΌμ΄λΈλ¬λ¦¬μ μΌλΆ μ λλ€.
λκ°μ μ λ ₯ νλ ¬μ κ³±ν΄μ μΆλ ₯μ μ»λ λ°©λ² μ λλ€.
λ₯λ¬λμμ λλΆλΆμ μ°μ°μ output = input * weight + bias
λ‘ ννμ΄ λ©λλ€. μ¬κΈ°μ input
, output
, weight
λ₯Ό νλ ¬λ‘ ννν΄μ GEMMμ μ¬μ©ν΄ μ°μ°ν μ μμ΅λλ€.
fully connected layer
λ μμ κ°μ΄ ννν μ μμ΅λλ€.
im2col
: 3μ°¨μ μ΄λ―Έμ§ λ°°μ΄μ 2μ°¨μ λ°°μ΄λ‘ λ³νν©λλ€.
convolutional layer
λ μμ κ°μ΄ ννν μ μμ΅λλ€. μ κ·Έλ¦Όμ κ²½μ°λ stride
κ° kernel size
μ κ°μ κ²½μ°λ₯Ό μλ―Έν©λλ€.
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
ν¨μ μ΄λ¦: gemm
μ λ ₯:
int TA: νλ ¬ Aμ μ μΉ μ¬λΆ (0: μ μΉνμ§ μμ, 1: μ μΉν¨)
int TB: νλ ¬ Bμ μ μΉ μ¬λΆ (0: μ μΉνμ§ μμ, 1: μ μΉν¨)
int M: νλ ¬ Cμ νμ μ
int N: νλ ¬ Cμ μ΄μ μ
int K: νλ ¬ Aμ μ΄μ μ (νλ ¬ Bμ νμ μμ κ°μμΌ ν¨)
float ALPHA: μ€μΉΌλΌ κ°
float *A: νλ ¬ Aμ ν¬μΈν°
int lda: νλ ¬ Aμ ν λ¨μ ν¬κΈ°
float *B: νλ ¬ Bμ ν¬μΈν°
int ldb: νλ ¬ Bμ ν λ¨μ ν¬κΈ°
float BETA: μ€μΉΌλΌ κ°
float *C: νλ ¬ Cμ ν¬μΈν°
int ldc: νλ ¬ Cμ ν λ¨μ ν¬κΈ°
λμ:
νλ ¬-νλ ¬ κ³±μ μ°μ°μ μνν¨.
μ€λͺ :
μ΄ ν¨μλ CPU μμμ νλ ¬-νλ ¬ κ³±μ μ°μ°μ μννλ ν¨μμ΄λ€.
gemm_cpu ν¨μλ₯Ό νΈμΆνμ¬ μ΄ μ°μ°μ μννλ€.
νλ ¬ Aμ νλ ¬ Bμ ν¬κΈ°μ μ μΉ μ¬λΆ, μ€μΉΌλΌ κ° ALPHAμ BETA λ±μ μ λ ₯μΌλ‘ λ°κ³ , μ°μ° κ²°κ³ΌμΈ νλ ¬ Cλ₯Ό μΆλ ₯μΌλ‘ λ°ννλ€.
ν¨μ μ΄λ¦: gemm_cpu
μ λ ₯:
int TA: A νλ ¬μ μ μΉ μ¬λΆλ₯Ό λνλ΄λ νλκ·Έ
int TB: B νλ ¬μ μ μΉ μ¬λΆλ₯Ό λνλ΄λ νλκ·Έ
int M: C νλ ¬μ ν μ
int N: C νλ ¬μ μ΄ μ
int K: A, B νλ ¬μμ 곡μ νλ μ°¨μμ ν¬κΈ°
float ALPHA: A, B νλ ¬μ κ³±μ λν κ°μ€μΉ
float *A: A νλ ¬μ ν¬μΈν°
int lda: A νλ ¬μ ν λΉ μμ μ
float *B: B νλ ¬μ ν¬μΈν°
int ldb: B νλ ¬μ ν λΉ μμ μ
float BETA: C νλ ¬μ λν κ°μ€μΉ
float *C: C νλ ¬μ ν¬μΈν°
int ldc: C νλ ¬μ ν λΉ μμ μ
λμ:
CPUμμ νλ ¬ κ³±μ μ°μ°μ μννλ€.
A, B, C μΈ κ°μ νλ ¬μ μΈμλ‘ λ°κ³ , Aμ Bμ κ³±μ κ°μ€μΉ ALPHAλ₯Ό κ³±ν κ²°κ³Όλ₯Ό C νλ ¬μ λνλ€.
μ€λͺ :
gemm_cpu ν¨μλ CPUμμ νλ ¬ κ³±μ μ°μ°μ μννλ€.
μ΄ ν¨μλ A, B, C μΈ κ°μ ν¬μΈν°μ λ€μν μΈμλ₯Ό λ°μμ, νλ ¬ κ³±μ μ°μ° κ²°κ³Όλ₯Ό C νλ ¬μ μ μ₯νλ€.
ν¨μ λ΄λΆμμλ TAμ TB μΈμλ₯Ό μ¬μ©νμ¬ Aμ B νλ ¬μ΄ μ μΉλμ΄ μλμ§ μ¬λΆλ₯Ό νμΈνκ³ , μ΄μ λ°λΌ gemm_nn, gemm_tn, gemm_nt, gemm_tt ν¨μ μ€ νλλ₯Ό νΈμΆνλ€.
μ΄ ν¨μλ€μ λ€μν νλ ¬ κ³±μ μ°μ° λ°©λ²μ ꡬννκ³ μλ€.
λ°λΌμ gemm_cpu ν¨μλ μ΄λ₯Ό μ΄μ©νμ¬ μ λ ₯μΌλ‘ λ°μ νλ ¬ A, Bμ κ³±μ κ°μ€μΉ ALPHAλ₯Ό κ³±ν κ²°κ³Όλ₯Ό C νλ ¬μ λνλ€.
μ΄ λ BETA μΈμλ₯Ό μ¬μ©νμ¬ κΈ°μ‘΄μ C νλ ¬ κ°μ λν κ°μ€μΉλ₯Ό μ‘°μ ν μ μλ€.
ν¨μ μ΄λ¦: gemm_nn
μ λ ₯:
M: νλ ¬ Aμ νμ κ°μ
N: νλ ¬ Bμ μ΄μ κ°μ
K: νλ ¬ Aμ μ΄μ κ°μ λλ νλ ¬ Bμ νμ κ°μ
ALPHA: νλ ¬ Aμ νλ ¬ Bμ κ³±μ κ²°κ³Όμ κ³±ν΄μ§λ μ€μΉΌλΌ κ°
A: ν¬κΈ° M x Kμ νλ ¬ A
lda: νλ ¬ Aμ leading dimension
B: ν¬κΈ° K x Nμ νλ ¬ B
ldb: νλ ¬ Bμ leading dimension
C: ν¬κΈ° M x Nμ νλ ¬ C
λμ:
νλ ¬ Aμ Bλ₯Ό κ³±νμ¬ νλ ¬ Cλ₯Ό κ³μ°νλ General Matrix Multiply(GEMM) μ°μ°μ μννλ€.
i, k, j μΈ κ°μ for 루νλ₯Ό μ¬μ©νμ¬ νλ ¬ Cμ κ° μμλ₯Ό κ³μ°νλ€.
i 루νμμλ νλ ¬ Aμ κ° νμ μννλ©°, k 루νμμλ νλ ¬ Aμ κ° μ΄κ³Ό νλ ¬ Bμ κ° νμ μννλ©°, j 루νμμλ νλ ¬ Bμ κ° μ΄μ μννλ©° νλ ¬ Cμ κ° μμλ₯Ό κ³μ°νλ€.
OpenMPλ₯Ό μ¬μ©νμ¬ λ³λ ¬ μ²λ¦¬νλ€.
μ€λͺ :
General Matrix Multiply(GEMM) μ°μ°μ μΈκ³΅ μ κ²½λ§μμ κ°μ₯ λ§μ΄ μ¬μ©λλ μ°μ° μ€ νλμ΄λ€.
GEMM μ°μ°μ μννλ λ°©λ²μ μ¬λ¬ κ°μ§κ° μμΌλ©°, μ΄ ν¨μμμλ A νλ ¬μ μννλ©΄μ Aμ Bμ κ³±μ κ³μ°νλ€.
OpenMPλ λ©ν°μ½μ΄ CPUμμ λ³λ ¬ μ²λ¦¬λ₯Ό μνν μ μλ λΌμ΄λΈλ¬λ¦¬λ‘, μ΄λ₯Ό μ¬μ©νμ¬ μ±λ₯μ ν₯μμν¨λ€.
ν¨μ μ΄λ¦: gemm_nt
μ λ ₯:
M: A νλ ¬μ ν κ°μ
N: B νλ ¬μ μ΄ κ°μ
K: A νλ ¬μ μ΄ κ°μ (λμμ B νλ ¬μ ν κ°μ)
ALPHA: Aμ B νλ ¬μ κ³±μ κ²°κ³Όμ κ³±ν΄μ§ μ€μΉΌλΌ κ°
*A: A νλ ¬μ ν¬μΈν°
lda: A νλ ¬μ ν λΉ μμ κ°μ
*B: B νλ ¬μ ν¬μΈν°
ldb: B νλ ¬μ ν λΉ μμ κ°μ
*C: C νλ ¬μ ν¬μΈν°
ldc: C νλ ¬μ ν λΉ μμ κ°μ
λμ:
νλ ¬ Aμ Bλ₯Ό κ³±ν ν, C νλ ¬μ λν΄μ£Όλ μ°μ°μ μννλ€.
Aμ B νλ ¬μ κ³±νκΈ° μν΄ Aλ κ·Έλλ‘, Bλ μ μΉ(transpose)λ ννλ‘ μ¬μ©λλ€.
Aμ iλ²μ§Έ νκ³Ό Bμ jλ²μ§Έ μ΄μ κ³±ν κ°μ Cμ iλ²μ§Έ ν jλ²μ§Έ μ΄μ λμ νμ¬ λν΄μ€λ€.
μ€λͺ :
μ΄ ν¨μλ B νλ ¬μ΄ μ μΉλ ννλ‘ μ λ ₯μΌλ‘ λ€μ΄μ¬ λ Aμ Bλ₯Ό κ³±ν ν C νλ ¬μ λν΄μ£Όλ μ°μ°μ μννλ€.
ν¨μ λ΄λΆμμλ OpenMPλ₯Ό μ΄μ©νμ¬ λ³λ ¬ μ²λ¦¬λ₯Ό μννλ©°, i, j, k μΈ κ°μ for 루νλ₯Ό μ΄μ©νμ¬ νλ ¬μ μμ κ³±μ λ° λ§μ μ°μ°μ μννλ€.
ν¨μ μ΄λ¦: gemm_tn
μ λ ₯:
int M: νλ ¬ Aμ νμ μ
int N: νλ ¬ Bμ μ΄μ μ
int K: νλ ¬ Aμ μ΄μ μ λλ νλ ¬ Bμ νμ μ
float ALPHA: κ³±ν΄μ§λ μμ
float *A: νλ ¬ Aμ λ°μ΄ν° ν¬μΈν°
int lda: νλ ¬ Aμ ν κ°κ²©
float *B: νλ ¬ Bμ λ°μ΄ν° ν¬μΈν°
int ldb: νλ ¬ Bμ ν κ°κ²©
float *C: μΆλ ₯ νλ ¬ Cμ λ°μ΄ν° ν¬μΈν°
int ldc: μΆλ ₯ νλ ¬ Cμ ν κ°κ²©
λμ:
νλ ¬ Aμ Bμ μ μΉνλ ¬μΈ ATμ BTλ₯Ό κ³±νκ³ ALPHAλ₯Ό κ³±ν κ°μ μΆλ ₯ νλ ¬ Cμ λνλ€.
μ€λͺ :
μ΄ ν¨μλ νλ ¬ Aμ Bμ μ μΉνλ ¬μΈ ATμ BTλ₯Ό κ³±ν κ²°κ³Όλ₯Ό μΆλ ₯νλ ν¨μμ΄λ€.
μ΄λ ALPHAλ₯Ό κ³±ν κ°μ΄ μΆλ ₯ νλ ¬ Cμ λν΄μ§λ€.
λ΄λΆμ μΌλ‘λ OpenMPλ₯Ό μ¬μ©νμ¬ λ³λ ¬ μ²λ¦¬λ₯Ό μννλ€.
ν¨μ μ΄λ¦: gemm_tt
μ λ ₯:
int M: νλ ¬ Cμ ν κ°μ
int N: νλ ¬ Cμ μ΄ κ°μ
int K: νλ ¬ Aμ μ΄ κ°μ (νλ ¬ Bμ ν κ°μ)
float ALPHA: μ€μΉΌλΌ κ°
float *A: M x K ν¬κΈ°μ νλ ¬ A
int lda: νλ ¬ Aμ μ΄ κ°μ
float *B: K x N ν¬κΈ°μ νλ ¬ B
int ldb: νλ ¬ Bμ μ΄ κ°μ
float *C: M x N ν¬κΈ°μ νλ ¬ C
int ldc: νλ ¬ Cμ μ΄ κ°μ
λμ:
λ κ°μ νλ ¬ Aμ Bλ₯Ό κ³±ν κ²°κ³Όλ₯Ό νλ ¬ Cμ λμ νλ€.
Aμ Bλ μ μΉ(transpose)λμ΄ μλ€κ³ κ°μ νλ€.
OpenMPλ₯Ό μ¬μ©νμ¬ λ³λ ¬μ²λ¦¬νλ€.
μ€λͺ :
μΌλ°μ μΌλ‘ νλ ¬ κ³±μ μ°μ°μμλ A x Bμ B x Aλ λ€λ₯΄λ€. κ·Έλ¬λ gemm_tt ν¨μμμλ Aμ B λͺ¨λ μ μΉλ μνμμ κ³±μ μ μννκΈ° λλ¬Έμ A^T x B^T = (B x A)^Tμ κ°μ κ²°κ³Όλ₯Ό μ»λλ€.
i, j, kμ μμλ‘ 3μ€ for 루νλ₯Ό μννλ©°, κ°κ° C[ildc+j], A[i+klda], B[k+j*ldb]μ κ°μ μ°Έμ‘°νλ€.
κ°κ°μ C[i*ldc+j]μ κ°μ κ³μ°νκΈ° μν΄ sum λ³μλ₯Ό μ¬μ©νμ¬ λμ ν©μ κ³μ°νλ€.
OpenMPλ₯Ό μ¬μ©νμ¬ λ³λ ¬μ²λ¦¬νμ¬ μ±λ₯μ ν₯μμν¨λ€.