Skip to content

Instantly share code, notes, and snippets.

@es3n1n
Created August 20, 2024 20:48
Show Gist options
  • Save es3n1n/5f9ca17c030305f56679afba0543f7cb to your computer and use it in GitHub Desktop.
Save es3n1n/5f9ca17c030305f56679afba0543f7cb to your computer and use it in GitHub Desktop.
unfinished uniform_init_distribution drop-in replacement that utilizes https://arxiv.org/pdf/1805.10941
/// \brief A drop-in replacement for `std::uniform_int_distribution`.
///
/// We are using a replacement because the uniform_int_distribution implementation is different in libc and libstdc++.
/// This is important because otherwise we can't reproduce random values across multiple platforms.
///
/// \tparam IntType The integer type to be used for the distribution. Defaults to `int`.
template <typename IntType = int>
class UniformIntDistribution {
public:
using ResultTy = IntType;
explicit UniformIntDistribution(ResultTy min, ResultTy max = (std::numeric_limits<ResultTy>::max)()): min_(min), max_(max) { }
/// \brief Generates a random integer within the distribution range.
/// \engine The random engine instance to use.
template <typename Engine>
ResultTy operator()(Engine& engine) {
return eval(engine, min_, max_);
}
private:
using UResultTy = std::make_unsigned_t<ResultTy>;
template <class DiffTy, class URng>
class RngFromURng { // wrap a URNG as an RNG
public:
using ConvUDiffTy = std::make_unsigned_t<DiffTy>;
using RngResultTy = std::invoke_result_t<URng&>;
using _Udiff = std::conditional_t<sizeof(RngResultTy) < sizeof(ConvUDiffTy), ConvUDiffTy, RngResultTy>;
static constexpr unsigned int _Udiff_bits = sizeof(_Udiff) * CHAR_BIT;
using _Uprod = std::conditional_t<_Udiff_bits <= 16, uint32_t,
std::conditional_t<_Udiff_bits <= 32, uint64_t, std::_Unsigned128>>; // fixme: _Unsigned128
explicit RngFromURng(URng& _Func): _Ref(_Func) { }
DiffTy operator()(DiffTy _Index) { // adapt _Urng closed range to [0, _Index)
// From Daniel Lemire, "Fast Random Integer Generation in an Interval", ACM Trans. Model. Comput. Simul. 29 (1),
// 2019.
//
// Algorithm 5 <-> This Code:
// m <-> _Product
// l <-> _Rem
// s <-> _Index
// t <-> _Threshold
// L <-> _Generated_bits
// 2^L - 1 <-> _Mask
_Udiff _Mask = _Bmask;
unsigned int _Niter = 1;
if constexpr (_Bits < _Udiff_bits) {
while (_Mask < static_cast<_Udiff>(_Index - 1)) {
_Mask <<= _Bits;
_Mask |= _Bmask;
++_Niter;
}
}
// x <- random integer in [0, 2^L)
// m <- x * s
auto _Product = _Get_random_product(_Index, _Niter);
// l <- m mod 2^L
auto _Rem = static_cast<_Udiff>(_Product) & _Mask;
if (_Rem < _Index) {
// t <- (2^L - s) mod s
const auto _Threshold = (_Mask - _Index + 1) % _Index;
while (_Rem < _Threshold) {
_Product = _Get_random_product(_Index, _Niter);
_Rem = static_cast<_Udiff>(_Product) & _Mask;
}
}
unsigned int _Generated_bits;
if constexpr (_Bits < _Udiff_bits) {
_Generated_bits = static_cast<unsigned int>(_Popcount(_Mask));
} else {
_Generated_bits = _Udiff_bits;
}
// m / 2^L
return static_cast<DiffTy>(_Product >> _Generated_bits);
}
_Udiff _Get_all_bits() {
_Udiff _Ret = _Get_bits();
if constexpr (_Bits < _Udiff_bits) {
for (unsigned int _Num = _Bits; _Num < _Udiff_bits; _Num += _Bits) { // don't mask away any bits
_Ret <<= _Bits;
_Ret |= _Get_bits();
}
}
return _Ret;
}
RngFromURng(const RngFromURng&) = delete;
RngFromURng& operator=(const RngFromURng&) = delete;
private:
_Udiff _Get_bits() { // return a random value within [0, _Bmask]
static constexpr auto _Urng_min = (URng::min)();
for (;;) { // repeat until random value is in range
const _Udiff _Val = _Ref() - _Urng_min;
if (_Val <= _Bmask) {
return _Val;
}
}
}
static constexpr size_t _Calc_bits() {
auto _Bits_local = _Udiff_bits;
auto _Bmask_local = static_cast<_Udiff>(-1);
for (; (URng::max)() - (URng::min)() < _Bmask_local; _Bmask_local >>= 1) {
--_Bits_local;
}
return _Bits_local;
}
_Uprod _Get_random_product(const DiffTy _Index, unsigned int _Niter) {
_Udiff _Ret = _Get_bits();
if constexpr (_Bits < _Udiff_bits) {
while (--_Niter > 0) {
_Ret <<= _Bits;
_Ret |= _Get_bits();
}
}
if constexpr (std::is_same_v<_Udiff, uint64_t>) {
uint64_t _High;
const auto _Low = std::_Base128::_UMul128(_Ret, static_cast<_Udiff>(_Index), _High);
return _Uprod{_Low, _High};
} else {
return _Uprod{_Ret} * _Uprod{_Index};
}
}
URng& _Ref; // reference to URNG
static constexpr size_t _Bits = _Calc_bits(); // number of random bits generated by _Get_bits()
static constexpr _Udiff _Bmask = static_cast<_Udiff>(-1) >> (_Udiff_bits - _Bits); // 2^_Bits - 1
};
/// \brief Evaluates the distribution and generates a random integer.
/// \param engine The random engine instance to use.
/// \param min The minimum value of the distribution (inclusive).
/// \param max The maximum value of the distribution (inclusive).
/// \return A random integer within the specified range.
template <typename Engine>
ResultTy eval(Engine& engine, ResultTy min, ResultTy max) const {
const auto u_min = adjust(min);
const auto u_max = adjust(max);
UResultTy result;
if ((u_max - u_min) == static_cast<UResultTy>(-1)) {
result = static_cast<UResultTy>(engine());
} else {
result = static_cast<UResultTy>(RngFromURng<UResultTy, Engine>(engine)(static_cast<UResultTy>(u_max - u_min + 1)));
}
return static_cast<ResultTy>(adjust(static_cast<UResultTy>(result + u_min)));
}
static UResultTy adjust(UResultTy val) { // convert signed ranges to unsigned ranges and vice versa
if constexpr (std::is_signed_v<ResultTy>) {
constexpr UResultTy _Adjuster = (static_cast<UResultTy>(-1) >> 1) + 1; // 2^(N-1)
if (val < _Adjuster) {
return static_cast<UResultTy>(val + _Adjuster);
} else {
return static_cast<UResultTy>(val - _Adjuster);
}
} else { // ResultTy is already unsigned, do nothing
return val;
}
}
ResultTy min_;
ResultTy max_;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment