Skip to content

Instantly share code, notes, and snippets.

@ram-nad
Last active March 27, 2020 05:12
Show Gist options
  • Save ram-nad/049a75e96a129c3461286f2cb9792543 to your computer and use it in GitHub Desktop.
Save ram-nad/049a75e96a129c3461286f2cb9792543 to your computer and use it in GitHub Desktop.
Karatsuba Multiplication
#include <chrono>
#include <cstring> // For memcpy, memset
#include <exception>
#include <iomanip>
#include <iostream>
#include <string>
#define KARATSUBA_CUTOFF(type) static_cast<type>(50)
class NumberBase {
using size_t = std::size_t;
public:
using limb_t = unsigned int;
using double_limb_t = unsigned long long;
static constexpr limb_t limb_bytes = sizeof(limb_t);
static constexpr limb_t limb_bits = limb_bytes * 8;
private:
limb_t* data;
size_t data_count;
bool negative;
public:
NumberBase(size_t count, limb_t* mem, bool init = true)
: negative(false), data(mem), data_count(count) {
if (init) {
memset(this->data, 0, this->data_count * NumberBase::limb_bytes);
}
}
NumberBase(size_t count, limb_t* mem, const NumberBase& other)
: negative(other.negative), data(mem), data_count(count) {
if (other.data_count > this->data_count) {
memcpy(this->data, other.data, this->data_count * NumberBase::limb_bytes);
} else {
memcpy(this->data, other.data, other.data_count * NumberBase::limb_bytes);
memset(this->data + other.data_count * NumberBase::limb_bytes, 0,
(this->data_count - other.data_count) * NumberBase::limb_bytes);
}
}
explicit NumberBase(size_t count, limb_t* mem, const std::string& val)
: data(mem), data_count(count) {
std::string init = val;
set_zero();
size_t len = init.size();
size_t start = 0;
if (init[0] == '-') {
this->negative = true;
start = 1;
} else {
this->negative = false;
}
for (size_t i = start; i < len; i++) {
if (init[i] < '0' || init[i] > '9') {
throw std::invalid_argument("Unexpected Character.");
} else {
init[i] -= '0';
}
}
size_t pos = 0;
size_t shift = 0;
while (start < len) {
int carry = 0;
for (size_t i = start; i < len; i++) {
init[i] = init[i] + carry;
carry = (init[i] & 1) * 10;
init[i] >>= 1;
}
if (init[start] == 0) {
start++;
}
if (carry) {
this->data[pos] += 1 << shift;
shift++;
} else {
shift++;
}
if (shift == NumberBase::limb_bits) {
pos++;
shift = 0;
}
}
}
const limb_t& operator()(size_t index) const { return this->data[index]; }
bool is_negative() const { return negative; }
std::size_t count() const { return this->data_count; }
limb_t& operator()(size_t index) { return this->data[index]; }
void negate() { this->negative = !this->negative; }
void set_sign(bool sign) { this->negative = sign; }
void set_zero() {
memset(this->data, 0, this->data_count * NumberBase::limb_bytes);
}
NumberBase(NumberBase&&) = delete;
NumberBase(const NumberBase&) = delete;
};
// Ouput the Number in hexadecimal format
std::ostream& operator<<(std::ostream& out, const NumberBase& num) {
// Store the original flags
std::ios_base::fmtflags original(out.flags());
for (int i = num.count() - 1; i >= 0; i--) {
out << std::setfill('0') << std::setw(NumberBase::limb_bytes * 2)
<< std::hex << static_cast<int>(num(i));
}
// Restore the flags
out.flags(original);
return out;
}
// Finds max or min
inline std::size_t min(std::size_t a, std::size_t b) { return a > b ? b : a; }
inline std::size_t max(std::size_t a, std::size_t b) { return a > b ? a : b; }
int cmp_abs(const NumberBase& first, const NumberBase& second) {
std::size_t f_count = first.count();
std::size_t s_count = second.count();
int i = max(f_count, s_count) - 1;
// Make sure that i == min(s_count-1, f_count-1)
for (; i >= s_count; i--) {
// If any of extra bits in first is greater than 0, return 1
if (first(i)) {
return 1;
}
}
// Make sure that i == min(s_count-1, f_count-1)
for (; i >= f_count; i--) {
// If any of extra bits in second is greater than 0, return 1
if (second(i)) {
return -1;
}
}
// For common bits compare for unequal limbs
// return on mismatch
for (; i >= 0; i--) {
if (first(i) > second(i)) {
return 1;
} else if (first(i) < second(i)) {
return -1;
}
}
// Finally if no cases work out return 0
return 0;
}
int cmp(const NumberBase& first, const NumberBase& second) {
bool f_neg = first.is_negative();
bool s_neg = second.is_negative();
if (f_neg != s_neg) {
// If bith have different signs
// return on basis of it
return f_neg ? -1 : 1;
} else {
int cmp = cmp_abs(first, second);
// Change sign of absolute comparision
// if any is negative
return f_neg ? -1 * cmp : cmp;
}
}
// result = first + second. Works even when result == first
void eval_add_abs(NumberBase& result, const NumberBase& first,
const NumberBase& second) {
NumberBase::double_limb_t carry = 0;
std::size_t f_count = first.count();
std::size_t s_count = second.count();
std::size_t data_count = result.count();
int i = 0;
// First add limbs common in both
while (i < f_count && i < s_count && i < data_count) {
carry += static_cast<NumberBase::double_limb_t>(first(i)) +
static_cast<NumberBase::double_limb_t>(second(i));
result(i) = carry;
carry >>= NumberBase::limb_bits;
i++;
}
// If still space is left and first one isn't over add it over
while (i < f_count && i < data_count) {
carry += static_cast<NumberBase::double_limb_t>(first(i));
result(i) = carry;
carry >>= NumberBase::limb_bits;
i++;
}
// If still space is left and second one isn't over add it over
while (i < s_count && i < data_count) {
carry += static_cast<NumberBase::double_limb_t>(second(i));
result(i) = carry;
carry >>= NumberBase::limb_bits;
i++;
}
if (carry && i < data_count) {
result(i) = carry;
i++;
}
while (i < data_count) {
result(i) = static_cast<NumberBase::limb_t>(0);
i++;
}
}
// Works even when result == first
// Always calculates result = (first - second)
void eval_sub_abs(NumberBase& result, const NumberBase& first,
const NumberBase& second) {
NumberBase::limb_t sub = 0;
std::size_t f_count = first.count();
std::size_t s_count = second.count();
std::size_t data_count = result.count();
int i = 0;
// Subtract limbs that exist in both
while (i < f_count && i < s_count && i < data_count) {
if (first(i) > second(i)) {
// If first(i) > second(i), this expression would never overflow
result(i) = first(i) - (second(i) + sub);
sub = static_cast<NumberBase::limb_t>(0);
} else {
if (first(i) == second(i)) {
if (sub) {
// If first(i) equals second(i) and borrow is 1
// we need to fill this with 1's
// and set borrow = 1
result(i) = ~static_cast<NumberBase::limb_t>(0);
} else {
// If borrow is also 0, then result(i) is simply 0
result(i) = static_cast<NumberBase::limb_t>(0);
}
} else {
// If second(i) > first(i) and borrow is 1
// simply reverting bits of difference would work
if (sub) {
result(i) = ~(second(i) - first(i));
} else {
// else, we need to add 1 to result, this will not overflow
// as, second(i) - first(i) is never 0
result(i) =
~(second(i) - first(i)) + static_cast<NumberBase::limb_t>(1);
}
sub = static_cast<NumberBase::limb_t>(1);
}
}
i++;
}
// If more space is left and first is also left
while (i < f_count && i < data_count) {
if (!first(i)) {
if (sub) {
result(i) = ~static_cast<NumberBase::limb_t>(0);
} else {
result(i) = static_cast<NumberBase::limb_t>(0);
}
} else {
if (sub) {
result(i) = first(i) - sub;
sub = static_cast<NumberBase::limb_t>(0);
} else {
result(i) = first(i);
}
}
i++;
}
// If more space is left and second is also left
while (i < s_count && i < data_count) {
if (!second(i)) {
if (sub) {
result(i) = ~static_cast<NumberBase::limb_t>(0);
} else {
result(i) = static_cast<NumberBase::limb_t>(0);
}
} else {
if (sub) {
result(i) = ~second(i);
sub = static_cast<NumberBase::limb_t>(0);
} else {
result(i) = ~second(i) + static_cast<NumberBase::limb_t>(1);
}
}
i++;
}
// Fill in remaining space with data
while (i < data_count) {
if (sub) {
result(i) = ~static_cast<NumberBase::limb_t>(0);
} else {
result(i) = static_cast<NumberBase::limb_t>(0);
}
i++;
}
}
void eval_add(NumberBase& result, const NumberBase& first,
const NumberBase& second) {
bool f_neg = first.is_negative();
bool s_neg = second.is_negative();
// if both operands have same sign
// Add them and set same sign as operands
if (f_neg == s_neg) {
eval_add_abs(result, first, second);
result.set_sign(f_neg);
} else {
// Set absolute value as their difference
// Set sign of larger operand
int sign = cmp_abs(first, second);
if (sign == 0) {
result.set_zero();
} else if (sign > 0) {
eval_sub_abs(result, first, second);
result.set_sign(f_neg);
} else {
eval_sub_abs(result, second, first);
result.set_sign(s_neg);
}
}
}
void eval_sub(NumberBase& result, const NumberBase& first,
const NumberBase& second) {
bool f_neg = first.is_negative();
bool s_neg = second.is_negative();
if (f_neg == s_neg) {
// If both have different sign
// set absolute as their absolute difference
// set sign of larger operand
int sign = cmp_abs(first, second);
if (sign == 0) {
result.set_zero();
} else if (sign > 0) {
eval_sub_abs(result, first, second);
result.set_sign(f_neg);
} else {
eval_sub_abs(result, second, first);
result.set_sign(!s_neg);
}
} else {
// If both have different sign, add their absolute
// set same sign as first operand
eval_add_abs(result, first, second);
result.set_sign(f_neg);
}
}
void base_multiply(NumberBase& r, const NumberBase& a, const NumberBase& b) {
r.set_zero();
std::size_t a_count = a.count();
std::size_t b_count = b.count();
std::size_t data_count = r.count();
for (int i = 0; i < a_count; i++) {
// (a * 2^i) * (b * 2^j) = (a * b) * (2^(i+j))
NumberBase::double_limb_t carry = 0;
for (int j = 0; j < b_count && (i + j) < data_count; j++) {
carry += static_cast<NumberBase::double_limb_t>(a(i)) *
static_cast<NumberBase::double_limb_t>(b(j));
carry += r(i + j);
r(i + j) = carry;
carry >>= NumberBase::limb_bits;
}
if (carry && (i + b_count) < data_count) {
carry += r(i + b_count);
r(i + b_count) = carry;
}
}
}
// Assumption storage is 4 * max(size(a), size(b))
void karatsuba_multiply(NumberBase& r, NumberBase& a, NumberBase& b,
NumberBase::limb_t* storage) {
std::size_t a_count = a.count();
std::size_t b_count = b.count();
// Cutoff for using base-multiplication
if (a_count < KARATSUBA_CUTOFF(std::size_t) ||
b_count < KARATSUBA_CUTOFF(std::size_t)) {
base_multiply(r, a, b);
} else {
/*
Using, subtractive form of Karatsuba Multiplication
*/
std::size_t n_size = max((a_count + 1) / 2, (b_count + 1) / 2);
std::size_t h_a_size = a_count > n_size ? a_count - n_size : 0;
std::size_t h_b_size = b_count > n_size ? b_count - n_size : 0;
std::size_t l_a_size = a_count > n_size ? n_size : a_count;
std::size_t l_b_size = b_count > n_size ? n_size : b_count;
NumberBase h_r(a_count + b_count - 2 * n_size, &r(2 * n_size));
NumberBase m_r(a_count + b_count - n_size, &r(n_size), false);
NumberBase l_r(2 * n_size, &r(0));
// a1 = a / 2^n_size
NumberBase h_a(h_a_size, &a(l_a_size), false);
// a0 = a % 2^n_size
NumberBase l_a(l_a_size, &a(0), false);
// b1 = b / 2^n_size
NumberBase h_b(h_b_size, &b(l_b_size), false);
// b0 = b % 2^n_size
NumberBase l_b(l_b_size, &b(0), false);
// r += (a1 * b1)*(2^n_size)
karatsuba_multiply(h_r, h_a, h_b, storage);
// r += a0 * b0
karatsuba_multiply(l_r, l_a, l_b, storage);
int cmp_a = cmp_abs(l_a, h_a);
int cmp_b = cmp_abs(l_b, h_b);
NumberBase mid(2 * n_size, storage);
// |a1 - a0|
NumberBase mid_a(n_size, storage + 2 * n_size);
// |b1 - b0|
NumberBase mid_b(n_size, storage + 3 * n_size);
if (cmp_a > 0) {
eval_sub_abs(mid_a, l_a, h_a);
} else if (cmp_a < 0) {
eval_sub_abs(mid_a, h_a, l_a);
}
if (cmp_b > 0) {
eval_sub_abs(mid_b, l_b, h_b);
} else if (cmp_b < 0) {
eval_sub_abs(mid_b, h_b, l_b);
}
// |a1 - a0| * |b1 - b0|
karatsuba_multiply(mid, mid_a, mid_b, storage + 4 * n_size);
mid.set_sign(-1 * cmp_a * cmp_b < 0);
eval_add(mid, mid, h_r);
eval_add(mid, mid, l_r);
eval_add_abs(m_r, m_r, mid);
}
}
void eval_multiply(NumberBase& result, NumberBase& first, NumberBase& second) {
std::size_t r_count = first.count() + second.count();
NumberBase::limb_t* storage =
new NumberBase::limb_t[8 * max(first.count(), second.count())];
karatsuba_multiply(result, first, second, storage);
result.set_sign(first.is_negative() != second.is_negative());
}
template <std::size_t precision>
class Number : public NumberBase {
public:
static constexpr std::size_t data_count =
(precision + NumberBase::limb_bits - 1) / NumberBase::limb_bits;
static constexpr std::size_t actual_precision =
data_count * NumberBase::limb_bits;
Number() : NumberBase(data_count, this->data){};
Number(const std::string& val) : NumberBase(data_count, this->data, val){};
Number(const Number& num) : NumberBase(this->data, num){};
private:
limb_t data[data_count];
};
int main() {
using namespace std::chrono;
Number<3000> a(
"232924386438745725427542745242425253632636363672475245275428682649269625"
"936592659265926926593463846389463894689264893642659326432617838960385965"
"936740674985683548354692725035784538724589198126849572145783426346128371"
"232924386438745725427542745242425253632636363672475245275428682649269625"
"936592659265926926593463846389463894689264893642659326432617838960385965"
"936740674985683548354692725035784538724589198126849572145783426346128371"
"232924386438745725427542745242425253632636363672475245275428682649269625"
"936592659265926926593463846389463894689264893642659326432617838960385965"
"936740674985683548354692725035784538724589198126849572145783426346128371"
"936592659265926926593463846389463894689264893642659326432617838960385965"
"936740674985683548354692725035784538724589198126849572145783426346128371"
"2329243864387457254275427452424252536326363636724752452754286826492696");
Number<1000> b(
"464646458976412132165484653132156487879416163132123156864894791313213216"
"548645612316584894512302131231564684897498748451321321231684894654321324"
"464646458976412132165484653132156487879416163132123156864894791313213216"
"548645612316584894512302131231564684897498748451321321231684894654321324"
"0231561648948941651310354189487894658415612165165");
Number<4000> c;
Number<4000> d;
NumberBase::limb_t st[375];
// std::cout << a << std::endl;
// std::cout << b << std::endl;
auto start = high_resolution_clock::now();
karatsuba_multiply(c, b, a, st);
std::cout << c << std::endl;
auto stop = high_resolution_clock::now();
auto duration = duration_cast<microseconds>(stop - start);
std::cout << std::dec << "KaratSuba Time: " << duration.count() << std::endl;
start = high_resolution_clock::now();
base_multiply(d, b, a);
std::cout << d << std::endl;
stop = high_resolution_clock::now();
duration = duration_cast<microseconds>(stop - start);
std::cout << std::dec << "Base Time: " << duration.count() << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment