Recent efforts to run LLMs send us searching for some element types to quantize weights and activations into, that will somehow be wide enough to provide enough accuracy, and narrow enough to provide enough performance and/or memory compression.
This document is about the "performance" dimension, specifically on x86 and Arm architectures.
This section is meant to provide a mental model that's simple and generally applicable on most CPU architectures, intentionally not getting into micro-architecture details. It should be accurate enough to inform high-level decisions of which element type to quantize to.
Think of CPU arithmetic throughput as the product (literally, multiplication) of several factors:
(scalar ops per second) = (Number of threads) * (Cycles per second) * (Instructions per cycle) * (Scalar ops per instruction)
To put that in terms of familiar units,
GFlop/s = (Number of threads) * GHz * (Instructions per cycle) * (Scalar ops per instruction)
Now let's simplify this picture.
First, since we are discussing arithmetic details, we don't need to discuss threads here. The choice of an element type is generally orthogonal to the choice of how many threads are used. Caveat: that's actually not 100% always true, as there are cases where some element type can cause thermal throttling, and other cases where some element types are handled by ALUs that are shared across hardware threads, so there can be an interplay here. But at the level of detail that we are interested here, we can ignore that. Suffices to know that this consideration only adds more reward for favoring narrower element types.
Second, let's discuss the (Instructions per cycle)
factor to understand that it is essentially out of our control and therefore, can be factored out of this discussion as well. CPUs have pipelines that allow multiple instructions to be executed in parallel. In general, utilizing that parallelism seems like it would require careful thinking about each instruction's ability to simultaneously execute with others. And that's indeed the case in general, which is why in general, CPU instruction scheduling is delicate. But here we don't need to deal with this full generality, because we are looking specifically at matrix multiplications, and the SIMD arithmetic instructions that we need are whichever SIMD instruction the ISA offers to perform some flavor of small batch matrix multiplications. And for these specific instructions, most CPUs tend to have the same number of (Instructions per cycle)
for all of the instructions that we might want to use. There are some exceptions, that we will call out.
In case you're curious, here's what that (Instructions per cycle)
figure usually looks like on common CPUs. If you look up CPU specs, this also matches what is commonly referred to as the number of SIMD pipelines. For example, if a CPU has a 4x128bit pipeline, it means you get 4 instructions per cycle (and each of them computes a 128-bit vector).
CPU | Architecture | SIMD instructions per cycle | SIMD vector length |
---|---|---|---|
High-end Intel | x86_64 | 2 | 512-bit |
AMD Zen4 | x86_64 | 1 | 512-bit |
AMD Zen4 | x86_64 | 2 | 256-bit |
Low-end Intel or AMD | x86_64 | 1 | 256-bit |
Apple performance cores | arm_64 | 4 | 128-bit |
Cortex-X2 (biggest core in recent Android phone) | arm_64 | 4 | 128-bit |
Cortex-A710 (second-tier big core in recent Android phone) | arm_64 | 2 | 128-bit |
Cortex-A510 (little core in recent Android phone) | arm_64 | 1 | 128-bit |
But again, this shouldn't matter, as this can be factored out from element-type decisions as explained above.
Anyway, with these two simplifications, we are down to only having to think about the last factor: (Scalar ops per instruction)
. That is, we just need to understand which instructions each ISA offers for each element type, and how many scalar ops they each perform. And then add ah-hoc caveats when some instruction has a lower (Instructions per cycle)
than others on some CPU.
Let's now catalogue SIMD instructions that can be useful to implement matrix multiplications on x86 and Arm.
Here's a mental model to categorize them: all of these instructions can be seen as doing a tiny batch-matmul-with-transposed-RHS op by themselves. As batch-matmul ops are characterized by four dimensions (B, M, N, K), with B standing for "batch size" and K being the reduction dimension size, so can we describe all of these instructions using such (B, M, N, K) tuples.
As pseudocode, each such instruction performs
for (b in range(B))
for (m in range(M))
for (n in range(N))
for (k in range(K))
out[b, m, n] += lhs[b, m, k] * rhs[b, n, k]
Any such instruction performs B*M*N*K
scalar multiply-accumulates. As each multiply-accumulate is conventionally counted as 2 ops, that means that any such instruction performs 2*B*M*N*K
ops.
An exception is instructions that do not accumulate into an existing accumulator, i.e. that behave as if the existing accumulator was zero-filled. We will reflect that in the Ops count in the table below and add an asterisk (*
) to highlight that.
Let's record some observations here before the big table below.
- CPUs have limited support for
f16
, particularly on x86. There's a current trend of gettingf16
models from early GPU experiments and assuming that'll be a decent starting point for CPU as well. Often, it's not. On x86, outside of very-recent Intel Sapphire Rapids, there is no nativef16
arithmetic. So it's slow. - If you are going to do 16-bit floating point on CPU, do
bf16
operands andf32
accumulator matmuls, and keep the rest inf32
. Indeed,bf16
has surprisingly good support in recent x86 and Arm. Even whenf16
is supported,bf16
is supported with better instructions.f16
gets at best some plain element-wise arithmetic,bf16
gets fancy dot-product / matmul instructions. - If your workload is small-bit-depth integers, try to run it in integer arithmetic rather than dequantizing to floating point. On Arm, try hard to expand to 8bit for arithmetic. On x86, the bulk of narrow-integer arithmetic support is for 16bit operands, although if you can stomach the
signed*unsigned
semantics, there are some 8bit instructions that you can take further advantage of. - You should probably not worry about having things fit in 16bit accumulator. On both x86 and Arm, you have narrow-integer arithmetic into 32bit accumulators.
Architecture | LHS | RHS | Out | Instruction | B | M | N | K | Ops | CPU feature | Support | Caveats |
---|---|---|---|---|---|---|---|---|---|---|---|---|
arm_64 | f32 | f32 | f32 | FMLA (vector) | 4 | 1 | 1 | 1 | 8 | baseline | universal | |
arm_64 | f32 | f32 | f32 | FMLA (by element) | 1 | 4 | 1 | 1 | 8 | baseline | universal | |
arm_64 | f16 | f16 | f16 | FMLA (vector) | 8 | 1 | 1 | 1 | 16 | +fp16 | ~2018 | |
arm_64 | f16 | f16 | f16 | FMLA (by element) | 1 | 8 | 1 | 1 | 16 | +fp16 | ~2018 | |
arm_64 | f16 | f16 | f32 | FMLAL (vector) | 4 | 1 | 1 | 1 | 8 | +fp16fml | ~2020 | |
arm_64 | f16 | f16 | f32 | FMLAL (by element) | 1 | 4 | 1 | 1 | 8 | +fp16fml | ~2020 | |
arm_64 | bf16 | bf16 | f32 | BFMMLA | 1 | 2 | 2 | 4 | 32 | +bf16 | ~2022 | Slow on Apple M2 |
arm_64 | i16 | i16 | i32 | SMLAL (vector) (variants for all signednesses) | 4 | 1 | 1 | 1 | 8 | baseline | universal | |
arm_64 | i16 | i16 | i32 | SMLAL (by element) (variants for all signednesses) | 1 | 4 | 1 | 1 | 8 | baseline | universal | |
arm_64 | i8 | i8 | i32 | SDOT (variants for all signedness) | 1 | 4 | 1 | 4 | 32 | +dotprod | ~2018 | |
arm_64 | i8 | i8 | i32 | SDOT (vector) (variants for all signedness) | 4 | 1 | 1 | 4 | 32 | +dotprod | ~2018 | |
arm_64 | i8 | i8 | i32 | SDOT (by element) (variants for all signedness) | 1 | 4 | 1 | 4 | 32 | +dotprod | ~2018 | |
arm_64 | i8 | i8 | i32 | SMMLA (variants for all signedness) | 1 | 2 | 2 | 8 | 64 | +i8mm | ~2022 | Slow on Apple M2 |
x86_64 | f32 | f32 | f32 | VFMADD231PS (ymm*ymm) | 8 | 1 | 1 | 1 | 16 | +fma | near-universal | |
x86_64 | f32 | f32 | f32 | VFMADD231PS (zmm*zmm) | 16 | 1 | 1 | 1 | 32 | +avx512f | High-end Intel , AMD Zen4 | Some Intel CPUs throttle on AVX-512. AMD Zen4 only has 256-bit pipelines, so most 512-bit ops are broken into 2 256-bit ops. |
x86_64 | f32 | f32 | f32 | VFMADD231PS (zmm*m32bcst) | 1 | 16 | 1 | 1 | 32 | +avx512f | See above about AVX-512 | See above caveats about AVX-512 |
x86_64 | f16 | f16 | f16 | VFMADD231PH (ymm*ymm) | 16 | 1 | 1 | 1 | 32 | +avx512fp16 | Only Intel Sapphire Rapids | |
x86_64 | f16 | f16 | f16 | VFMADD231PH (zmm*zmm) | 32 | 1 | 1 | 1 | 64 | +avx512fp16 | Only Intel Sapphire Rapids | See above caveats about AVX-512 |
x86_64 | f16 | f16 | f16 | VFMADD231PH (zmm*m16bcst) | 1 | 32 | 1 | 1 | 64 | +avx512fp16 | Only Intel Sapphire Rapids | See above caveats about AVX-512 |
x86_64 | bf16 | bf16 | f32 | VDPBF16PS (ymm*ymm) | 8 | 1 | 1 | 2 | 32 | +avx512bf16 | ~2022 CPUs from both Intel and AMD (Zen4) | |
x86_64 | bf16 | bf16 | f32 | VDPBF16PS (zmm*zmm) | 16 | 1 | 1 | 2 | 64 | +avx512bf16 | ~2022 CPUs from both Intel and AMD (Zen4) | See above caveats about AVX-512 |
x86_64 | bf16 | bf16 | f32 | VDPBF16PS (zmm*m32bcst) | 1 | 16 | 1 | 2 | 64 | +avx512bf16 | See above about AVX-512 | See above caveats about AVX-512 |
x86_64 | s8 | u8 | s16 | VPMADDUBSW (ymm*ymm) | 16 | 1 | 1 | 2 | 48* | +avx2 | near-universal | |
x86_64 | s8 | u8 | s16 | VPMADDUBSW (zmm*zmm) | 32 | 1 | 1 | 2 | 96* | +avx512bw | See above about AVX-512 | |
x86_64 | s8 | u8 | s32 | VPDPBUSD (ymm*ymm) | 8 | 1 | 1 | 4 | 64 | +avx512vnni | ~2022 CPUs from both Intel and AMD (Zen4) | |
x86_64 | s8 | u8 | s32 | VPDPBUSD (zmm*zmm) | 16 | 1 | 1 | 4 | 128 | +avx512vnni | ~2022 CPUs from both Intel and AMD (Zen4) | |
x86_64 | s8 | u8 | s32 | VPDPBUSD (zmm*m32bcst) | 1 | 16 | 1 | 4 | 128 | +avx512vnni | ~2022 CPUs from both Intel and AMD (Zen4) | |
x86_64 | s16 | s16 | s32 | VPMADDWD (ymm*ymm) | 8 | 1 | 1 | 2 | 24* | +avx2 | near-universal | |
x86_64 | s16 | s16 | s32 | VPMADDWD (zmm*zmm) | 16 | 1 | 1 | 2 | 48* | +avx512bw | See above about AVX-512 | |
x86_64 | s16 | s16 | s32 | VPDPWSSD (ymm*ymm) | 8 | 1 | 1 | 2 | 32 | +avx512vnni | ~2022 CPUs from both Intel and AMD (Zen4) | |
x86_64 | s16 | s16 | s32 | VPDPWSSD (zmm*zmm) | 16 | 1 | 1 | 2 | 64 | +avx512vnni | ~2022 CPUs from both Intel and AMD (Zen4) | |
x86_64 | s16 | s16 | s32 | VPDPWSSD (zmm*m32bcst) | 1 | 16 | 1 | 2 | 64 | +avx512vnni | ~2022 CPUs from both Intel and AMD (Zen4) |