Last active
January 4, 2024 07:54
-
-
Save travisstaloch/8bb45b3b6a9502f7f83de466974755af to your computer and use it in GitHub Desktop.
nice n-dimensional and generic matrix code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# generates valid zig multiply test case code which may used in matrix.zig | |
import numpy as np | |
np.random.seed(42) | |
# A = np.empty((2, 3, 2)) | |
# A.fill(1) | |
# A = np.array([ | |
# [ | |
# [1,2,3], | |
# [4,5,6], | |
# ], | |
# [ | |
# [7,8,9], | |
# [10,11,12], | |
# ], | |
# ]) | |
# B = np.array([ | |
# [ | |
# [1,2], | |
# [3,4], | |
# [5,6], | |
# ], | |
# [ | |
# [7,8], | |
# [9,10], | |
# [11,12], | |
# ], | |
# ]) | |
# B = np.empty((2, 2, 3)) | |
# B.fill(2) | |
A = np.random.randint(0, 10, size=(2, 2, 2, 3, 2)) | |
B = np.random.randint(0, 10, size=(2, 2, 2, 2, 3)) | |
# print("A:\n{}, shape={}\nB:\n{}, shape={}".format(A, A.shape, B, B.shape)) | |
C = np.matmul(A, B) | |
# print("Product C:\n{}, shape={}".format(C, C.shape)) | |
import io | |
def print_to_string(*args, **kwargs): | |
output = io.StringIO() | |
print(*args, file=output, **kwargs) | |
contents = output.getvalue() | |
output.close() | |
return contents | |
def mls(mat): | |
s = print_to_string(mat) | |
x = s.split('\n') | |
return '\n\\\\'.join(x) | |
def uwp(tup): | |
s = print_to_string(tup) | |
return s.replace("(", "").replace(")", "") | |
s = """ | |
try testMul(u32, &.{{ {} }}, | |
\\\\{} | |
, &.{{ {} }} | |
, | |
\\\\{} | |
, | |
&.{{ {} }}, | |
\\\\{} | |
); | |
""".format(uwp(A.shape), mls(A), uwp(B.shape), mls(B), uwp(C.shape), mls(C)) | |
print(s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! this lib only does multiplication so far. but can parse numpy output at | |
//! runtime and comptime and has lots nice helpers including asArray() which | |
//! allows multi indexing i.e `mat.asArray()[i][j]`. there are lots of working | |
//! tests for 1d, 2d, 3d, 4d and 5d matrices w/ different shapes. | |
const std = @import("std"); | |
const Allocator = std.mem.Allocator; | |
pub fn Matrix(comptime T: type, comptime shape: []const usize) type { | |
return struct { | |
const Self = @This(); | |
items: []T, | |
comptime shape: []const usize = shape, | |
pub const len = @reduce(.Mul, @as( | |
@Vector(shape.len, usize), | |
shape[0..shape.len].*, | |
)); | |
pub const Array = ArrayFromDim(0); | |
pub fn init(allocator: Allocator) !Self { | |
return .{ | |
.items = try allocator.alloc(T, len), | |
}; | |
} | |
pub fn initFilled(allocator: Allocator, value: T) !Self { | |
var result = try init(allocator); | |
result.fill(value); | |
return result; | |
} | |
pub fn initArray(allocator: Allocator, array: Array) !Self { | |
const result = try init(allocator); | |
@memcpy(result.items, @as([*]const T, @ptrCast(&array))); | |
return result; | |
} | |
pub fn initConst(items: *const Array) Self { | |
return .{ | |
.items = @as([*]T, @constCast(@ptrCast(items)))[0..len], | |
}; | |
} | |
fn parseErr(comptime fmt: []const u8, args: anytype) noreturn { | |
if (@inComptime()) | |
@compileError(std.fmt.comptimePrint(fmt, args)) | |
else { | |
std.log.err(fmt, args); | |
@panic("parse error"); | |
} | |
} | |
/// parse output from numpy at comptime | |
pub inline fn parse( | |
comptime input: []const u8, | |
) !Self { | |
comptime { | |
var items: [len]T = undefined; | |
return parseBuf(input, &items); | |
} | |
} | |
/// parse output from numpy at runtime into items buffer | |
pub fn parseBuf( | |
input: []const u8, | |
items: *[len]T, | |
) !Self { | |
const self = Self{ .items = items }; | |
var i: usize = 0; | |
var depth: usize = 0; | |
var ptr = self.items.ptr; | |
while (i < input.len) : (i += 1) { | |
switch (input[i]) { | |
' ', '\n', '\t', '\r' => {}, | |
'[' => depth += 1, | |
']' => depth -= 1, | |
'0'...'9', '-', '+' => { | |
if (depth != shape.len) unreachable; | |
const rbi = std.mem.indexOfScalarPos(u8, input, i, ']') orelse | |
parseErr( | |
"missing closing bracket at position {}", | |
.{i}, | |
); | |
var it = std.mem.tokenizeScalar(u8, input[i..rbi], ' '); | |
while (it.next()) |nr| { | |
const n = switch (@typeInfo(T)) { | |
.Int => try std.fmt.parseInt(T, nr, 10), | |
.Float => try std.fmt.parseFloat(T, nr), | |
else => unreachable, | |
}; | |
ptr[0] = n; | |
ptr += 1; | |
} | |
i = rbi; | |
depth -= 1; | |
}, | |
else => parseErr( | |
"unexpected character '{c}' at position {}", | |
.{ input[i], i }, | |
), | |
} | |
} | |
if (depth != 0) parseErr("eof and missing closing bracket", .{}); | |
if (ptr != self.items.ptr + len) | |
parseErr("eof and either too many or not enough items", .{}); | |
return self; | |
} | |
pub fn deinit(self: Self, allocator: Allocator) void { | |
if (@sizeOf(T) > 0) allocator.free(self.items); | |
} | |
pub fn fill(self: Self, value: T) void { | |
@memset(self.items, value); | |
} | |
pub inline fn asArray(self: Self) *Array { | |
return @ptrCast(self.items.ptr); | |
} | |
pub fn format(self: Self, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { | |
_ = fmt; | |
_ = options; | |
for (self.asArray()) |x| { | |
try writer.print("{any}\n", .{x}); | |
} | |
} | |
pub fn dump(self: Self, message: []const u8) void { | |
std.debug.print("--{s}:", .{message}); | |
for (shape, 0..) |s, i| { | |
std.debug.print("{s}{}", .{ if (i != 0) "x" else "", s }); | |
} | |
std.debug.print("--\n", .{}); | |
std.debug.print("{}", .{self}); | |
} | |
pub fn ArrayFromDim(comptime dim: usize) type { | |
const s = shape[dim..]; | |
return switch (s[dim..].len) { | |
1 => [s[0]]T, | |
2 => [s[0]][s[1]]T, | |
3 => [s[0]][s[1]][s[2]]T, | |
4 => [s[0]][s[1]][s[2]][s[3]]T, | |
5 => [s[0]][s[1]][s[2]][s[3]][s[4]]T, | |
else => unreachable, // TODO | |
}; | |
} | |
pub fn subMatrix(ptr: [*]T, comptime dim: usize) Matrix(T, shape[dim..]) { | |
const M = Matrix(T, shape[dim..]); | |
return M{ .items = ptr[0..M.len] }; | |
} | |
pub const mul = switch (shape.len) { | |
1 => mul1d, | |
2 => mul2d, | |
else => mulNd, | |
}; | |
fn mul1d(a: Self, b: anytype, dst: anytype) void { | |
comptime std.debug.assert(a.shape.len == 1); | |
comptime std.debug.assert(b.shape.len == 2); | |
comptime std.debug.assert(a.shape[0] == b.shape[1]); | |
comptime std.debug.assert(b.shape[0] == 1); | |
dst.items[0] = 0; | |
for (0..a.shape[0]) |i| { | |
dst.items[0] += a.items[i] * b.asArray()[0][i]; | |
} | |
} | |
fn mul2d(a: Self, b: anytype, dst: anytype) void { | |
comptime std.debug.assert(a.shape.len == 2); | |
comptime std.debug.assert(a.shape[1] == b.shape[0]); | |
for (0..shape[0]) |i| { | |
dst.asArray()[i] = [1]T{0} ** b.shape[1]; | |
for (0..b.shape[1]) |j| { | |
for (0..shape[1]) |k| { | |
dst.asArray()[i][j] += | |
a.asArray()[i][k] * b.asArray()[k][j]; | |
} | |
} | |
} | |
} | |
fn mulNd(a: Self, b: anytype, dst: anytype) void { | |
comptime std.debug.assert(a.shape.len >= 3); | |
for (0..dst.shape[0]) |i| { | |
const asubm = subMatrix(@ptrCast(&a.asArray()[i]), 1); | |
const bsubm = @TypeOf(b).subMatrix(@ptrCast(&b.asArray()[i]), 1); | |
const dsubm = @TypeOf(dst).subMatrix(@ptrCast(&dst.asArray()[i]), 1); | |
switch (shape.len) { | |
3 => asubm.mul2d(bsubm, dsubm), | |
else => asubm.mulNd(bsubm, dsubm), | |
} | |
} | |
} | |
}; | |
} | |
fn testMul1d(comptime T: type) !void { | |
const a = Matrix(T, &.{2}).initConst(&.{ 1, 2 }); | |
const b = Matrix(T, &.{ 1, 2 }).initConst(&.{.{ 3, 4 }}); | |
var c = try Matrix(T, &.{1}).init(std.testing.allocator); | |
defer c.deinit(std.testing.allocator); | |
a.mul(b, c); | |
try std.testing.expectEqualSlices( | |
T, | |
Matrix(T, &.{1}).initConst(&.{11}).items, | |
c.items, | |
); | |
} | |
test "1d mul" { | |
try testMul1d(u8); | |
try testMul1d(i8); | |
try testMul1d(f32); | |
} | |
fn testMul2d(comptime T: type) !void { | |
var a = try Matrix(T, &.{ 3, 2 }).init(std.testing.allocator); | |
defer a.deinit(std.testing.allocator); | |
a.fill(1); | |
const b = Matrix(T, &.{ 2, 3 }).initConst(&.{ | |
.{ 2, 2, 2 }, | |
.{ 2, 2, 2 }, | |
}); | |
const C = Matrix(T, &.{ 3, 3 }); | |
var c = try C.init(std.testing.allocator); | |
defer c.deinit(std.testing.allocator); | |
a.mul(b, c); | |
try std.testing.expectEqualSlices(T, C.initConst(&.{ | |
.{ 4, 4, 4 }, | |
.{ 4, 4, 4 }, | |
.{ 4, 4, 4 }, | |
}).items, c.items); | |
} | |
test "2d mul" { | |
try testMul2d(u8); | |
try testMul2d(i8); | |
try testMul2d(f32); | |
} | |
fn testMul3d(comptime T: type) !void { | |
var a = try Matrix(T, &.{ 2, 2, 3 }).initArray(std.testing.allocator, .{ | |
.{ .{ 1, 2, 3 }, .{ 4, 5, 6 } }, | |
.{ .{ 7, 8, 9 }, .{ 10, 11, 12 } }, | |
}); | |
defer a.deinit(std.testing.allocator); | |
var b = try Matrix(T, &.{ 2, 3, 2 }).initArray(std.testing.allocator, .{ | |
.{ .{ 1, 2 }, .{ 3, 4 }, .{ 5, 6 } }, | |
.{ .{ 7, 8 }, .{ 9, 10 }, .{ 11, 12 } }, | |
}); | |
defer b.deinit(std.testing.allocator); | |
const C = Matrix(T, &.{ 2, 2, 2 }); | |
var c = try C.init(std.testing.allocator); | |
defer c.deinit(std.testing.allocator); | |
a.mul(b, c); | |
try std.testing.expectEqualSlices(T, C.initConst(&.{ | |
.{ .{ 22, 28 }, .{ 49, 64 } }, | |
.{ .{ 220, 244 }, .{ 301, 334 } }, | |
}).items, c.items); | |
} | |
test "3d mul" { | |
try testMul3d(u16); | |
try testMul3d(i16); | |
try testMul3d(f32); | |
} | |
fn testParse(comptime T: type) !void { | |
@setEvalBranchQuota(2000); | |
const A = Matrix(T, &.{2}); | |
const a = try A.parse("[1 2]"); | |
try std.testing.expectEqualSlices(T, &.{ 1, 2 }, a.items); | |
var abuf: [2]T = undefined; | |
const a2 = try A.parseBuf("[1 2]", &abuf); | |
try std.testing.expectEqualSlices(T, &.{ 1, 2 }, a2.items); | |
const B = Matrix(T, &.{ 2, 2 }); | |
const b = try B.parse("[[1 2] [3 4]]"); | |
try std.testing.expectEqualSlices(T, &.{ 1, 2, 3, 4 }, b.items); | |
var bbuf: [4]T = undefined; | |
const b2 = try B.parseBuf("[[1 2] [3 4]]", &bbuf); | |
try std.testing.expectEqualSlices(T, &.{ 1, 2, 3, 4 }, b2.items); | |
const C = Matrix(T, &.{ 2, 2, 2 }); | |
const cin = | |
\\[[[1 2] [3 4]] | |
\\ | |
\\ [[5 6] [ 7 8]]] | |
; | |
const c = try C.parse(cin); | |
try std.testing.expectEqualSlices(T, &.{ 1, 2, 3, 4, 5, 6, 7, 8 }, c.items); | |
var cbuf: [8]T = undefined; | |
const c2 = try C.parseBuf(cin, &cbuf); | |
try std.testing.expectEqualSlices(T, &.{ 1, 2, 3, 4, 5, 6, 7, 8 }, c2.items); | |
} | |
test "parse" { | |
try testParse(u32); | |
try testParse(i32); | |
try testParse(f32); | |
} | |
fn testMul( | |
comptime T: type, | |
comptime a_shape: []const usize, | |
comptime a_in: []const u8, | |
comptime b_shape: []const usize, | |
comptime b_in: []const u8, | |
comptime c_shape: []const usize, | |
comptime c_in: []const u8, | |
) !void { | |
@setEvalBranchQuota(10_000); | |
const A = Matrix(T, a_shape); | |
const a = try A.parse(a_in); | |
const B = Matrix(T, b_shape); | |
const b = try B.parse(b_in); | |
const C = Matrix(T, c_shape); | |
var c = try C.init(std.testing.allocator); | |
defer c.deinit(std.testing.allocator); | |
a.mul(b, c); | |
const expected = try C.parse(c_in); | |
try std.testing.expectEqualSlices(T, expected.items, c.items); | |
} | |
fn testMuls(comptime T: type) !void { | |
// 3d | |
try testMul(T, &.{ 3, 3, 2 }, | |
\\[[[6 3] | |
\\ [7 4] | |
\\ [6 9]] | |
\\ | |
\\ [[2 6] | |
\\ [7 4] | |
\\ [3 7]] | |
\\ | |
\\ [[7 2] | |
\\ [5 4] | |
\\ [1 7]]] | |
, &.{ 3, 2, 4 }, | |
\\[[[5 1 4 0] | |
\\ [9 5 8 0]] | |
\\ | |
\\ [[9 2 6 3] | |
\\ [8 2 4 2]] | |
\\ | |
\\ [[6 4 8 6] | |
\\ [1 3 8 1]]] | |
, &.{ 3, 3, 4 }, | |
\\[[[ 57 21 48 0] | |
\\ [ 71 27 60 0] | |
\\ [111 51 96 0]] | |
\\ | |
\\ [[ 66 16 36 18] | |
\\ [ 95 22 58 29] | |
\\ [ 83 20 46 23]] | |
\\ | |
\\ [[ 44 34 72 44] | |
\\ [ 34 32 72 34] | |
\\ [ 13 25 64 13]]] | |
); | |
try testMul(T, &.{ 1, 3, 2 }, | |
\\[[[6 3] | |
\\ [7 4] | |
\\ [6 9]]] | |
\\ | |
, &.{ 1, 2, 4 }, | |
\\[[[2 6 7 4] | |
\\ [3 7 7 2]]] | |
\\ | |
, &.{ 1, 3, 4 }, | |
\\[[[ 21 57 63 30] | |
\\ [ 26 70 77 36] | |
\\ [ 39 99 105 42]]] | |
\\ | |
); | |
try testMul(T, &.{ 1, 3, 1 }, | |
\\[[[6] | |
\\ [3] | |
\\ [7]]] | |
\\ | |
, &.{ 1, 1, 3 }, | |
\\[[[4 6 9]]] | |
\\ | |
, &.{ 1, 3, 3 }, | |
\\[[[24 36 54] | |
\\ [12 18 27] | |
\\ [28 42 63]]] | |
\\ | |
); | |
// 4d | |
try testMul(T, &.{ 1, 1, 3, 1 }, | |
\\[[[[6] | |
\\ [3] | |
\\ [7]]]] | |
\\ | |
, &.{ 1, 1, 1, 3 }, | |
\\[[[[4 6 9]]]] | |
\\ | |
, &.{ 1, 1, 3, 3 }, | |
\\[[[[24 36 54] | |
\\ [12 18 27] | |
\\ [28 42 63]]]] | |
\\ | |
); | |
try testMul(u32, &.{ 2, 2, 3, 2 }, | |
\\[[[[6 3] | |
\\ [7 4] | |
\\ [6 9]] | |
\\ | |
\\ [[2 6] | |
\\ [7 4] | |
\\ [3 7]]] | |
\\ | |
\\ | |
\\ [[[7 2] | |
\\ [5 4] | |
\\ [1 7]] | |
\\ | |
\\ [[5 1] | |
\\ [4 0] | |
\\ [9 5]]]] | |
\\ | |
, &.{ 2, 2, 2, 3 }, | |
\\[[[[8 0 9] | |
\\ [2 6 3]] | |
\\ | |
\\ [[8 2 4] | |
\\ [2 6 4]]] | |
\\ | |
\\ | |
\\ [[[8 6 1] | |
\\ [3 8 1]] | |
\\ | |
\\ [[9 8 9] | |
\\ [4 1 3]]]] | |
\\ | |
, &.{ 2, 2, 3, 3 }, | |
\\[[[[ 54 18 63] | |
\\ [ 64 24 75] | |
\\ [ 66 54 81]] | |
\\ | |
\\ [[ 28 40 32] | |
\\ [ 64 38 44] | |
\\ [ 38 48 40]]] | |
\\ | |
\\ | |
\\ [[[ 62 58 9] | |
\\ [ 52 62 9] | |
\\ [ 29 62 8]] | |
\\ | |
\\ [[ 49 41 48] | |
\\ [ 36 32 36] | |
\\ [101 77 96]]]] | |
\\ | |
); | |
// 5d | |
try testMul(u32, &.{ 2, 2, 2, 3, 2 }, | |
\\[[[[[6 3] | |
\\ [7 4] | |
\\ [6 9]] | |
\\ | |
\\ [[2 6] | |
\\ [7 4] | |
\\ [3 7]]] | |
\\ | |
\\ | |
\\ [[[7 2] | |
\\ [5 4] | |
\\ [1 7]] | |
\\ | |
\\ [[5 1] | |
\\ [4 0] | |
\\ [9 5]]]] | |
\\ | |
\\ | |
\\ | |
\\ [[[[8 0] | |
\\ [9 2] | |
\\ [6 3]] | |
\\ | |
\\ [[8 2] | |
\\ [4 2] | |
\\ [6 4]]] | |
\\ | |
\\ | |
\\ [[[8 6] | |
\\ [1 3] | |
\\ [8 1]] | |
\\ | |
\\ [[9 8] | |
\\ [9 4] | |
\\ [1 3]]]]] | |
\\ | |
, &.{ 2, 2, 2, 2, 3 }, | |
\\[[[[[6 7 2] | |
\\ [0 3 1]] | |
\\ | |
\\ [[7 3 1] | |
\\ [5 5 9]]] | |
\\ | |
\\ | |
\\ [[[3 5 1] | |
\\ [9 1 9]] | |
\\ | |
\\ [[3 7 6] | |
\\ [8 7 4]]]] | |
\\ | |
\\ | |
\\ | |
\\ [[[[1 4 7] | |
\\ [9 8 8]] | |
\\ | |
\\ [[0 8 6] | |
\\ [8 7 0]]] | |
\\ | |
\\ | |
\\ [[[7 7 2] | |
\\ [0 7 2]] | |
\\ | |
\\ [[2 0 4] | |
\\ [9 6 9]]]]] | |
\\ | |
, &.{ 2, 2, 2, 3, 3 }, | |
\\[[[[[ 36 51 15] | |
\\ [ 42 61 18] | |
\\ [ 36 69 21]] | |
\\ | |
\\ [[ 44 36 56] | |
\\ [ 69 41 43] | |
\\ [ 56 44 66]]] | |
\\ | |
\\ | |
\\ [[[ 39 37 25] | |
\\ [ 51 29 41] | |
\\ [ 66 12 64]] | |
\\ | |
\\ [[ 23 42 34] | |
\\ [ 12 28 24] | |
\\ [ 67 98 74]]]] | |
\\ | |
\\ | |
\\ | |
\\ [[[[ 8 32 56] | |
\\ [ 27 52 79] | |
\\ [ 33 48 66]] | |
\\ | |
\\ [[ 16 78 48] | |
\\ [ 16 46 24] | |
\\ [ 32 76 36]]] | |
\\ | |
\\ | |
\\ [[[ 56 98 28] | |
\\ [ 7 28 8] | |
\\ [ 56 63 18]] | |
\\ | |
\\ [[ 90 48 108] | |
\\ [ 54 24 72] | |
\\ [ 29 18 31]]]]] | |
\\ | |
); | |
} | |
test { | |
try testMuls(u32); | |
try testMuls(i32); | |
try testMuls(f32); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment