Created
April 14, 2024 15:20
-
-
Save skeeto/92622272d75f8e872876f3dfb8c34dc0 to your computer and use it in GitHub Desktop.
"Two Sum" benchmark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// "Two Sum" benchmark | |
// $ cc -O1 -nostartfiles -o twoSum.exe twoSum.c | |
// $ cl /O1 twoSums.c /link /subsystem:console kernel32.lib libvcruntime.lib | |
// Ref: https://old.reddit.com/r/C_Programming/comments/1c36391 | |
#include <stddef.h> | |
#include <stdint.h> | |
#include <string.h> | |
#define assert(c) while (!(c)) *(volatile int *)0 = 0 | |
#define new(a, t, n) (t *)alloc(a, sizeof(t), _Alignof(t), n) | |
#define W32(r) __declspec(dllimport) r __stdcall | |
W32(void) ExitProcess(int32_t); | |
W32(uintptr_t) GetStdHandle(int32_t); | |
W32(int32_t) QueryPerformanceCounter(uint64_t *); | |
W32(int32_t) QueryPerformanceFrequency(uint64_t *); | |
W32(void *) VirtualAlloc(uintptr_t, size_t, int32_t, int32_t); | |
W32(int32_t) WriteFile(uintptr_t, uint8_t *, int32_t, int32_t *, uintptr_t); | |
typedef struct { | |
char *beg; | |
char *end; | |
} arena; | |
static void *alloc(arena *a, ptrdiff_t size, ptrdiff_t align, ptrdiff_t count) | |
{ | |
ptrdiff_t pad = (uintptr_t)a->end & (align - 1); | |
assert(count < (a->end - a->beg - pad)/size); | |
return memset(a->end -= size*count + pad, 0, size*count); | |
} | |
typedef struct { | |
int32_t x; | |
int32_t y; | |
} result; | |
static result twoSum1(int32_t target, int32_t *nums, int32_t len) | |
{ | |
for (int32_t x = 0; x < len; x++) { | |
for (int32_t y = 0; y < len; y++) { | |
if (nums[x]+nums[y] == target && nums[x]!=nums[y]) { | |
return (result){nums[x], nums[y]}; | |
} | |
} | |
} | |
assert(0); | |
} | |
typedef struct map map; | |
struct map { | |
map *child[4]; | |
int32_t value; | |
}; | |
static uint64_t hash(int32_t x, int32_t y) | |
{ | |
// NOTE: hash must be commutative for arguments | |
return 1111111111111111111u * x * y; | |
} | |
static int32_t *upsert(map **m, int32_t target, int32_t value, arena *perm) | |
{ | |
for (uint64_t h = hash(value, target-value); *m; h <<= 2) { | |
if ((*m)->value+value == target) { | |
return &(*m)->value; | |
} | |
m = &(*m)->child[h>>62]; | |
} | |
*m = new(perm, map, 1); | |
(*m)->value = value; | |
return &(*m)->value; | |
} | |
static result twoSum2(int32_t target, int32_t *nums, int32_t len, arena scratch) | |
{ | |
map *m = 0; | |
for (int32_t i = 0; i < len; i++) { | |
int32_t *value = upsert(&m, target, nums[i], &scratch); | |
if (*value != nums[i]) { | |
return (result){*value, nums[i]}; | |
} | |
} | |
assert(0); | |
} | |
static int32_t randint(uint64_t *rng, int32_t min, int32_t max) | |
{ | |
*rng = *rng*0x3243f6a8885a308du + 1; | |
return (int32_t)(((*rng>>32)*(max - min))>>32) + min; | |
} | |
static int32_t generate(int32_t *nums, int32_t len, uint64_t seed) | |
{ | |
for (int32_t i = 0; i < len; i++) { | |
nums[i] = randint(&seed, -1000000000, +1000000001); | |
} | |
for (;;) { | |
int32_t i = randint(&seed, 0, len); | |
int32_t j = randint(&seed, 0, len); | |
int32_t target = nums[i] + nums[j]; | |
if (nums[i]!=nums[j] && target>=-1000000000 && target<=+1000000000) { | |
return target; | |
} | |
} | |
} | |
typedef struct { | |
uint8_t buf[1<<12]; | |
int32_t len; | |
int32_t err; | |
} bufout; | |
static void flush(bufout *b) | |
{ | |
if (!b->err && b->len) { | |
uintptr_t stdout = GetStdHandle(-11); | |
b->err = !WriteFile(stdout, b->buf, b->len, &b->len, 0); | |
b->len = 0; | |
} | |
} | |
static void print(bufout *b, uint8_t *buf, ptrdiff_t len) | |
{ | |
for (ptrdiff_t off = 0; !b->err && off<len;) { | |
int32_t avail = (int32_t)sizeof(b->buf) - b->len; | |
int32_t count = avail<len-off ? avail : (int32_t)(len-off); | |
memcpy(b->buf+b->len, buf+off, count); | |
off += count; | |
b->len += count; | |
if (b->len == (int32_t)sizeof(b->buf)) { | |
flush(b); | |
} | |
} | |
} | |
static void printint(bufout *b, int32_t x) | |
{ | |
uint8_t buf[16]; | |
uint8_t *end = buf + sizeof(buf); | |
uint8_t *beg = end; | |
int32_t t = x<0 ? x : -x; | |
do { | |
*--beg = '0' - (uint8_t)(t%10); | |
} while (t /= 10); | |
if (x < 0) { | |
*--beg = '-'; | |
} | |
print(b, beg, end-beg); | |
} | |
static void findslowest(int32_t len, arena scratch) | |
{ | |
double slowest = 0; | |
bufout *stdout = new(&scratch, bufout, 1); | |
uint64_t freq; | |
QueryPerformanceFrequency(&freq); | |
for (uint64_t seed = 0;; seed++) { | |
arena loop = scratch; | |
int32_t *nums = new(&loop, int32_t, len); | |
int32_t target = generate(nums, len, seed); | |
double best = 1e100; | |
for (int i = 0; i < 8; i++) { | |
uint64_t start, stop; | |
QueryPerformanceCounter(&start); | |
//twoSum1(target, nums, len); | |
twoSum2(target, nums, len, loop); | |
QueryPerformanceCounter(&stop); | |
double t = (double)(stop - start)/(double)freq; | |
best = t<best ? t : best; | |
} | |
if (best > slowest) { | |
slowest = best; | |
printint(stdout, (int32_t)(seed)); | |
print (stdout, (uint8_t *)"\t", 1); | |
printint(stdout, (int32_t)(best*1000000.0 + 0.5)); | |
print (stdout, (uint8_t *)" us\n", 4); | |
flush (stdout); | |
} | |
} | |
} | |
void mainCRTStartup(void) | |
{ | |
ptrdiff_t cap = (ptrdiff_t)1<<24; | |
arena scratch = {0}; | |
scratch.beg = VirtualAlloc(0, cap, 0x3000, 4); | |
scratch.end = scratch.beg + cap; | |
memset(scratch.beg, -1, cap); // realize entire arena | |
//findslowest(100000, scratch); | |
int32_t len = 100000; | |
int32_t *nums = new(&scratch, int32_t, len); | |
int32_t target = generate(nums, len, 384); | |
uint64_t freq; | |
QueryPerformanceFrequency(&freq); | |
bufout *stdout = new(&scratch, bufout, 1); | |
for (int32_t i = 0; i < 2; i++) { | |
uint64_t start, stop; | |
QueryPerformanceCounter(&start); | |
result r; | |
switch (i) { | |
case 0: r = twoSum2(target, nums, len, scratch); break; | |
case 1: r = twoSum1(target, nums, len); break; | |
} | |
QueryPerformanceCounter(&stop); | |
double t = (double)(stop - start)/(double)freq; | |
printint(stdout, r.x); | |
print (stdout, (uint8_t *)" ", 1); | |
printint(stdout, r.y); | |
print (stdout, (uint8_t *)" [", 2); | |
printint(stdout, (int32_t)(t*1000.0 + 0.5)); | |
print (stdout, (uint8_t *)" ms]\n", 5); | |
flush (stdout); | |
} | |
ExitProcess(stdout->err); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment