Created
March 19, 2020 14:52
-
-
Save ppetrushkov/694bc7ec0f7663c63e067e9ecfdc7d99 to your computer and use it in GitHub Desktop.
DNNL vs FBGEMM u8s8s32 single core performance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <chrono> | |
#include <cmath> | |
#include "dnnl.hpp" | |
#include "fbgemm/Fbgemm.h" | |
template<typename T> | |
void init_random(int size, T* ptr) { | |
for (int i=0; i<size; ++i) | |
ptr[i] = static_cast<T>(127.0f*static_cast<float>(rand()) / static_cast <float> (RAND_MAX)); | |
} | |
template<typename T> | |
void init_zero(int size, T* ptr) { | |
for (int i=0; i<size; ++i) | |
ptr[i] = static_cast<T>(0); | |
} | |
int main() { | |
int m = 16; | |
int n = 768; | |
int k = 3072; | |
int iters = 1000; | |
//DNNL | |
char trans_a = 'N'; | |
char trans_b = 'N'; | |
char offsetc = 'F'; | |
float alpha = 1.0f; | |
int lda = k; | |
uint8_t ao = 0; | |
int ldb = n; | |
int8_t bo = 0; | |
float beta = 0.0f; | |
int ldc = n; | |
std::array<int32_t, 1> oc = {0}; | |
void* weights_p = std::malloc(k*n*sizeof(int8_t)); | |
int8_t* weights = static_cast<int8_t*>(weights_p); | |
init_random(k*n, weights); | |
void* input_p = std::malloc(m*k*sizeof(uint8_t)); | |
uint8_t* input = static_cast<uint8_t*>(input_p); | |
init_random(m*k, input); | |
void* output_p = std::malloc(m*n*sizeof(int)); | |
int* output = static_cast<int*>(output_p); | |
float total=0.0f; | |
int dryrun=3; | |
for (int i=0; i<iters; ++i) { | |
std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); | |
dnnl_gemm_u8s8s32(trans_a, trans_b, offsetc, | |
m, n, k, alpha, const_cast<const uint8_t*>(input), lda, ao, | |
const_cast<const int8_t*>(weights), ldb, bo, | |
beta, output, ldc, oc.data()); | |
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); | |
if (i>=dryrun) | |
total += std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count(); | |
} | |
std::cout << "DNNL Average time = " << total/(iters-dryrun) << std::endl; | |
//FBGEMM | |
int32_t* col_offsets = static_cast<int32_t*>(std::malloc(n*sizeof(int32_t))); | |
init_zero(n, col_offsets); | |
auto packedBN = fbgemm::PackBMatrix<int8_t, int32_t>(fbgemm::matrix_op_t::NoTranspose, | |
k, | |
n, | |
weights, | |
n); | |
float input_scale = 1.0f/255; | |
int32_t input_zeropoint = 0; | |
float* weight_scales = static_cast<float*>(std::malloc(n*sizeof(float))); | |
for (int i=0; i<n; ++i) | |
weight_scales[i] = 1.0f/255; | |
int32_t* weight_zeropoints = static_cast<int32_t*>(std::malloc(n*sizeof(int32_t))); | |
init_zero(n, weight_zeropoints); | |
total=0.0f; | |
for (int i=0; i<iters; ++i) { | |
std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now(); | |
fbgemm::PackAWithRowOffset<uint8_t,int32_t> packAN( | |
fbgemm::matrix_op_t::NoTranspose, | |
m, | |
k, | |
input, | |
k, //lda==k for no-transpose | |
nullptr, //buffer for packed matrix | |
1, //groups | |
nullptr); //buffer for packed data | |
fbgemm::DoNothing<float, float> doNothingObj{}; | |
fbgemm::ReQuantizeForFloat<false, fbgemm::QuantizationGranularity::OUT_CHANNEL> outputProcObj( | |
doNothingObj, | |
input_scale, | |
weight_scales, | |
input_zeropoint, | |
weight_zeropoints, | |
packAN.getRowOffsetBuffer(), //row offsets | |
col_offsets, //column offsets | |
nullptr, | |
n); | |
fbgemm::fbgemmPacked( | |
packAN, | |
packedBN, | |
static_cast<float*>(output_p), | |
(int32_t*)output_p, | |
(int32_t)n, //ldc==n | |
outputProcObj, | |
0, //thread id | |
1); //num threads | |
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); | |
if (i>=dryrun) | |
total += std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count(); | |
} | |
std::cout << "FBGEMM Average time = " << total/(iters-dryrun) << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment