Last active
January 24, 2024 10:35
-
-
Save Validark/40d2df74b87692fe135bbeac14eed50d to your computer and use it in GitHub Desktop.
comptime pext zig
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const builtin = @import("builtin"); | |
const HAS_FAST_PDEP_AND_PEXT = blk: { | |
const cpu_name = builtin.cpu.model.llvm_name orelse builtin.cpu.model.name; | |
break :blk builtin.cpu.arch == .x86_64 and | |
std.Target.x86.featureSetHas(builtin.cpu.features, .bmi2) and | |
// pdep is microcoded (slow) on AMD architectures before Zen 3. | |
!std.mem.startsWith(u8, cpu_name, "bdver") and | |
(!std.mem.startsWith(u8, cpu_name, "znver") or cpu_name["znver".len] >= '3'); | |
}; | |
fn pext(src: anytype, comptime mask: @TypeOf(src), comptime use_vector_impl: bool) std.meta.Int(.unsigned, @popCount(mask)) { | |
if (mask == 0) return 0; | |
const num_one_groups = @popCount(mask & ~(mask << 1)); | |
if (!@inComptime() and comptime num_one_groups >= 3 and @bitSizeOf(@TypeOf(src)) <= 64 and HAS_FAST_PDEP_AND_PEXT) { | |
return switch (@TypeOf(src)) { | |
u64, u32 => @intCast(asm ("pext %[mask], %[src], %[ret]" | |
: [ret] "=r" (-> @TypeOf(src)), | |
: [src] "r" (src), | |
[mask] "r" (mask), | |
)), | |
else => @intCast(pext(@as(if (@bitSizeOf(@TypeOf(src)) <= 32) u32 else u64, src), mask)), | |
}; | |
} else if (num_one_groups >= 4) { | |
blk: { | |
// Attempt to produce a `global_shift` value such that | |
// the return statement at the end of this block moves the desired bits into the least significant | |
// bit position. | |
comptime var global_shift: @TypeOf(src) = 0; | |
comptime { | |
var x = mask; | |
var target = @as(@TypeOf(src), 1) << (@bitSizeOf(@TypeOf(src)) - 1); | |
for (0..@popCount(x) - 1) |_| target |= target >> 1; | |
// The maximum sum of the garbage data. If this overflows into the target bits, | |
// we can't use the global_shift. | |
var left_overs: @TypeOf(src) = 0; | |
var cur_pos: @TypeOf(src) = 0; | |
while (true) { | |
const shift = (@clz(x) - cur_pos); | |
global_shift |= @as(@TypeOf(src), 1) << shift; | |
var shifted_mask = x << shift; | |
cur_pos = @clz(shifted_mask); | |
cur_pos += @clz(~(shifted_mask << cur_pos)); | |
shifted_mask = shifted_mask << cur_pos >> cur_pos; | |
left_overs += shifted_mask; | |
if ((target & left_overs) != 0) break :blk; | |
if ((shifted_mask & target) != 0) break :blk; | |
x = shifted_mask >> shift; | |
if (x == 0) break; | |
} | |
} | |
return @intCast(((src & mask) *% global_shift) >> (@bitSizeOf(@TypeOf(src)) - @popCount(mask))); | |
} | |
// TODO: add heuristics for when this is probably the best option. | |
// Most probably, when we can keep inside of the vector widths that the machine actually has | |
if (use_vector_impl) { | |
comptime var min_int = u0; | |
const vec2 = comptime relevant_masks: { | |
var relevant_indices: []const @TypeOf(src) = &[0]@TypeOf(src){}; | |
var x = mask; | |
for (0..@popCount(mask)) |_| { | |
relevant_indices = relevant_indices ++ [1]@TypeOf(src){1 << @ctz(x)}; | |
x &= x -% 1; | |
} | |
min_int = std.meta.Int(.unsigned, @ctz(relevant_indices[@popCount(mask) - 1]) + 1); | |
break :relevant_masks relevant_indices[0..@popCount(mask)].*; | |
}; | |
const vec = @as(@Vector(@popCount(mask), min_int), @splat(@truncate(src))); | |
return @bitCast((vec & vec2) == vec2); | |
} | |
} | |
{ | |
var ans: @TypeOf(src) = 0; | |
comptime var cur_pos = 0; | |
comptime var x = mask; | |
inline while (x != 0) { | |
const mask_ctz = @ctz(x); | |
const num_ones = @ctz(~(x >> mask_ctz)); | |
comptime var ones = 1; | |
inline for (0..num_ones) |_| ones <<= 1; | |
ones -%= 1; | |
// @compileLog(std.fmt.comptimePrint("ans |= (src >> {}) & 0b{b}", .{ mask_ctz - cur_pos, (ones << cur_pos) })); | |
ans |= (src >> (mask_ctz - cur_pos)) & (ones << cur_pos); | |
cur_pos += num_ones; | |
inline for (0..num_ones) |_| x &= x - 1; | |
} | |
return @intCast(ans); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment