Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active July 31, 2024 06:20
Show Gist options
  • Save Chillee/42e4635c59760a74cb3b4ba7ea5ad9f8 to your computer and use it in GitHub Desktop.
Save Chillee/42e4635c59760a74cb3b4ba7ea5ad9f8 to your computer and use it in GitHub Desktop.
Strangely, Matrix Multiplications Run Faster When Given "Predictable" Data! https://www.thonking.ai/p/strangely-matrix-multiplications
import torch
torch.set_default_device('cuda')
from triton.testing import do_bench
from collections import defaultdict
from functools import partial
import random
random.seed(0)
def get_flops(A, B):
ms = do_bench(lambda: torch.mm(A, B))
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
return (1e3/ms) * flops
M = 8192
N = 8192
K = 8192
def get_tensors(f):
A = f(M, K, dtype=torch.bfloat16)
B = f(N, K, dtype=torch.bfloat16).t()
return A, B
def one_bit_random(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype)
return x
def sparse(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = torch.where(x < 0, 0, x)
return x
original_setups = [
("randn", torch.randn),
("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)),
("sparse", sparse),
("one bit", one_bit_random),
("rand", torch.rand),
("zeros", torch.zeros),
]
results = defaultdict(list)
setups = list(original_setups)
ITERS = 10
for _ in range(ITERS):
random.shuffle(setups)
for name, f in setups:
results[name].append(get_flops(*get_tensors(f)))
def median(x):
x = sorted(x)
if len(x) % 2 == 0:
return (x[len(x)//2] + x[(len(x) - 1)//2])/2
else:
return x[len(x)//2]
for name, _ in original_setups:
print(f"{name}: {median(results[name])/1e12}")
import torch
torch.set_default_device('cuda')
from triton.testing import do_bench
from collections import defaultdict
from functools import partial
import random
import subprocess
random.seed(0)
def set_gpu_limits(ref_sm_clock=1810, power_limit=330):
subprocess.check_output([
"sudo",
"nvidia-smi",
"-i",
"0",
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
])
subprocess.check_output([
"sudo",
"nvidia-smi",
"-i",
"0",
f"-pl={power_limit}",
])
def get_flops(A, B):
ms = do_bench(lambda: torch.mm(A, B))
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
return (1e3/ms) * flops
M = 8192
N = 8192
K = 8192
def get_tensors(f):
A = f(M, K, dtype=torch.bfloat16)
B = f(N, K, dtype=torch.bfloat16).t()
return A, B
def one_bit_random(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype)
return x
def sparse(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = torch.where(torch.rand_like(x) > 0.1, 0, x)
return x
def checkerboard(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = torch.where((torch.arange(shape[0]).view(1, -1) - torch.arange(shape[1]).view(-1, 1)) % 2 == 0, x, 0)
return x
def ternary(*shape, dtype=torch.bfloat16):
x = torch.randint(low=-1, high=2, size=shape, dtype=torch.bfloat16)
return x
original_setups = [
# ("zeros", torch.zeros),
("randn", torch.randn),
# ("checkerboard", checkerboard),
# ("sparse", sparse),
# ("rand", torch.rand),
# ("ternary", ternary),
# ("one bit", one_bit_random),
# ("all_pi", lambda *shape, dtype: torch.full(shape, fill_value=3.1415926535897932384626, dtype=dtype)),
# ("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)),
]
def get_results(clocks, power):
set_gpu_limits(clocks, power)
results = defaultdict(list)
setups = list(original_setups)
ITERS = 10
for _ in range(ITERS):
random.shuffle(setups)
for name, f in setups:
results[name].append(get_flops(*get_tensors(f)))
def median(x):
x = sorted(x)
if len(x) % 2 == 0:
return (x[len(x)//2] + x[(len(x) - 1)//2])/2
else:
return x[len(x)//2]
# for name, _ in original_setups:
# print(f"{name}: {median(results[name])/1e12}")
# print(median(results['zeros']) / median(results["randn"]))
return median(results['randn'])
start_clocks = 1980 # H100
for power in reversed([150, 200, 250, 300, 350, 400, 450, 500]):
max_clocks = 1980 # H100
start_flops = get_results(max_clocks, power)
for clocks in range(start_clocks, 200, -100):
# print(power, clocks)
cur_flops = get_results(clocks, power)
if cur_flops < start_flops * 0.9:
print("Done: ", power, clocks)
start_clocks = clocks
break
@Chillee
Copy link
Author

Chillee commented Apr 28, 2024

image

@TJ-Solergibert
Copy link

Hi! Thanks for an amazing post! I've run the mm_weird.py benchmark w/ H100 and I get the following results:

Run 1:
randn: 1024.3100185282144
twos: 803.3550678054524
sparse: 1086.47683669488
one bit: 830.4096678480972
rand: 837.8445385632689
zeros: 810.90379017118
Run 2:
randn: 1020.114801596814
twos: 803.5072206112413
sparse: 1060.8216568964108
one bit: 828.119089454572
rand: 832.9280217508104
zeros: 815.2949820775259
Run 3:
randn: 1015.1157728697485
twos: 808.3138761162128
sparse: 1074.7391939180266
one bit: 835.835139020573
rand: 836.061297508501
zeros: 812.8299565166335

I don't know what's even more estrange, getting +1000 TFLOPs or getting opposite results...

Toni

PD: I changed L28 with x = torch.randn(*shape, dtype=dtype)

@Chillee
Copy link
Author

Chillee commented Jun 21, 2024

@TJ-Solergibert That's indeed quite strange 🤔 In particular, you also see very high FLOPs for randn compared to zeros.

@TJ-Solergibert
Copy link

And I got the same behavior with 80GB A100…

@Chillee
Copy link
Author

Chillee commented Jun 21, 2024

@TJ-Solergibert Can you show your nvidia-smi?

@TJ-Solergibert
Copy link

Running on pretty exotic H100s (My A100 Cluster is down)
image

$ python3 mm_weird.py 
randn: 1055.2074280203815
twos: 805.9562233451499
sparse: 1085.084700223281
one bit: 897.8750065084391
rand: 906.2403213001224
zeros: 805.7438630174415

Have you tried re-running the script?

@Chillee
Copy link
Author

Chillee commented Jun 21, 2024

Yeah I reran it on an A100:
image

@photomz
Copy link

photomz commented Jul 31, 2024

Am I obtuse or isn't this just GPU speculative execution in action?

@Chillee
Copy link
Author

Chillee commented Jul 31, 2024

@photomz If you haven't seen it, this is the associated article: https://www.thonking.ai/p/strangely-matrix-multiplications

But speculative execution is an interesting thought - it's definitely a phenomenon that looks "fairly similar" on the surface, and I thought a bit about it.

But:

  1. I'm not actually sure GPUs do speculative execution? Since GPUs are generally focused on parallel execution, they usually have much shallower "pipes" than CPUs do.
  2. Matrix multiplications do not typically have any branching on data. Generally, I've seen speculative execution interact with branch prediction to result in differing performance based off of the input data. But in this case, there is no branching! The GPU executes the exact same assembly instructions in the exact same order, regardless of what the input data is.
  3. And related to the above, I'd have a hard time seeing how randn and rand would lead to differing speculative execution behavior on the GPUs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment