Skip to content

Instantly share code, notes, and snippets.

@bredelings
Last active August 20, 2024 15:37
Show Gist options
  • Save bredelings/9a107128a93342cabd37d9dcfee8861c to your computer and use it in GitHub Desktop.
Save bredelings/9a107128a93342cabd37d9dcfee8861c to your computer and use it in GitHub Desktop.
A log-density class that counts the number of zeros.
#include <limits>
#include <cmath>
#include <cassert>
#include <iostream>
class LogDensity
{
double zeros_ = 0;
// Cannot be -Inf. Can be +Inf.
double ones_ = 0;
public:
double zeros() const {return zeros_;}
double& zeros() {return zeros_;}
double ones() const {return ones_;}
double& ones() {return ones_;}
void check() const {
assert(not std::isinf(ones_) or ones_ > 0); // ones_ != -Inf
assert(not std::isinf(zeros_)); // only a finite number of zeros.
}
LogDensity& operator *=(double y)
{
// No infinite powers
assert(not std::isinf(y));
// 0^0 == 1
zeros_ *= y; // fractional zeros
ones_ *= y;
return *this;
}
LogDensity& operator /=(double y)
{
// No zeroth roots
assert(y != 0);
zeros_ /= y; // fractional zeros
ones_ /= y;
return *this;
}
LogDensity operator +=(const LogDensity y)
{
zeros_ += y.zeros_;
ones_ += y.ones_;
return *this;
}
LogDensity operator -=(const LogDensity y)
{
zeros_ -= y.zeros_;
ones_ -= y.ones_;
return *this;
}
bool operator<(const LogDensity& y) const
{
if (is_nan() or y.is_nan()) return false;
if (zeros_ == y.zeros_)
return (ones_ < y.ones_);
else if (zeros_ > y.zeros_)
return true;
else
return false; // handles NANs.
}
bool operator>(const LogDensity& y) const
{
if (is_nan() or y.is_nan()) return false;
if (zeros_ == y.zeros_)
return (ones_ > y.ones_);
else if (zeros_ < y.zeros_)
return true;
else
return false; // handles NANs.
}
bool operator==(const LogDensity& y) const
{
if (is_nan() or y.is_nan()) return false;
return zeros_ == y.zeros_ and ones_ == y.ones_;
}
bool operator!=(const LogDensity& y) const
{
if (is_nan() or y.is_nan()) return true;
return not (operator==(y));
}
bool operator<=(const LogDensity& y) const
{
return operator<(y) or operator==(y);
}
bool operator>=(const LogDensity& y) const
{
return operator>(y) or operator==(y);
}
explicit operator double() const
{
check();
if (zeros_ == 0)
return ones_;
else if (zeros_ > 0)
{
if (std::isfinite(ones_))
return -std::numeric_limits<double>::infinity();
else
return std::nan("1"); // NAN
}
else
return std::numeric_limits<double>::infinity();
}
double exp() const
{
if (zeros_ == 0)
return ::exp(ones_);
else if (zeros_ > 0)
{
if (std::isfinite(ones_))
return 0;
else
return std::nan("1"); // NAN
}
else
return std::numeric_limits<double>::infinity();
}
bool is_nan() const
{
return std::isnan(ones_) or std::isnan(zeros_) or (zeros_ > 0 and std::isinf(ones_));
}
LogDensity() = default;
explicit LogDensity(double y)
{
if (std::isinf(y) and y<0)
zeros_ = 1;
else
ones_ = y;
}
explicit LogDensity(double z, double y)
:zeros_(z), ones_(y)
{ }
};
LogDensity operator+(LogDensity x, const LogDensity& y)
{
x += y;
return x;
}
LogDensity operator-(LogDensity x, const LogDensity& y)
{
x -= y;
return x;
}
LogDensity operator*(double p, LogDensity x)
{
x *= p;
return x;
}
LogDensity operator*(LogDensity x, double p)
{
x *= p;
return x;
}
LogDensity operator/(LogDensity x, double p)
{
x /= p;
return x;
}
std::ostream& operator<<(std::ostream& o, const LogDensity& x)
{
o<<x.ones()<<" + "<<x.zeros()<<"*-Inf";
return o;
}
double exp(LogDensity y)
{
return y.exp();
}
LogDensity toLogDensity(double y)
{
if (y<0)
throw std::runtime_error("Negative density");
else
return LogDensity(log(y));
}
#include <iostream>
#include "LogDensity.h"
int main()
{
LogDensity x1 = toLogDensity(0.5);
LogDensity x2 = toLogDensity(0);
LogDensity x3 = toLogDensity(std::numeric_limits<double>::infinity());
LogDensity y1 = x1;
LogDensity y2 = y1 + x2;
LogDensity y3 = x1 + x2 + x2;
LogDensity y4 = y2 + x2;
LogDensity y5 = x3;
LogDensity y6 = x3 + x2;
LogDensity y7 = x3 - x2;
std::cerr<<std::boolalpha;
std::cerr<<"y1 = "<<y1<<" double(y1) = "<<double(y1)<<" exp(y1) = "<<exp(y1)<<"\n";
std::cerr<<"y2 = "<<y2<<" double(y2) = "<<double(y2)<<" exp(y2) = "<<exp(y2)<<"\n";
std::cerr<<"y3 = "<<y3<<" double(y3) = "<<double(y3)<<" exp(y3) = "<<exp(y3)<<"\n";
std::cerr<<"y4 = "<<y4<<" double(y4) = "<<double(y4)<<" exp(y4) = "<<exp(y4)<<"\n";
std::cerr<<"\n";
std::cerr<<"y5 = "<<y5<<" double(y5) = "<<double(y5)<<" exp(y5) = "<<exp(y5)<<"\n";
std::cerr<<"y6 = "<<y6<<" double(y6) = "<<double(y6)<<" exp(y6) = "<<exp(y6)<<"\n";
std::cerr<<"y7 = "<<y7<<" double(y7) = "<<double(y7)<<" exp(y7) = "<<exp(y7)<<"\n";
std::cerr<<"\n";
std::cerr<<"y1 > y2 = "<<(y1 > y2)<<"\n";
std::cerr<<"y2 > y3 = "<<(y2 > y3)<<"\n";
std::cerr<<"y3 > y4 = "<<(y3 > y4)<<"\n";
std::cerr<<"y2 > y4 = "<<(y2 > y4)<<"\n";
std::cerr<<"\n";
std::cerr<<"y1 > y6 = "<<(y1 > y6)<<"\n";
std::cerr<<"y6 == y6 = "<<(y6 == y6)<<"\n";
std::cerr<<"y6 != y6 = "<<(y6 != y6)<<"\n";
std::cerr<<"\n";
std::cerr<<"y1-y2 = "<<y1-y2<<" double(y1-y2) = "<<double(y1-y2)<<" exp(y1-y2) = "<<exp(y1-y2)<<"\n";
std::cerr<<"y2-y2 = "<<y2-y2<<" double(y2-y2) = "<<double(y2-y2)<<" exp(y2-y2) = "<<exp(y2-y2)<<"\n";
std::cerr<<"y2-y1 = "<<y2-y1<<" double(y2-y1) = "<<double(y2-y1)<<" exp(y2-y1) = "<<exp(y2-y1)<<"\n";
return 0;
}
y1 = -0.693147 + 0*-Inf double(y1) = -0.693147 exp(y1) = 0.5
y2 = -0.693147 + 1*-Inf double(y2) = -inf exp(y2) = 0
y3 = -0.693147 + 2*-Inf double(y3) = -inf exp(y3) = 0
y4 = -0.693147 + 2*-Inf double(y4) = -inf exp(y4) = 0
y5 = inf + 0*-Inf double(y5) = inf exp(y5) = inf
y6 = inf + 1*-Inf double(y6) = nan exp(y6) = nan
y7 = inf + -1*-Inf double(y7) = inf exp(y7) = inf
y1 > y2 = true
y2 > y3 = true
y3 > y4 = false
y2 > y4 = true
y1 > y6 = false
y6 == y6 = false
y6 != y6 = true
y1-y2 = 0 + -1*-Inf double(y1-y2) = inf exp(y1-y2) = inf
y2-y2 = 0 + 0*-Inf double(y2-y2) = 0 exp(y2-y2) = 1
y2-y1 = 0 + 1*-Inf double(y2-y1) = -inf exp(y2-y1) = 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment