Created
September 28, 2022 16:18
-
-
Save bjourne/ce1e189b926d3aa3ff8772e9f5252cc2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Notes: | |
// | |
// * unsigned int vs. int: makes a small difference for clang but | |
// probably not for gcc. | |
// * best tiling appears to be 256x256x256. | |
// | |
// 12.31 for two 8192 matrices | |
// | |
// | |
#include <assert.h> | |
#include <math.h> | |
#include <pthread.h> | |
#include <stdbool.h> | |
#include <stdlib.h> | |
#include <stdio.h> | |
#include <time.h> | |
#include <xmmintrin.h> | |
#define A_ROWS 101 | |
#define A_COLS 64 | |
#define B_ROWS A_COLS | |
#define B_COLS 64 | |
#define C_ROWS A_ROWS | |
#define C_COLS B_COLS | |
#define A_N_BYTES (A_ROWS * A_COLS * sizeof(float)) | |
#define B_N_BYTES (B_ROWS * B_COLS * sizeof(float)) | |
#define C_N_BYTES (C_ROWS * C_COLS * sizeof(float)) | |
#ifndef TILE_I | |
#define TILE_I 32 | |
#endif | |
#ifndef TILE_J | |
#define TILE_J 32 | |
#endif | |
#ifndef TILE_K | |
#define TILE_K 32 | |
#endif | |
#ifndef N_THREADS | |
#define N_THREADS 1 | |
#endif | |
#define SIMD_HEIGHT 2 | |
#define SIMD_WIDTH 16 | |
#define MIN(a, b) ((a > b) ? (b) : (a)) | |
typedef unsigned int uint_t; | |
void | |
mul_slow(float * restrict A, | |
float * restrict B, | |
float * restrict C, | |
uint_t a_rows, uint_t a_cols, | |
uint_t b_rows, uint_t b_cols) { | |
for (uint_t i = 0; i < a_rows; i++) { | |
for (uint_t j = 0; j < b_cols; j++) { | |
float v = 0; | |
for (uint_t k = 0; k < b_rows; k++) { | |
v += A[a_cols * i + k] * B[b_cols * k + j]; | |
} | |
C[b_cols * i + j] = v; | |
} | |
} | |
} | |
void | |
print_mat(float *M, int size) { | |
for (int i = 0; i < size; i++) { | |
for (int j = 0; j < size; j++) { | |
printf("%3.0f ", M[i * size + j]); | |
} | |
printf("\n"); | |
} | |
printf("\n"); | |
} | |
static void | |
mul_fast_tile_16x2( | |
uint_t i0, uint_t i1, | |
uint_t j0, uint_t j1, | |
uint_t k0, uint_t k1, | |
float * restrict Apt, | |
float * restrict Bpt, | |
float * restrict C, | |
uint_t a_rows, uint_t a_cols, | |
uint_t b_rows, uint_t b_cols | |
) { | |
for (uint_t i = i0; i < i1; i += SIMD_HEIGHT) { | |
float * restrict Bptr = Bpt; | |
float * restrict Cptr0 = &C[b_cols * (i + 0) + j0]; | |
float * restrict Cptr1 = &C[b_cols * (i + 1) + j0]; | |
for (uint_t j = j0; j < j1; j += SIMD_WIDTH) { | |
__m128 acc00 = _mm_load_ps(Cptr0 + 0); | |
__m128 acc01 = _mm_load_ps(Cptr0 + 4); | |
__m128 acc02 = _mm_load_ps(Cptr0 + 8); | |
__m128 acc03 = _mm_load_ps(Cptr0 + 12); | |
__m128 acc10 = _mm_load_ps(Cptr1 + 0); | |
__m128 acc11 = _mm_load_ps(Cptr1 + 4); | |
__m128 acc12 = _mm_load_ps(Cptr1 + 8); | |
__m128 acc13 = _mm_load_ps(Cptr1 + 12); | |
float * restrict Aptr = &Apt[a_cols * i + SIMD_HEIGHT * k0]; | |
for (uint_t k = k0; k < k1; k++) { | |
__m128 a0 = _mm_set1_ps(*Aptr++); | |
__m128 a1 = _mm_set1_ps(*Aptr++); | |
__m128 b0 = _mm_load_ps(Bptr + 0); | |
__m128 b1 = _mm_load_ps(Bptr + 4); | |
__m128 b2 = _mm_load_ps(Bptr + 8); | |
__m128 b3 = _mm_load_ps(Bptr + 12); | |
Bptr += SIMD_WIDTH; | |
acc00 = _mm_add_ps(acc00, _mm_mul_ps(a0, b0)); | |
acc01 = _mm_add_ps(acc01, _mm_mul_ps(a0, b1)); | |
acc02 = _mm_add_ps(acc02, _mm_mul_ps(a0, b2)); | |
acc03 = _mm_add_ps(acc03, _mm_mul_ps(a0, b3)); | |
acc10 = _mm_add_ps(acc10, _mm_mul_ps(a1, b0)); | |
acc11 = _mm_add_ps(acc11, _mm_mul_ps(a1, b1)); | |
acc12 = _mm_add_ps(acc12, _mm_mul_ps(a1, b2)); | |
acc13 = _mm_add_ps(acc13, _mm_mul_ps(a1, b3)); | |
} | |
_mm_store_ps(Cptr0 + 0, acc00); | |
_mm_store_ps(Cptr0 + 4, acc01); | |
_mm_store_ps(Cptr0 + 8, acc02); | |
_mm_store_ps(Cptr0 + 12, acc03); | |
_mm_store_ps(Cptr1 + 0, acc10); | |
_mm_store_ps(Cptr1 + 4, acc11); | |
_mm_store_ps(Cptr1 + 8, acc12); | |
_mm_store_ps(Cptr1 + 12, acc13); | |
Cptr0 += SIMD_WIDTH; | |
Cptr1 += SIMD_WIDTH; | |
} | |
} | |
} | |
typedef struct { | |
float *A, *Bp, *C; | |
uint_t start_i, end_i; | |
uint_t a_rows, a_cols; | |
uint_t b_rows, b_cols; | |
} mul_job_t; | |
static void * | |
mul_thread(void *arg) { | |
mul_job_t *job = (mul_job_t *)arg; | |
float *A = job->A; | |
float *Bp = job->Bp; | |
float *C = job->C; | |
uint_t start_i = job->start_i; | |
uint_t end_i = job->end_i; | |
uint_t a_rows = job->a_rows; | |
uint_t a_cols = job->a_cols; | |
uint_t b_rows = job->b_rows; | |
uint_t b_cols = job->b_cols; | |
for (uint_t i = start_i; i < end_i; i += TILE_I) { | |
uint_t imax = MIN(i + TILE_I, a_rows); | |
for (uint_t j = 0; j < b_cols; j += TILE_J) { | |
uint_t jmax = MIN(j + TILE_J, b_cols); | |
for (uint_t k = 0; k < a_cols; k += TILE_K) { | |
uint_t kmax = MIN(k + TILE_K, a_cols); | |
float *Bptr = &Bp[k * b_cols + j * TILE_K]; | |
mul_fast_tile_16x2(i, imax, j, jmax, k, kmax, | |
A, Bptr, C, | |
a_rows, a_cols, | |
b_rows, b_cols); | |
} | |
} | |
} | |
return 0; | |
} | |
/* static float *Abuf = NULL; */ | |
/* static float *Bbuf = NULL; */ | |
static void | |
mul_fast(float * restrict A, | |
float * restrict B, | |
float * restrict C, | |
uint_t a_rows, uint_t a_cols, | |
uint_t b_rows, uint_t b_cols) { | |
assert(TILE_I % SIMD_HEIGHT == 0); | |
assert(TILE_J % SIMD_WIDTH == 0); | |
float *Bbuf = malloc(sizeof(float) * b_rows * b_cols); | |
float *Bptr = Bbuf; | |
for (uint_t k = 0; k < b_rows; k += TILE_K) { | |
for (uint_t j = 0; j < b_cols; j += TILE_J) { | |
for (uint_t y = 0; y < TILE_J; y += SIMD_WIDTH) { | |
for (uint_t x = 0; x < TILE_K; x++) { | |
uint_t row = k + x; | |
uint_t col = j + y; | |
for (uint_t o = 0; o < SIMD_WIDTH; o++) { | |
*Bptr++ = B[row * b_cols + col + o]; | |
} | |
} | |
} | |
} | |
} | |
assert(Bptr - Bbuf == b_rows * b_cols); | |
int a_padded_rows = ceil((float)a_rows / (float)SIMD_HEIGHT) * SIMD_HEIGHT; | |
float *Abuf = malloc(a_padded_rows * a_cols * sizeof(float)); | |
float *Aptr = Abuf; | |
for (int i = 0; i < a_padded_rows; i += SIMD_HEIGHT) { | |
for (int j = 0; j < a_cols; j++) { | |
for (int k = i; k < i + SIMD_HEIGHT; k++) { | |
if (k < a_rows && j < a_cols) { | |
*Aptr++ = A[a_cols * k + j]; | |
} else { | |
*Aptr++ = 0.0; | |
} | |
} | |
} | |
} | |
assert(Aptr - Abuf == a_padded_rows * a_cols); | |
pthread_t threads[N_THREADS]; | |
mul_job_t jobs[N_THREADS]; | |
int n_i_tiles = (int)ceil((float)a_rows / (float)TILE_I); | |
int tiles_per_thread = (int)ceil((float)n_i_tiles / (float)N_THREADS); | |
for (int i = 0; i < N_THREADS; i++) { | |
int start = TILE_I * i * tiles_per_thread; | |
int end = MIN(TILE_I * (i + 1) * tiles_per_thread, a_rows); | |
jobs[i] = (mul_job_t){ | |
Abuf, Bbuf, C, start, end, | |
a_rows, a_cols, | |
b_rows, b_cols | |
}; | |
pthread_create(&threads[i], NULL, mul_thread, &jobs[i]); | |
} | |
for (int i = 0; i < N_THREADS; i++) { | |
pthread_join(threads[i], NULL); | |
} | |
free(Abuf); | |
free(Bbuf); | |
} | |
int | |
main(int argc, char *argv[]) { | |
float *A = malloc(A_N_BYTES); | |
float *B = malloc(B_N_BYTES); | |
float *C = calloc(C_N_BYTES, 1); | |
float *c_ref = calloc(C_N_BYTES, 1); | |
/* Abuf = malloc(A_N_BYTES); */ | |
/* Bbuf = malloc(B_N_BYTES); */ | |
for (int i = 0; i < A_ROWS * A_COLS; i++) { | |
A[i] = (rand() / (float)RAND_MAX) * 5; | |
} | |
for (int i = 0; i < B_ROWS * B_COLS; i++) { | |
B[i] = (rand() / (float)RAND_MAX) * 5; | |
} | |
mul_slow(A, B, c_ref, | |
A_ROWS, A_COLS, B_ROWS, B_COLS); | |
struct timespec begin, end; | |
clock_gettime(CLOCK_MONOTONIC_RAW, &begin); | |
mul_fast(A, B, C, | |
A_ROWS, A_COLS, B_ROWS, B_COLS); | |
clock_gettime(CLOCK_MONOTONIC_RAW, &end); | |
double delta = (end.tv_nsec - begin.tv_nsec) / 1000000000.0 + | |
(end.tv_sec - begin.tv_sec); | |
float gflops = (long)A_ROWS * (long)A_COLS * (long)B_COLS | |
/ (delta * 1000.0 * 1000.0 * 1000.0); | |
printf("[%4d,%4d] * [%4d,%4d] = [%4d, %4d] %8d %6d %6d %6d %6.2f %7.2f\n", | |
A_ROWS, A_COLS, B_ROWS, B_COLS, C_ROWS, C_COLS, | |
N_THREADS, TILE_I, TILE_J, TILE_K, delta, gflops); | |
for (int i = 0; i < C_ROWS; i++) { | |
for (int j = 0; j < C_COLS; j++) { | |
float v = C[C_COLS * i + j]; | |
float v2 = c_ref[C_COLS * i + j]; | |
float diff = fabs(v - v2); | |
if (diff > 0.1) { | |
printf("%d %d, %6.2f %6.2f\n", i, j, v, v2); | |
assert(false); | |
} | |
} | |
} | |
free(A); | |
free(B); | |
free(C); | |
free(c_ref); | |
/* free(Abuf); */ | |
/* free(Bbuf); */ | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment