Skip to content

Instantly share code, notes, and snippets.

@cblp
Created September 27, 2022 22:40
Show Gist options
  • Save cblp/7ce1cdfaed13d180ae85a5911ee259a4 to your computer and use it in GitHub Desktop.
Save cblp/7ce1cdfaed13d180ae85a5911ee259a4 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <complex>
using namespace std;
#include <numbers>
using namespace std::numbers;
struct Expr {
virtual string source() const = 0;
virtual shared_ptr<Expr> derive() const = 0;
};
struct Lit: Expr {
int n;
Lit(int n) : n(n) {}
string source() const { return to_string(n); }
shared_ptr<Expr> derive() const { return make_shared<Lit>(0); }
};
struct Var: Expr {
string source() const { return "x"; }
shared_ptr<Expr> derive() const { return make_shared<Lit>(1); }
};
struct Mul: Expr {
shared_ptr<Expr> a, b;
Mul(shared_ptr<Expr> a, shared_ptr<Expr> b) : a(a), b(b) {}
string source() const { return a->source() + " * " + b->source(); }
shared_ptr<Expr> derive() const { return nullptr; }
};
struct Pow: Expr {
shared_ptr<Expr> a;
int n;
Pow(shared_ptr<Expr> a, int n) : a(a), n(n) {}
string source() const { return a->source() + " ^ " + to_string(n); }
shared_ptr<Expr> derive() const {
return make_shared<Mul>(
make_shared<Lit>(n),
make_shared<Mul>(make_shared<Pow>(a, n - 1), a->derive())
);
}
};
struct Cos;
struct Sin: Expr {
shared_ptr<Expr> a;
Sin(shared_ptr<Expr> a) : a(a) {}
string source() const { return "sin(" + a->source() + ")"; }
shared_ptr<Expr> derive() const {
return make_shared<Mul>(make_shared<Cos>(a), a->derive());
}
};
struct Cos: Expr {
shared_ptr<Expr> a;
Cos(shared_ptr<Expr> a) : a(a) {}
string source() const { return "cos(" + a->source() + ")"; }
shared_ptr<Expr> derive() const { return nullptr; }
};
shared_ptr<Expr> sin(shared_ptr<Expr> a) {
return make_shared<Sin>(a);
}
shared_ptr<Expr> pow(shared_ptr<Expr> a, int n) {
return make_shared<Pow>(a, n);
}
string source(function<shared_ptr<Expr>(shared_ptr<Expr>)> f) {
return f(make_shared<Var>())->source();
}
shared_ptr<Expr> derive(function<shared_ptr<Expr>(shared_ptr<Expr>)> f) {
return f(make_shared<Var>())->derive();
}
// f x = sin x ** 4
template <typename T> T f(T x) { return pow(sin(x), 4); }
int main() {
cout << f(pi / 2) << endl;
// 1
cout << source(f<shared_ptr<Expr>>) << endl;
// sin(x) ^ 4
cout << derive(f<shared_ptr<Expr>>)->source() << endl;
// 4 * sin(x) ^ 3 * cos(x) * 1
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment