Last active
March 27, 2020 05:12
-
-
Save ram-nad/049a75e96a129c3461286f2cb9792543 to your computer and use it in GitHub Desktop.
Karatsuba Multiplication
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <chrono> | |
#include <cstring> // For memcpy, memset | |
#include <exception> | |
#include <iomanip> | |
#include <iostream> | |
#include <string> | |
#define KARATSUBA_CUTOFF(type) static_cast<type>(50) | |
class NumberBase { | |
using size_t = std::size_t; | |
public: | |
using limb_t = unsigned int; | |
using double_limb_t = unsigned long long; | |
static constexpr limb_t limb_bytes = sizeof(limb_t); | |
static constexpr limb_t limb_bits = limb_bytes * 8; | |
private: | |
limb_t* data; | |
size_t data_count; | |
bool negative; | |
public: | |
NumberBase(size_t count, limb_t* mem, bool init = true) | |
: negative(false), data(mem), data_count(count) { | |
if (init) { | |
memset(this->data, 0, this->data_count * NumberBase::limb_bytes); | |
} | |
} | |
NumberBase(size_t count, limb_t* mem, const NumberBase& other) | |
: negative(other.negative), data(mem), data_count(count) { | |
if (other.data_count > this->data_count) { | |
memcpy(this->data, other.data, this->data_count * NumberBase::limb_bytes); | |
} else { | |
memcpy(this->data, other.data, other.data_count * NumberBase::limb_bytes); | |
memset(this->data + other.data_count * NumberBase::limb_bytes, 0, | |
(this->data_count - other.data_count) * NumberBase::limb_bytes); | |
} | |
} | |
explicit NumberBase(size_t count, limb_t* mem, const std::string& val) | |
: data(mem), data_count(count) { | |
std::string init = val; | |
set_zero(); | |
size_t len = init.size(); | |
size_t start = 0; | |
if (init[0] == '-') { | |
this->negative = true; | |
start = 1; | |
} else { | |
this->negative = false; | |
} | |
for (size_t i = start; i < len; i++) { | |
if (init[i] < '0' || init[i] > '9') { | |
throw std::invalid_argument("Unexpected Character."); | |
} else { | |
init[i] -= '0'; | |
} | |
} | |
size_t pos = 0; | |
size_t shift = 0; | |
while (start < len) { | |
int carry = 0; | |
for (size_t i = start; i < len; i++) { | |
init[i] = init[i] + carry; | |
carry = (init[i] & 1) * 10; | |
init[i] >>= 1; | |
} | |
if (init[start] == 0) { | |
start++; | |
} | |
if (carry) { | |
this->data[pos] += 1 << shift; | |
shift++; | |
} else { | |
shift++; | |
} | |
if (shift == NumberBase::limb_bits) { | |
pos++; | |
shift = 0; | |
} | |
} | |
} | |
const limb_t& operator()(size_t index) const { return this->data[index]; } | |
bool is_negative() const { return negative; } | |
std::size_t count() const { return this->data_count; } | |
limb_t& operator()(size_t index) { return this->data[index]; } | |
void negate() { this->negative = !this->negative; } | |
void set_sign(bool sign) { this->negative = sign; } | |
void set_zero() { | |
memset(this->data, 0, this->data_count * NumberBase::limb_bytes); | |
} | |
NumberBase(NumberBase&&) = delete; | |
NumberBase(const NumberBase&) = delete; | |
}; | |
// Ouput the Number in hexadecimal format | |
std::ostream& operator<<(std::ostream& out, const NumberBase& num) { | |
// Store the original flags | |
std::ios_base::fmtflags original(out.flags()); | |
for (int i = num.count() - 1; i >= 0; i--) { | |
out << std::setfill('0') << std::setw(NumberBase::limb_bytes * 2) | |
<< std::hex << static_cast<int>(num(i)); | |
} | |
// Restore the flags | |
out.flags(original); | |
return out; | |
} | |
// Finds max or min | |
inline std::size_t min(std::size_t a, std::size_t b) { return a > b ? b : a; } | |
inline std::size_t max(std::size_t a, std::size_t b) { return a > b ? a : b; } | |
int cmp_abs(const NumberBase& first, const NumberBase& second) { | |
std::size_t f_count = first.count(); | |
std::size_t s_count = second.count(); | |
int i = max(f_count, s_count) - 1; | |
// Make sure that i == min(s_count-1, f_count-1) | |
for (; i >= s_count; i--) { | |
// If any of extra bits in first is greater than 0, return 1 | |
if (first(i)) { | |
return 1; | |
} | |
} | |
// Make sure that i == min(s_count-1, f_count-1) | |
for (; i >= f_count; i--) { | |
// If any of extra bits in second is greater than 0, return 1 | |
if (second(i)) { | |
return -1; | |
} | |
} | |
// For common bits compare for unequal limbs | |
// return on mismatch | |
for (; i >= 0; i--) { | |
if (first(i) > second(i)) { | |
return 1; | |
} else if (first(i) < second(i)) { | |
return -1; | |
} | |
} | |
// Finally if no cases work out return 0 | |
return 0; | |
} | |
int cmp(const NumberBase& first, const NumberBase& second) { | |
bool f_neg = first.is_negative(); | |
bool s_neg = second.is_negative(); | |
if (f_neg != s_neg) { | |
// If bith have different signs | |
// return on basis of it | |
return f_neg ? -1 : 1; | |
} else { | |
int cmp = cmp_abs(first, second); | |
// Change sign of absolute comparision | |
// if any is negative | |
return f_neg ? -1 * cmp : cmp; | |
} | |
} | |
// result = first + second. Works even when result == first | |
void eval_add_abs(NumberBase& result, const NumberBase& first, | |
const NumberBase& second) { | |
NumberBase::double_limb_t carry = 0; | |
std::size_t f_count = first.count(); | |
std::size_t s_count = second.count(); | |
std::size_t data_count = result.count(); | |
int i = 0; | |
// First add limbs common in both | |
while (i < f_count && i < s_count && i < data_count) { | |
carry += static_cast<NumberBase::double_limb_t>(first(i)) + | |
static_cast<NumberBase::double_limb_t>(second(i)); | |
result(i) = carry; | |
carry >>= NumberBase::limb_bits; | |
i++; | |
} | |
// If still space is left and first one isn't over add it over | |
while (i < f_count && i < data_count) { | |
carry += static_cast<NumberBase::double_limb_t>(first(i)); | |
result(i) = carry; | |
carry >>= NumberBase::limb_bits; | |
i++; | |
} | |
// If still space is left and second one isn't over add it over | |
while (i < s_count && i < data_count) { | |
carry += static_cast<NumberBase::double_limb_t>(second(i)); | |
result(i) = carry; | |
carry >>= NumberBase::limb_bits; | |
i++; | |
} | |
if (carry && i < data_count) { | |
result(i) = carry; | |
i++; | |
} | |
while (i < data_count) { | |
result(i) = static_cast<NumberBase::limb_t>(0); | |
i++; | |
} | |
} | |
// Works even when result == first | |
// Always calculates result = (first - second) | |
void eval_sub_abs(NumberBase& result, const NumberBase& first, | |
const NumberBase& second) { | |
NumberBase::limb_t sub = 0; | |
std::size_t f_count = first.count(); | |
std::size_t s_count = second.count(); | |
std::size_t data_count = result.count(); | |
int i = 0; | |
// Subtract limbs that exist in both | |
while (i < f_count && i < s_count && i < data_count) { | |
if (first(i) > second(i)) { | |
// If first(i) > second(i), this expression would never overflow | |
result(i) = first(i) - (second(i) + sub); | |
sub = static_cast<NumberBase::limb_t>(0); | |
} else { | |
if (first(i) == second(i)) { | |
if (sub) { | |
// If first(i) equals second(i) and borrow is 1 | |
// we need to fill this with 1's | |
// and set borrow = 1 | |
result(i) = ~static_cast<NumberBase::limb_t>(0); | |
} else { | |
// If borrow is also 0, then result(i) is simply 0 | |
result(i) = static_cast<NumberBase::limb_t>(0); | |
} | |
} else { | |
// If second(i) > first(i) and borrow is 1 | |
// simply reverting bits of difference would work | |
if (sub) { | |
result(i) = ~(second(i) - first(i)); | |
} else { | |
// else, we need to add 1 to result, this will not overflow | |
// as, second(i) - first(i) is never 0 | |
result(i) = | |
~(second(i) - first(i)) + static_cast<NumberBase::limb_t>(1); | |
} | |
sub = static_cast<NumberBase::limb_t>(1); | |
} | |
} | |
i++; | |
} | |
// If more space is left and first is also left | |
while (i < f_count && i < data_count) { | |
if (!first(i)) { | |
if (sub) { | |
result(i) = ~static_cast<NumberBase::limb_t>(0); | |
} else { | |
result(i) = static_cast<NumberBase::limb_t>(0); | |
} | |
} else { | |
if (sub) { | |
result(i) = first(i) - sub; | |
sub = static_cast<NumberBase::limb_t>(0); | |
} else { | |
result(i) = first(i); | |
} | |
} | |
i++; | |
} | |
// If more space is left and second is also left | |
while (i < s_count && i < data_count) { | |
if (!second(i)) { | |
if (sub) { | |
result(i) = ~static_cast<NumberBase::limb_t>(0); | |
} else { | |
result(i) = static_cast<NumberBase::limb_t>(0); | |
} | |
} else { | |
if (sub) { | |
result(i) = ~second(i); | |
sub = static_cast<NumberBase::limb_t>(0); | |
} else { | |
result(i) = ~second(i) + static_cast<NumberBase::limb_t>(1); | |
} | |
} | |
i++; | |
} | |
// Fill in remaining space with data | |
while (i < data_count) { | |
if (sub) { | |
result(i) = ~static_cast<NumberBase::limb_t>(0); | |
} else { | |
result(i) = static_cast<NumberBase::limb_t>(0); | |
} | |
i++; | |
} | |
} | |
void eval_add(NumberBase& result, const NumberBase& first, | |
const NumberBase& second) { | |
bool f_neg = first.is_negative(); | |
bool s_neg = second.is_negative(); | |
// if both operands have same sign | |
// Add them and set same sign as operands | |
if (f_neg == s_neg) { | |
eval_add_abs(result, first, second); | |
result.set_sign(f_neg); | |
} else { | |
// Set absolute value as their difference | |
// Set sign of larger operand | |
int sign = cmp_abs(first, second); | |
if (sign == 0) { | |
result.set_zero(); | |
} else if (sign > 0) { | |
eval_sub_abs(result, first, second); | |
result.set_sign(f_neg); | |
} else { | |
eval_sub_abs(result, second, first); | |
result.set_sign(s_neg); | |
} | |
} | |
} | |
void eval_sub(NumberBase& result, const NumberBase& first, | |
const NumberBase& second) { | |
bool f_neg = first.is_negative(); | |
bool s_neg = second.is_negative(); | |
if (f_neg == s_neg) { | |
// If both have different sign | |
// set absolute as their absolute difference | |
// set sign of larger operand | |
int sign = cmp_abs(first, second); | |
if (sign == 0) { | |
result.set_zero(); | |
} else if (sign > 0) { | |
eval_sub_abs(result, first, second); | |
result.set_sign(f_neg); | |
} else { | |
eval_sub_abs(result, second, first); | |
result.set_sign(!s_neg); | |
} | |
} else { | |
// If both have different sign, add their absolute | |
// set same sign as first operand | |
eval_add_abs(result, first, second); | |
result.set_sign(f_neg); | |
} | |
} | |
void base_multiply(NumberBase& r, const NumberBase& a, const NumberBase& b) { | |
r.set_zero(); | |
std::size_t a_count = a.count(); | |
std::size_t b_count = b.count(); | |
std::size_t data_count = r.count(); | |
for (int i = 0; i < a_count; i++) { | |
// (a * 2^i) * (b * 2^j) = (a * b) * (2^(i+j)) | |
NumberBase::double_limb_t carry = 0; | |
for (int j = 0; j < b_count && (i + j) < data_count; j++) { | |
carry += static_cast<NumberBase::double_limb_t>(a(i)) * | |
static_cast<NumberBase::double_limb_t>(b(j)); | |
carry += r(i + j); | |
r(i + j) = carry; | |
carry >>= NumberBase::limb_bits; | |
} | |
if (carry && (i + b_count) < data_count) { | |
carry += r(i + b_count); | |
r(i + b_count) = carry; | |
} | |
} | |
} | |
// Assumption storage is 4 * max(size(a), size(b)) | |
void karatsuba_multiply(NumberBase& r, NumberBase& a, NumberBase& b, | |
NumberBase::limb_t* storage) { | |
std::size_t a_count = a.count(); | |
std::size_t b_count = b.count(); | |
// Cutoff for using base-multiplication | |
if (a_count < KARATSUBA_CUTOFF(std::size_t) || | |
b_count < KARATSUBA_CUTOFF(std::size_t)) { | |
base_multiply(r, a, b); | |
} else { | |
/* | |
Using, subtractive form of Karatsuba Multiplication | |
*/ | |
std::size_t n_size = max((a_count + 1) / 2, (b_count + 1) / 2); | |
std::size_t h_a_size = a_count > n_size ? a_count - n_size : 0; | |
std::size_t h_b_size = b_count > n_size ? b_count - n_size : 0; | |
std::size_t l_a_size = a_count > n_size ? n_size : a_count; | |
std::size_t l_b_size = b_count > n_size ? n_size : b_count; | |
NumberBase h_r(a_count + b_count - 2 * n_size, &r(2 * n_size)); | |
NumberBase m_r(a_count + b_count - n_size, &r(n_size), false); | |
NumberBase l_r(2 * n_size, &r(0)); | |
// a1 = a / 2^n_size | |
NumberBase h_a(h_a_size, &a(l_a_size), false); | |
// a0 = a % 2^n_size | |
NumberBase l_a(l_a_size, &a(0), false); | |
// b1 = b / 2^n_size | |
NumberBase h_b(h_b_size, &b(l_b_size), false); | |
// b0 = b % 2^n_size | |
NumberBase l_b(l_b_size, &b(0), false); | |
// r += (a1 * b1)*(2^n_size) | |
karatsuba_multiply(h_r, h_a, h_b, storage); | |
// r += a0 * b0 | |
karatsuba_multiply(l_r, l_a, l_b, storage); | |
int cmp_a = cmp_abs(l_a, h_a); | |
int cmp_b = cmp_abs(l_b, h_b); | |
NumberBase mid(2 * n_size, storage); | |
// |a1 - a0| | |
NumberBase mid_a(n_size, storage + 2 * n_size); | |
// |b1 - b0| | |
NumberBase mid_b(n_size, storage + 3 * n_size); | |
if (cmp_a > 0) { | |
eval_sub_abs(mid_a, l_a, h_a); | |
} else if (cmp_a < 0) { | |
eval_sub_abs(mid_a, h_a, l_a); | |
} | |
if (cmp_b > 0) { | |
eval_sub_abs(mid_b, l_b, h_b); | |
} else if (cmp_b < 0) { | |
eval_sub_abs(mid_b, h_b, l_b); | |
} | |
// |a1 - a0| * |b1 - b0| | |
karatsuba_multiply(mid, mid_a, mid_b, storage + 4 * n_size); | |
mid.set_sign(-1 * cmp_a * cmp_b < 0); | |
eval_add(mid, mid, h_r); | |
eval_add(mid, mid, l_r); | |
eval_add_abs(m_r, m_r, mid); | |
} | |
} | |
void eval_multiply(NumberBase& result, NumberBase& first, NumberBase& second) { | |
std::size_t r_count = first.count() + second.count(); | |
NumberBase::limb_t* storage = | |
new NumberBase::limb_t[8 * max(first.count(), second.count())]; | |
karatsuba_multiply(result, first, second, storage); | |
result.set_sign(first.is_negative() != second.is_negative()); | |
} | |
template <std::size_t precision> | |
class Number : public NumberBase { | |
public: | |
static constexpr std::size_t data_count = | |
(precision + NumberBase::limb_bits - 1) / NumberBase::limb_bits; | |
static constexpr std::size_t actual_precision = | |
data_count * NumberBase::limb_bits; | |
Number() : NumberBase(data_count, this->data){}; | |
Number(const std::string& val) : NumberBase(data_count, this->data, val){}; | |
Number(const Number& num) : NumberBase(this->data, num){}; | |
private: | |
limb_t data[data_count]; | |
}; | |
int main() { | |
using namespace std::chrono; | |
Number<3000> a( | |
"232924386438745725427542745242425253632636363672475245275428682649269625" | |
"936592659265926926593463846389463894689264893642659326432617838960385965" | |
"936740674985683548354692725035784538724589198126849572145783426346128371" | |
"232924386438745725427542745242425253632636363672475245275428682649269625" | |
"936592659265926926593463846389463894689264893642659326432617838960385965" | |
"936740674985683548354692725035784538724589198126849572145783426346128371" | |
"232924386438745725427542745242425253632636363672475245275428682649269625" | |
"936592659265926926593463846389463894689264893642659326432617838960385965" | |
"936740674985683548354692725035784538724589198126849572145783426346128371" | |
"936592659265926926593463846389463894689264893642659326432617838960385965" | |
"936740674985683548354692725035784538724589198126849572145783426346128371" | |
"2329243864387457254275427452424252536326363636724752452754286826492696"); | |
Number<1000> b( | |
"464646458976412132165484653132156487879416163132123156864894791313213216" | |
"548645612316584894512302131231564684897498748451321321231684894654321324" | |
"464646458976412132165484653132156487879416163132123156864894791313213216" | |
"548645612316584894512302131231564684897498748451321321231684894654321324" | |
"0231561648948941651310354189487894658415612165165"); | |
Number<4000> c; | |
Number<4000> d; | |
NumberBase::limb_t st[375]; | |
// std::cout << a << std::endl; | |
// std::cout << b << std::endl; | |
auto start = high_resolution_clock::now(); | |
karatsuba_multiply(c, b, a, st); | |
std::cout << c << std::endl; | |
auto stop = high_resolution_clock::now(); | |
auto duration = duration_cast<microseconds>(stop - start); | |
std::cout << std::dec << "KaratSuba Time: " << duration.count() << std::endl; | |
start = high_resolution_clock::now(); | |
base_multiply(d, b, a); | |
std::cout << d << std::endl; | |
stop = high_resolution_clock::now(); | |
duration = duration_cast<microseconds>(stop - start); | |
std::cout << std::dec << "Base Time: " << duration.count() << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment