Created
August 16, 2023 07:25
-
-
Save travisstaloch/b377c953c3101249b30405afff4c067d to your computer and use it in GitHub Desktop.
levenshtein distance implementation in zig
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! https://en.wikipedia.org/wiki/Levenshtein_distance | |
const std = @import("std"); | |
pub fn levRec(a: []const u8, b: []const u8) u16 { | |
if (a.len == 0) return @truncate(b.len); | |
if (b.len == 0) return @truncate(a.len); | |
if (a[0] == b[0]) | |
return levRec(a[1..], b[1..]); | |
return @min( | |
levRec(a[1..], b), | |
levRec(a, b[1..]), | |
levRec(a[1..], b[1..]), | |
) + 1; | |
} | |
inline fn idx(i: usize, j: usize, cols: usize) usize { | |
return i * cols + j; | |
} | |
fn dumpTable(a: []const u8, b: []const u8, table: []const u8) void { | |
const n = a.len; | |
const m = b.len; | |
std.debug.print(" ", .{}); | |
for (b) |c| std.debug.print("{c} ", .{c}); | |
std.debug.print("\n ", .{}); | |
for (0..n + 1) |i| { | |
for (0..m + 1) |j| { | |
std.debug.print("{} ", .{table[idx(i, j, m + 1)]}); | |
} | |
std.debug.print("\n", .{}); | |
std.debug.print("{c} ", .{if (i < a.len) a[i] else ' '}); | |
} | |
} | |
pub fn levenshteinDistance(allocator: std.mem.Allocator, a: []const u8, b: []const u8) !u16 { | |
const n = a.len; | |
const m = b.len; | |
const table = try allocator.alloc(u8, (n + 1) * (m + 1)); | |
defer allocator.free(table); | |
table[0] = 0; | |
for (1..n + 1) |i| table[idx(i, 0, m + 1)] = @truncate(i); | |
for (1..m + 1) |i| table[i] = @truncate(i); | |
for (1..n + 1) |i| { | |
for (1..m + 1) |j| { | |
table[idx(i, j, m + 1)] = @min( | |
table[idx(i - 1, j, m + 1)] + 1, | |
table[idx(i, j - 1, m + 1)] + 1, | |
table[idx(i - 1, j - 1, m + 1)] + | |
@intFromBool(a[i - 1] != b[j - 1]), | |
); | |
} | |
} | |
// dumpTable(a, b, table); | |
return table[table.len - 1]; | |
} | |
fn dumpTable2(a: []const u8, b: []const u8, table: []const u8) void { | |
const n = a.len; | |
const m = b.len; | |
std.debug.print("", .{}); | |
for (b) |c| std.debug.print("{c} ", .{c}); | |
std.debug.print("\n", .{}); | |
for (0..n) |i| { | |
std.debug.print("{c} ", .{a[i]}); | |
for (0..m) |j| { | |
std.debug.print("{} ", .{table[idx(i, j, m)]}); | |
} | |
std.debug.print("\n", .{}); | |
} | |
} | |
/// this one allocates less memory but has more instructions | |
pub fn levenshteinDistance2(allocator: std.mem.Allocator, a: []const u8, b: []const u8) !u16 { | |
const n = a.len; | |
const m = b.len; | |
const table = try allocator.alloc(u8, n * m); | |
defer allocator.free(table); | |
table[0] = 0; | |
for (0..n) |i| { | |
for (0..m) |j| { | |
table[idx(i, j, m)] = @min( | |
(if (i == 0) | |
@as(u8, @truncate(j)) | |
else | |
table[idx(i - 1, j, m)]) + 1, | |
(if (j == 0) | |
@as(u8, @truncate(i)) | |
else | |
table[idx(i, j - 1, m)]) + 1, | |
(if (i == 0) | |
@as(u8, @truncate(j)) | |
else if (j == 0) | |
@as(u8, @truncate(i)) | |
else | |
table[idx(i - 1, j - 1, m)]) + | |
@intFromBool(a[i] != b[j]), | |
); | |
} | |
} | |
// dumpTable2(a, b, table); | |
return table[table.len - 1]; | |
} | |
fn check(a: []const u8, b: []const u8) !void { | |
try std.testing.expectEqual( | |
levRec(a, b), | |
try levenshteinDistance(std.testing.allocator, a, b), | |
); | |
try std.testing.expectEqual( | |
levRec(a, b), | |
try levenshteinDistance2(std.testing.allocator, a, b), | |
); | |
} | |
test { | |
try check("kitten", "sitting"); | |
try check("flaw", "lawn"); | |
try check("superman", "superwoman"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment