Last active
November 26, 2015 15:56
-
-
Save hovren/0fede832741e856fd2d5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import numpy as np | |
import numpy.testing as nt | |
import scipy.misc | |
def multinomial(n, K): | |
nom = scipy.misc.factorial(n) | |
den = np.product([scipy.misc.factorial(k) for k in K]) | |
return nom / den | |
def powers(n, m): | |
for comb in itertools.product(range(n+1), repeat=m): | |
if np.sum(comb) == n: | |
yield comb | |
def general_leibniz(funcs, derivative, x): | |
"""General Leibniz rule for derivation of products | |
funcs is a list of functions with signature f(x, derivative) | |
""" | |
m = len(funcs) | |
n = derivative | |
Y = 0 | |
for K in powers(n, m): | |
assert len(K) == m | |
mcoeff = multinomial(n, K) | |
yk = [f(x, k) for f, k in zip(funcs, K)] | |
ykprod = np.product(yk, axis=0) | |
Y = Y + mcoeff * ykprod | |
return Y | |
def test_general_product_derivative(): | |
def func_A(x, derivative): | |
f = (lambda x: x**2, | |
lambda x: 2*x, | |
lambda x: 2*np.ones_like(x))[derivative] | |
return f(x) | |
def func_B(x, derivative): | |
f = (lambda x: 3*x, | |
lambda x: 3*np.ones_like(x), | |
lambda x: np.zeros_like(x))[derivative] | |
return f(x) | |
def func_C(x, derivative): | |
f = (lambda x: np.sin(x), | |
lambda x: np.cos(x), | |
lambda x: -np.sin(x) | |
)[derivative] | |
return f(x) | |
funcs = [func_A, func_B, func_C] | |
# Analytical versions of above product f = f_A * f_B * f_C | |
f = lambda x: (x ** 2) * (3 * x) * np.sin(x) | |
f_prim = lambda x: 9 * x**2 * np.sin(x) + 3*x**3 * np.cos(x) | |
f_bis = lambda x: np.sin(x) * (18*x - 3*x**3) + np.cos(x)*18*x**2 | |
x = np.linspace(-5, 5) | |
y = f(x) | |
yp1 = f_prim(x) | |
yp2 = f_bis(x) | |
# Calculate derivatives | |
yp1_hat = splines.general_product_derivative(funcs, 1, x) | |
yp2_hat = splines.general_product_derivative(funcs, 2, x) | |
nt.assert_almost_equal(yp1_hat, yp1) | |
nt.assert_almost_equal(yp2_hat, yp2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment