gemm
GEMM 이란?
참고 자료 : https://petewarden.com/2015/04/20/why-gemm-is-at-the-heart-of-deep-learning/
General Matrix to Matrix Multiplication
1979년에 만들어진 BLAS 라이브러리의 일부 입니다.
두개의 입력 행렬을 곱해서 출력을 얻는 방법 입니다.
딥러닝에서 대부분의 연산은 output = input * weight + bias
로 표현이 됩니다. 여기서 input
, output
, weight
를 행렬로 표현해서 GEMM을 사용해 연산할 수 있습니다.
Fully Connected Layer
fully connected layer
는 위와 같이 표현할 수 있습니다.
Convolutional Layer
im2col
: 3차원 이미지 배열을 2차원 배열로 변환합니다.
convolutional layer
는 위와 같이 표현할 수 있습니다. 위 그림의 경우는 stride
가 kernel size
와 같은 경우를 의미합니다.
gemm.c
gemm
gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
함수 이름: gemm
입력:
int TA: 행렬 A의 전치 여부 (0: 전치하지 않음, 1: 전치함)
int TB: 행렬 B의 전치 여부 (0: 전치하지 않음, 1: 전치함)
int M: 행렬 C의 행의 수
int N: 행렬 C의 열의 수
int K: 행렬 A의 열의 수 (행렬 B의 행의 수와 같아야 함)
float ALPHA: 스칼라 값
float *A: 행렬 A의 포인터
int lda: 행렬 A의 행 단위 크기
float *B: 행렬 B의 포인터
int ldb: 행렬 B의 행 단위 크기
float BETA: 스칼라 값
float *C: 행렬 C의 포인터
int ldc: 행렬 C의 행 단위 크기
동작:
행렬-행렬 곱셈 연산을 수행함.
설명:
이 함수는 CPU 상에서 행렬-행렬 곱셈 연산을 수행하는 함수이다.
gemm_cpu 함수를 호출하여 이 연산을 수행한다.
행렬 A와 행렬 B의 크기와 전치 여부, 스칼라 값 ALPHA와 BETA 등을 입력으로 받고, 연산 결과인 행렬 C를 출력으로 반환한다.
gemm_cpu
함수 이름: gemm_cpu
입력:
int TA: A 행렬의 전치 여부를 나타내는 플래그
int TB: B 행렬의 전치 여부를 나타내는 플래그
int M: C 행렬의 행 수
int N: C 행렬의 열 수
int K: A, B 행렬에서 공유하는 차원의 크기
float ALPHA: A, B 행렬의 곱에 대한 가중치
float *A: A 행렬의 포인터
int lda: A 행렬의 행 당 원소 수
float *B: B 행렬의 포인터
int ldb: B 행렬의 행 당 원소 수
float BETA: C 행렬에 대한 가중치
float *C: C 행렬의 포인터
int ldc: C 행렬의 행 당 원소 수
동작:
CPU에서 행렬 곱셈 연산을 수행한다.
A, B, C 세 개의 행렬을 인자로 받고, A와 B의 곱에 가중치 ALPHA를 곱한 결과를 C 행렬에 더한다.
설명:
gemm_cpu 함수는 CPU에서 행렬 곱셈 연산을 수행한다.
이 함수는 A, B, C 세 개의 포인터와 다양한 인자를 받아서, 행렬 곱셈 연산 결과를 C 행렬에 저장한다.
함수 내부에서는 TA와 TB 인자를 사용하여 A와 B 행렬이 전치되어 있는지 여부를 확인하고, 이에 따라 gemm_nn, gemm_tn, gemm_nt, gemm_tt 함수 중 하나를 호출한다.
이 함수들은 다양한 행렬 곱셈 연산 방법을 구현하고 있다.
따라서 gemm_cpu 함수는 이를 이용하여 입력으로 받은 행렬 A, B의 곱에 가중치 ALPHA를 곱한 결과를 C 행렬에 더한다.
이 때 BETA 인자를 사용하여 기존의 C 행렬 값에 대한 가중치를 조절할 수 있다.
gemm_nn
함수 이름: gemm_nn
입력:
M: 행렬 A의 행의 개수
N: 행렬 B의 열의 개수
K: 행렬 A의 열의 개수 또는 행렬 B의 행의 개수
ALPHA: 행렬 A와 행렬 B의 곱셈 결과에 곱해지는 스칼라 값
A: 크기 M x K의 행렬 A
lda: 행렬 A의 leading dimension
B: 크기 K x N의 행렬 B
ldb: 행렬 B의 leading dimension
C: 크기 M x N의 행렬 C
동작:
행렬 A와 B를 곱하여 행렬 C를 계산하는 General Matrix Multiply(GEMM) 연산을 수행한다.
i, k, j 세 개의 for 루프를 사용하여 행렬 C의 각 요소를 계산한다.
i 루프에서는 행렬 A의 각 행을 순회하며, k 루프에서는 행렬 A의 각 열과 행렬 B의 각 행을 순회하며, j 루프에서는 행렬 B의 각 열을 순회하며 행렬 C의 각 요소를 계산한다.
OpenMP를 사용하여 병렬 처리한다.
설명:
General Matrix Multiply(GEMM) 연산은 인공 신경망에서 가장 많이 사용되는 연산 중 하나이다.
GEMM 연산을 수행하는 방법은 여러 가지가 있으며, 이 함수에서는 A 행렬을 순회하면서 A와 B의 곱을 계산한다.
OpenMP는 멀티코어 CPU에서 병렬 처리를 수행할 수 있는 라이브러리로, 이를 사용하여 성능을 향상시킨다.
gemm_nt
함수 이름: gemm_nt
입력:
M: A 행렬의 행 개수
N: B 행렬의 열 개수
K: A 행렬의 열 개수 (동시에 B 행렬의 행 개수)
ALPHA: A와 B 행렬의 곱셈 결과에 곱해질 스칼라 값
*A: A 행렬의 포인터
lda: A 행렬의 행 당 원소 개수
*B: B 행렬의 포인터
ldb: B 행렬의 행 당 원소 개수
*C: C 행렬의 포인터
ldc: C 행렬의 행 당 원소 개수
동작:
행렬 A와 B를 곱한 후, C 행렬에 더해주는 연산을 수행한다.
A와 B 행렬을 곱하기 위해 A는 그대로, B는 전치(transpose)된 형태로 사용된다.
A의 i번째 행과 B의 j번째 열을 곱한 값을 C의 i번째 행 j번째 열에 누적하여 더해준다.
설명:
이 함수는 B 행렬이 전치된 형태로 입력으로 들어올 때 A와 B를 곱한 후 C 행렬에 더해주는 연산을 수행한다.
함수 내부에서는 OpenMP를 이용하여 병렬 처리를 수행하며, i, j, k 세 개의 for 루프를 이용하여 행렬의 원소 곱셈 및 덧셈 연산을 수행한다.
gemm_tn
함수 이름: gemm_tn
입력:
int M: 행렬 A의 행의 수
int N: 행렬 B의 열의 수
int K: 행렬 A의 열의 수 또는 행렬 B의 행의 수
float ALPHA: 곱해지는 상수
float *A: 행렬 A의 데이터 포인터
int lda: 행렬 A의 행 간격
float *B: 행렬 B의 데이터 포인터
int ldb: 행렬 B의 행 간격
float *C: 출력 행렬 C의 데이터 포인터
int ldc: 출력 행렬 C의 행 간격
동작:
행렬 A와 B의 전치행렬인 AT와 BT를 곱하고 ALPHA를 곱한 값을 출력 행렬 C에 더한다.
설명:
이 함수는 행렬 A와 B의 전치행렬인 AT와 BT를 곱한 결과를 출력하는 함수이다.
이때 ALPHA를 곱한 값이 출력 행렬 C에 더해진다.
내부적으로는 OpenMP를 사용하여 병렬 처리를 수행한다.
gemm_tt
함수 이름: gemm_tt
입력:
int M: 행렬 C의 행 개수
int N: 행렬 C의 열 개수
int K: 행렬 A의 열 개수 (행렬 B의 행 개수)
float ALPHA: 스칼라 값
float *A: M x K 크기의 행렬 A
int lda: 행렬 A의 열 개수
float *B: K x N 크기의 행렬 B
int ldb: 행렬 B의 열 개수
float *C: M x N 크기의 행렬 C
int ldc: 행렬 C의 열 개수
동작:
두 개의 행렬 A와 B를 곱한 결과를 행렬 C에 누적한다.
A와 B는 전치(transpose)되어 있다고 가정한다.
OpenMP를 사용하여 병렬처리한다.
설명:
일반적으로 행렬 곱셈 연산에서는 A x B와 B x A는 다르다. 그러나 gemm_tt 함수에서는 A와 B 모두 전치된 상태에서 곱셈을 수행하기 때문에 A^T x B^T = (B x A)^T와 같은 결과를 얻는다.
i, j, k의 순서로 3중 for 루프를 수행하며, 각각 C[ildc+j], A[i+klda], B[k+j*ldb]의 값을 참조한다.
각각의 C[i*ldc+j]의 값을 계산하기 위해 sum 변수를 사용하여 누적 합을 계산한다.
OpenMP를 사용하여 병렬처리하여 성능을 향상시킨다.
Last updated