Skip to content

Instantly share code, notes, and snippets.

@memononen
Last active December 4, 2021 19:19
Show Gist options
  • Save memononen/b1b2eb0ff1077b629efb6c0ef472fe71 to your computer and use it in GitHub Desktop.
Save memononen/b1b2eb0ff1077b629efb6c0ef472fe71 to your computer and use it in GitHub Desktop.
3-way merge based on O(NP) Myers diff and diff3
#include <stdio.h>
#include <vector>
#include <span>
#include <algorithm>
// Based on
// "An O(NP) Sequence Comparison Algorithm" by Sun Wu, Udi Manber and Gene Myers
// - https://publications.mpi-cbg.de/Wu_1990_6334.pdf
// - Good article visualizing Myer's older algorithm: https://epxx.co/artigos/diff_en.html
//
// "A Formal Investigation of Diff3"
// - https://www.cis.upenn.edu/~bcpierce/papers/diff3-short.pdf
//
struct Range {
Range() = default;
Range(const int _start, const int _end) : start(_start), end(_end) {}
int size() const { return end - start; }
int start = 0;
int end = 0;
};
// Describes a mismatch between sequences
struct Hunk {
Hunk() = default;
Hunk(Range _base, Range _change) : base(_base), change(_change) {}
Range base = {};
Range change = {};
};
struct Diag {
Diag() = default;
Diag(int _x, int _y, int _len, int _next) : x(_x), y(_y), length(_len), next(_next) {}
int x = 0; // start position
int y = 0;
int length = 0; // diagonal length
int next = -1; // next diagonal towards the start
};
struct Point {
Point() = default;
Point(int _y, int _diag) : y(_y), diag(_diag) {}
int y = 0; // furthest y
int diag = -1; // nearest diagonal
};
enum class ChunkType { Conflict, Base, Left, Right };
// Describes a change, type defines from which sequence should be used for merged.
struct Chunk {
Chunk() = default;
Chunk(ChunkType _type, const Range _base, const Range _left, const Range _right)
: type(_type), base(_base), left(_left), right(_right) {}
ChunkType type = ChunkType::Base;
Range base = {};
Range left = {};
Range right = {};
};
static void snake(const int k, std::span<const char> left, std::span<const char> right,
std::span<Point> fp, const int fp0, std::vector<Diag>& diags)
{
const Point& belowPt = fp[fp0 + k-1];
const Point& rightPt = fp[fp0 + k+1];
const bool below = (belowPt.y+1) > rightPt.y;
const int prevDiag = below ? belowPt.diag : rightPt.diag;
int y = below ? (belowPt.y+1) : rightPt.y;
int x = y - k;
int length = 0;
const int N = left.size();
const int M = right.size();
while (x < N && y < M && left[x] == right[y])
{
x++; y++; length++;
}
if (length > 0)
{
diags.emplace_back(x - length, y - length, length, prevDiag);
fp[fp0 + k] = Point(y, diags.size() - 1);
}
else
{
fp[fp0 + k] = Point(y, prevDiag);
}
}
std::vector<Hunk> diff(std::span<const char> left, std::span<const char> right)
{
bool reverse = false;
if (left.size() > right.size())
{
std::swap(left, right);
reverse = true;
}
const int N = left.size();
const int M = right.size();
std::vector<Point> fp;
std::vector<Diag> diags;
const int delta = M - N;
const int fp0 = N + 1; // zero offset for furthest point indexing, indexing can go negative.
fp.resize((N+1) + (M+1) + 1);
// All paths will lead to empty diagonal at zero.
diags.push_back(Diag(0, 0, 0, -1));
std::fill(fp.begin(), fp.end(), Point(-1,0));
// Calculate common diagonal sequences
for (int p = 0; fp[fp0 + delta].y != M; p++)
{
for (int k = -p; k <= delta-1; k++)
snake(k, left, right, fp, fp0, diags);
for (int k = delta+p; k >= delta+1; k--)
snake(k, left, right, fp, fp0, diags);
snake(delta, left, right, fp, fp0, diags);
}
// Backtrace shortest edit script
std::vector<Hunk> diff;
Diag prevDiag(N, M, 0, -1);
for (int i = fp[fp0 + delta].diag; i != -1; i = diags[i].next)
{
const Diag& diag = diags[i];
// The path between the diagonals is a changed sequence (hunk)
const int endX = diag.x + diag.length;
const int endY = diag.y + diag.length;
if ((prevDiag.x - endX) > 0 || (prevDiag.y - endY) > 0)
{
if (reverse)
diff.emplace_back(Range(endY, prevDiag.y), Range(endX, prevDiag.x));
else
diff.emplace_back(Range(endX, prevDiag.x), Range(endY, prevDiag.y));
}
prevDiag = diag;
}
// Backtrace left the sequence in reverse, flip it.
std::reverse(diff.begin(), diff.end());
return diff;
}
// Returns true if index is valid for specific diff.
static bool isValidIndex(const int index, std::span<const Hunk> diff)
{
return index < (int)diff.size();
}
// Returns base range, or "infinity" range if the index is out of bounds (i.e. iterator has finished).
static Range getBase(const int index, std::span<const Hunk> diff)
{
static const Range outOfBounds = {INT32_MAX, INT32_MAX};
return index < (int)diff.size() ? diff[index].base : outOfBounds;
}
// Returns a change range matching the base range.
static Range getCombinedChange(const Range combinedBase, const Range hunks, std::span<const Hunk> diff)
{
Range base, change;
if (hunks.size() > 0)
{
// Get the range covered by the hunks.
base = { diff[hunks.start].base.start, diff[hunks.end - 1].base.end };
change = { diff[hunks.start].change.start, diff[hunks.end - 1].change.end };
}
else
{
// No hunks, start is the current index, add empty range,
// it will be expanded below to cover the related range in combined base.
base = { diff[hunks.start].base.start, diff[hunks.start].base.start };
change = { diff[hunks.start].change.start, diff[hunks.start].change.start };
}
// Expand the change to cover all of the base range.
change.start += (combinedBase.start - base.start);
change.end += (combinedBase.end - base.end);
return change;
}
// Returns next run of overlapping hunks, and updates iterator indices.
// If there is just one hunk, the change is obvious. If there are many hunks, the changes conflict.
static void getNextHunks(int index[2], std::span<const Hunk> leftDiff, std::span<const Hunk> rightDiff,
Range& combinedBase, Range hunks[2])
{
// Init to empty range start at current index, this will be later used to match
// base range in case of no conflict.
hunks[0] = {index[0], index[0]};
hunks[1] = {index[1], index[1]};
// Get first hunk. We advance the hunks in both diffs side by side,
// always picking the next to advance based on the order in the base sequence.
Range base[2];
base[0] = getBase(index[0], leftDiff);
base[1] = getBase(index[1], rightDiff);
int side = base[0].start <= base[1].start ? 0 : 1;
combinedBase = base[side];
hunks[side] = Range(index[side], index[side] + 1);
index[side]++;
// Combine all consequtive overlapping hunks.
while (isValidIndex(index[0], leftDiff) || isValidIndex(index[1], rightDiff))
{
base[0] = getBase(index[0], leftDiff);
base[1] = getBase(index[1], rightDiff);
side = base[0].start <= base[1].start ? 0 : 1;
// If the next hunk does not touch the combined hunk so far, stop.
if (base[side].start > combinedBase.end)
break;
// Extend the region to contain the next hunk.
combinedBase.end = std::max(combinedBase.end, base[side].end);
// Keep track of the visited range on each diff.
if (hunks[side].size() == 0)
{
hunks[side].start = index[side];
hunks[side].end = index[side] + 1;
}
else
{
hunks[side].end = index[side] + 1;
}
index[side]++;
}
}
// Combines diffs to chunks which describe ranges of base/left/right and which one should be used for the merged result.
std::vector<Chunk> merge3(std::span<const char> base, std::span<const char> left, std::span<const char> right,
std::span<const Hunk> leftDiff, std::span<const Hunk> rightDiff)
{
std::vector<Chunk> res;
int basePos = 0;
int leftPos = 0;
int rightPos = 0;
int index[2] = { 0, 0 };
while (isValidIndex(index[0], leftDiff) || isValidIndex(index[1], rightDiff))
{
// Get the next range of overlapping hunks from the diffs.
Range combinedBase;
Range hunks[2];
getNextHunks(index, leftDiff, rightDiff, combinedBase, hunks);
// Calculate new combined curr range based on the region.
Range combinedLeft = getCombinedChange(combinedBase, hunks[0], leftDiff);
Range combinedRight = getCombinedChange(combinedBase, hunks[1], rightDiff);
// Classify the change
ChunkType type;
if ((hunks[0].size() + hunks[1].size()) > 1)
{
// More than one hunk contributed to this change, it's a conflict.
type = ChunkType::Conflict;
}
else
{
// One hunk, pick left or right depending where the change came.
type = hunks[0].size() > 0 ? ChunkType::Left : ChunkType::Right;
}
// Commit common sequence between hunks.
if (combinedBase.start > basePos)
{
res.emplace_back(ChunkType::Base, Range(basePos, combinedBase.start),
Range(leftPos, combinedLeft.start), Range(rightPos, combinedRight.start));
}
// Commit changed section
res.emplace_back(type, combinedBase, combinedLeft, combinedRight);
// Keep track how far we've progressed so far, used for detecting common sequences.
basePos = combinedBase.end;
leftPos = combinedLeft.end;
rightPos = combinedRight.end;
}
// Commit remaining common sequence.
const int baseSize = base.size();
const int leftSize = left.size();
const int rightSize = right.size();
if (baseSize > basePos)
{
res.emplace_back(ChunkType::Base, Range(basePos, baseSize),
Range(leftPos, leftSize), Range(rightPos, rightSize));
}
return res;
}
// Merge changes to right insitu
void resolveToRight(std::span<Chunk> merge, std::span<char> base, std::span<char> left, std::vector<char>& right)
{
int offset = 0;
for (const Chunk& c : merge)
{
if (c.type == ChunkType::Conflict)
{
// Conflict, arbitrarily merge left (could be any)
const int rem = c.right.size();
const int add = c.left.size();
const int start = c.right.start + offset;
const int end = c.right.end + offset;
const int len = end - start;
const int copy = std::min(len, std::max(0, add - rem));
std::copy(left.begin() + c.left.start, left.begin() + c.left.start + copy, right.begin() + start);
right.erase(right.begin() + start + copy, right.begin() + end);
right.insert(right.begin() + start + copy, left.begin() + c.left.start + copy, left.begin() + c.left.end);
offset -= rem;
offset += add;
}
else if (c.type == ChunkType::Left)
{
// Left, merge
const int rem = c.right.size();
const int add = c.left.size();
const int start = c.right.start + offset;
const int end = c.right.end + offset;
const int len = end - start;
const int copy = std::min(len, std::max(0, add - rem));
std::copy(left.begin() + c.left.start, left.begin() + c.left.start + copy, right.begin() + start);
right.erase(right.begin() + start + copy, right.begin() + end);
right.insert(right.begin() + start + copy, left.begin() + c.left.start + copy, left.begin() + c.left.end);
offset -= rem;
offset += add;
}
else if (c.type == ChunkType::Right)
{
// Right, keep.
}
else
{
// Common, keep
}
}
}
void printDiff(std::span<char> left, std::span<char> right, std::span<Hunk> diff)
{
int base = 0, change = 0;
// Left
printf(" Left: ");
base = 0;
change = 0;
for (const Hunk& hunk : diff)
{
// Common sequence
for (int k = base; k < hunk.base.start; k++)
printf("%c", left[k]);
// Mismatching sequence
printf("\u001b[41m");
for (int k = hunk.base.start; k < hunk.base.end; k++)
printf("%c", left[k]);
printf("\u001b[0m");
for (int k = hunk.base.end - hunk.base.start; k < (hunk.change.end - hunk.change.start); k++)
printf(" ");
base = hunk.base.end;
change = hunk.change.end;
}
// Last common chunk
for (int k = base; k < left.size(); k++)
printf("%c", left[k]);
printf("\n");
// Right
printf(" Right: ");
base = 0;
change = 0;
for (const Hunk& hunk : diff)
{
// Common sequence
for (int k = change; k < hunk.change.start; k++)
printf("%c", right[k]);
// Mismatching sequence
printf("\u001b[41m");
for (int k = hunk.change.start; k < hunk.change.end; k++)
printf("%c", right[k]);
printf("\u001b[0m");
for (int k = hunk.change.end - hunk.change.start; k < (hunk.base.end - hunk.base.start); k++)
printf(" ");
base = hunk.base.end;
change = hunk.change.end;
}
// Last common chunk
for (int k = change; k < right.size(); k++)
printf("%c", right[k]);
printf("\n");
}
void printMerge(std::span<char> base, std::span<char> left, std::span<char> right, std::span<Chunk> merge)
{
printf(" ");
for (const Chunk& c : merge)
{
const char* s = &base[0];
Range range;
if (c.type == ChunkType::Conflict)
{
// Conflict, arbitrarily choose left
s = &left[0];
range = c.left;
printf("\u001b[41m");
}
else if (c.type == ChunkType::Left)
{
// Left
s = &left[0];
range = c.left;
printf("\u001b[44m");
}
else if (c.type == ChunkType::Right)
{
// Right
s = &right[0];
range = c.right;
printf("\u001b[46m");
}
else
{
// Common
s = &base[0];
range = c.base;
}
for (int k = range.start; k < range.end; k++)
printf("%c", s[k]);
printf("\u001b[0m");
}
printf("\n");
}
void printv(std::span<char> arr)
{
printf(" ");
for (const char c : arr)
printf("%c", c);
printf("\n");
}
std::vector<char> makev(const char* s)
{
std::vector<char> str;
while (*s)
{
str.push_back(*s);
s++;
}
return str;
}
int main()
{
std::vector<char> base = makev("the quick fox jumps ovre some lazy dog");
std::vector<char> left = makev("the quick brown fox jumped ovre a dog");
std::vector<char> right = makev("the quick brown fox jumps over some record dog");
std::vector<Hunk> leftDiff = diff(base, left);
std::vector<Hunk> rightDiff = diff(base, right);
std::vector<Chunk> merged = merge3(base, left, right, leftDiff, rightDiff);
printf("Diff Base > \u001b[44mLeft\u001b[0m\n");
printDiff(base, left, leftDiff);
printf("Diff Base > \u001b[46mRight\u001b[0m\n");
printDiff(base, right, rightDiff);
printf("Merged\n");
printMerge(base, left, right, merged);
resolveToRight(merged, base, left, right);
printv(right);
return 0;
}
@memononen
Copy link
Author

memononen commented Dec 3, 2021

g++ -std=c++2a merge.cpp -o merge
./merge
Diff Base > Left
  Left:  the quick       fox jumps  ovre some lazy dog
  Right: the quick brown fox jumped ovre       a   dog
Diff Base > Right
  Left:  the quick       fox jumps ov re some lazy   dog
  Right: the quick brown fox jumps over  some record dog
Merged
  the quick brown fox jumped over a dog
  the quick brown fox jumped over a dog

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment