Skip to content

Instantly share code, notes, and snippets.

@lithdew
Last active February 6, 2021 17:03
Show Gist options
  • Save lithdew/1b93e41f6a2c6830f42aa4623c04b74e to your computer and use it in GitHub Desktop.
Save lithdew/1b93e41f6a2c6830f42aa4623c04b74e to your computer and use it in GitHub Desktop.
zig: open-addressing robinhood hashmap w/ backward shift deletion
// zig run benchmark.zig -lc -O ReleaseFast
const std = @import("std");
usingnamespace @import("hashmap.zig");
pub fn main() !void {
const allocator = std.heap.c_allocator;
var map = try HashMap(
u64,
u64,
comptime std.hash_map.getAutoHashFn(u64),
comptime std.hash_map.getAutoEqlFn(u64),
).initCapacity(allocator, 16);
defer map.deinit(allocator);
const count = 1_000_000;
var i: usize = 0;
while (i < count) : (i += 1) {
try map.put(allocator, i, i);
}
var k: usize = 0;
while (k < 10) : (k += 1) {
{
var j: usize = i;
while (j < i + count) : (j += 1) {
try map.put(allocator, j, j);
}
}
{
var j: usize = i - count;
while (j < i) : (j += 1) {
const value = map.remove(allocator, j);
if (value == null) unreachable;
}
}
i += count;
}
}
const std = @import("std");
const mem = std.mem;
const math = std.math;
const testing = std.testing;
pub fn HashMap(
comptime K: type,
comptime V: type,
comptime hashFn: fn (K) u64,
comptime eqlFn: fn (K, K) bool,
) type {
return struct {
const Self = @This();
pub const Iterator = struct {
map: *const Self,
index: usize = 0,
pub fn next(self: *Iterator) ?Entry {
while (self.index < self.map.entries.len) {
defer self.index += 1;
if (self.map.entries[self.index].fingerprint.dib != 0) {
return self.map.entries[self.index];
}
}
return null;
}
};
entries: []Entry,
len: usize,
cap: usize,
grow_at: usize,
shrink_at: usize,
pub const Fingerprint = packed struct {
hash: u48 = 0,
dib: u16 = 0,
};
pub const Entry = struct {
fingerprint: Fingerprint = .{},
key: K = mem.zeroes(K),
value: V = mem.zeroes(V),
};
pub fn initCapacity(allocator: *mem.Allocator, capacity: usize) !Self {
const num_entries = math.ceilPowerOfTwoAssert(usize, capacity);
const entries = try allocator.alloc(Entry, num_entries);
errdefer allocator.free(entries);
mem.set(u8, mem.sliceAsBytes(entries), 0);
return Self{
.entries = entries,
.len = 0,
.cap = num_entries,
.grow_at = num_entries * 75 / 100,
.shrink_at = num_entries * 10 / 100,
};
}
pub fn deinit(self: *const Self, allocator: *mem.Allocator) void {
allocator.free(self.entries);
}
pub fn clear(self: *Self, allocator: *mem.Allocator, comptime update_cap: bool) void {
self.len = 0;
if (update_cap) {
self.cap = self.entries.len;
} else if (self.entries.len != self.cap) {
const maybe_new_entries = allocator.realloc(self.entries, self.cap) catch null;
if (maybe_new_entries) |new_entries| {
self.entries = new_entries;
}
self.entries.len = self.cap;
}
mem.set(u8, mem.sliceAsBytes(self.entries), 0);
self.grow_at = self.entries.len * 75 / 100;
self.shrink_at = self.entries.len * 10 / 100;
}
pub fn resize(self: *Self, allocator: *mem.Allocator, new_cap: usize) !void {
const map = try Self.initCapacity(allocator, new_cap);
for (self.entries) |*entry| {
if (entry.fingerprint.dib == 0) {
continue;
}
entry.fingerprint.dib = 1;
var j: usize = entry.fingerprint.hash & (map.entries.len - 1);
while (true) {
const bucket = &map.entries[j];
if (bucket.fingerprint.dib == 0) {
bucket.* = entry.*;
break;
}
if (bucket.fingerprint.dib < entry.fingerprint.dib) {
mem.swap(Entry, bucket, entry);
}
j = (j +% 1) & (map.entries.len - 1);
entry.fingerprint.dib +%= 1;
}
}
allocator.free(self.entries);
self.entries = map.entries;
self.grow_at = map.grow_at;
self.shrink_at = map.shrink_at;
}
pub fn get(self: *const Self, key: K) ?V {
const hash = @truncate(u48, hashFn(key) << 16 >> 16);
var i: usize = hash & (self.entries.len - 1);
while (true) {
const bucket = &self.entries[i];
if (bucket.fingerprint.dib == 0) {
return null;
}
if (hash == bucket.fingerprint.hash and eqlFn(key, bucket.key)) {
return bucket.value;
}
i = (i +% 1) & (self.entries.len - 1);
}
}
pub fn put(self: *Self, allocator: *mem.Allocator, key: K, value: V) !void {
if (self.len == self.grow_at) {
try self.resize(allocator, self.entries.len * 2);
}
var entry: Entry = .{
.fingerprint = .{
.hash = @truncate(u48, hashFn(key) << 16 >> 16),
.dib = @as(u16, 1),
},
.key = key,
.value = value,
};
var i: usize = entry.fingerprint.hash & (self.entries.len - 1);
while (true) {
const bucket = &self.entries[i];
if (bucket.fingerprint.dib == 0) {
bucket.* = entry;
self.len += 1;
return;
}
if (entry.fingerprint.hash == bucket.fingerprint.hash and eqlFn(key, bucket.key)) {
bucket.value = value;
return;
}
if (bucket.fingerprint.dib < entry.fingerprint.dib) {
mem.swap(Entry, bucket, &entry);
}
i = (i +% 1) & (self.entries.len - 1);
entry.fingerprint.dib +%= 1;
}
}
pub fn remove(self: *Self, allocator: *mem.Allocator, key: K) ?V {
const hash = @truncate(u48, hashFn(key) << 16 >> 16);
var i: usize = hash & (self.entries.len - 1);
while (true) {
var bucket = &self.entries[i];
if (bucket.fingerprint.dib == 0) {
return null;
}
if (hash == bucket.fingerprint.hash and eqlFn(key, bucket.key)) {
const value = bucket.value;
bucket.fingerprint.dib = 0;
while (true) {
var prev = bucket;
i = (i +% 1) & (self.entries.len - 1);
bucket = &self.entries[i];
if (bucket.fingerprint.dib <= 1) {
prev.fingerprint.dib = 0;
break;
}
prev.* = bucket.*;
prev.fingerprint.dib -%= 1;
}
self.len -= 1;
if (self.entries.len > self.cap and self.len <= self.shrink_at) {
self.resize(allocator, self.entries.len / 2) catch {};
}
return value;
}
i = (i +% 1) & (self.entries.len - 1);
}
}
pub fn iterator(self: *const Self) Iterator {
return Iterator{ .map = self };
}
};
}
test "HashMap: sanity check" {
const allocator = testing.allocator;
var map = try HashMap(
u64,
u64,
comptime std.hash_map.getAutoHashFn(u64),
comptime std.hash_map.getAutoEqlFn(u64),
).initCapacity(allocator, 16);
defer map.deinit(allocator);
try map.put(allocator, 0, 0);
try map.put(allocator, 1, 1);
try map.put(allocator, 2, 2);
try map.put(allocator, 3, 3);
var order = [_]u64{ 3, 1, 2, 0 };
var it = map.iterator();
var i: usize = 0;
while (it.next()) |entry| : (i += 1) {
testing.expectEqual(order[i], entry.key);
testing.expectEqual(order[i], entry.value);
}
testing.expectEqual(@as(?u64, 0), map.get(0));
testing.expectEqual(@as(?u64, 1), map.get(1));
testing.expectEqual(@as(?u64, 2), map.get(2));
testing.expectEqual(@as(?u64, 3), map.get(3));
testing.expectEqual(@as(?u64, null), map.get(4));
testing.expectEqual(@as(?u64, 0), map.remove(allocator, 0));
testing.expectEqual(@as(?u64, 1), map.remove(allocator, 1));
testing.expectEqual(@as(?u64, 2), map.remove(allocator, 2));
testing.expectEqual(@as(?u64, 3), map.remove(allocator, 3));
testing.expectEqual(@as(?u64, null), map.remove(allocator, 4));
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment