Skip to content

Instantly share code, notes, and snippets.

@nihui
Last active May 10, 2022 13:33
Show Gist options
  • Save nihui/17a9a03e64730e6d6042d42432654dd5 to your computer and use it in GitHub Desktop.
Save nihui/17a9a03e64730e6d6042d42432654dd5 to your computer and use it in GitHub Desktop.
int8 vector multiplication in loongson mmi and mips msa
// g++ mul.cpp -o mul -mmsa -mloongson-mmi -O3
// https://github.com/Tencent/ncnn/blob/master/src/layer/mips/loongson_mmi.h
// root@ls2k:~/ncnn/build# ./quant
// mul_s8x8 385.743
// mul_s8x8_mmi 611.364
// mul_s8x8_msa 173.241
// -66 2 0 4 10 18 28 40
#include <msa.h>
#include <stdio.h>
#include <stdlib.h>
#include "loongson_mmi.h"
#include <sys/time.h>
static double get_current_time()
{
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0;
}
__attribute__((noinline))
static void mul_s8x8(const signed char* vptr, const signed char* kptr, int* out)
{
out[0] = vptr[0] * kptr[0];
out[1] = vptr[1] * kptr[1];
out[2] = vptr[2] * kptr[2];
out[3] = vptr[3] * kptr[3];
out[4] = vptr[4] * kptr[4];
out[5] = vptr[5] * kptr[5];
out[6] = vptr[6] * kptr[6];
out[7] = vptr[7] * kptr[7];
}
__attribute__((noinline))
static void mul_s8x8_mmi(const signed char* vptr, const signed char* kptr, int* out)
{
int8x8_t _v = __mmi_pldb_s(vptr);
int8x8_t _k = __mmi_pldb_s(kptr);
int8x8_t _zero = __mmi_pzerob_s();
int8x8_t _extv = __mmi_pcmpgtb_s(_zero, _v);
int8x8_t _extk = __mmi_pcmpgtb_s(_zero, _k);
int16x4_t _v0 = (int16x4_t)__mmi_punpcklbh_s(_v, _extv);
int16x4_t _v1 = (int16x4_t)__mmi_punpckhbh_s(_v, _extv);
int16x4_t _k0 = (int16x4_t)__mmi_punpcklbh_s(_k, _extk);
int16x4_t _k1 = (int16x4_t)__mmi_punpckhbh_s(_k, _extk);
int16x4_t _s0l = __mmi_pmullh(_v0, _k0);
int16x4_t _s0h = __mmi_pmulhh(_v0, _k0);
int16x4_t _s1l = __mmi_pmullh(_v1, _k1);
int16x4_t _s1h = __mmi_pmulhh(_v1, _k1);
int32x2_t _s0 = (int32x2_t)__mmi_punpcklhw_s(_s0l, _s0h);
int32x2_t _s1 = (int32x2_t)__mmi_punpckhhw_s(_s0l, _s0h);
int32x2_t _s2 = (int32x2_t)__mmi_punpcklhw_s(_s1l, _s1h);
int32x2_t _s3 = (int32x2_t)__mmi_punpckhhw_s(_s1l, _s1h);
__mmi_pstw_s(out, _s0);
__mmi_pstw_s(out + 2, _s1);
__mmi_pstw_s(out + 4, _s2);
__mmi_pstw_s(out + 6, _s3);
}
__attribute__((noinline))
static void mul_s8x8_msa(const signed char* vptr, const signed char* kptr, int* out)
{
v16i8 _v = __msa_ld_b(vptr, 0);
v16i8 _k = __msa_ld_b(kptr, 0);
v8i16 _v01 = (v8i16)__msa_ilvr_b(__msa_clti_s_b(_v, 0), _v);
v8i16 _k01 = (v8i16)__msa_ilvr_b(__msa_clti_s_b(_k, 0), _k);
v8i16 _s01 = __msa_mulv_h(_v01, _k01);
v8i16 _exts01 = __msa_clti_s_h(_s01, 0);
v4i32 _s0 = (v4i32)__msa_ilvr_h(_exts01, _s01);
v4i32 _s1 = (v4i32)__msa_ilvl_h(_exts01, _s01);
__msa_st_w(_s0, out, 0);
__msa_st_w(_s1, out + 4, 0);
}
int main(int argc, char** argv)
{
signed char vptr[8] = {33, -2, 1, 4, 5, 6, 7, 8};
signed char kptr[8] = {-2, -1, 0, 1, 2, 3, 4, 5};
if (argc == 9)
{
vptr[0] = atoi(argv[0]);
vptr[1] = atoi(argv[1]);
vptr[2] = atoi(argv[2]);
vptr[3] = atoi(argv[3]);
vptr[4] = atoi(argv[4]);
vptr[5] = atoi(argv[5]);
vptr[6] = atoi(argv[6]);
vptr[7] = atoi(argv[7]);
kptr[0] = atoi(argv[0]);
kptr[1] = atoi(argv[1]);
kptr[2] = atoi(argv[2]);
kptr[3] = atoi(argv[3]);
kptr[4] = atoi(argv[4]);
kptr[5] = atoi(argv[5]);
kptr[6] = atoi(argv[6]);
kptr[7] = atoi(argv[7]);
}
int out[8];
double t0 = get_current_time();
for (int i = 0; i < 10000000; i++)
mul_s8x8(vptr, kptr, out);
double t1 = get_current_time();
for (int i = 0; i < 10000000; i++)
mul_s8x8_mmi(vptr, kptr, out);
double t2 = get_current_time();
for (int i = 0; i < 10000000; i++)
mul_s8x8_msa(vptr, kptr, out);
double t3 = get_current_time();
fprintf(stderr, "mul_s8x8 %.3f\n", t1-t0);
fprintf(stderr, "mul_s8x8_mmi %.3f\n", t2-t1);
fprintf(stderr, "mul_s8x8_msa %.3f\n", t3-t2);
fprintf(stderr, "%d %d %d %d %d %d %d %d\n", out[0], out[1], out[2], out[3], out[4], out[5], out[6], out[7]);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment