Skip to content

Instantly share code, notes, and snippets.

@blakeNaccarato
Created July 26, 2024 16:29
Show Gist options
  • Save blakeNaccarato/8cb6a335bc5784eaff487083aa681ff1 to your computer and use it in GitHub Desktop.
Save blakeNaccarato/8cb6a335bc5784eaff487083aa681ff1 to your computer and use it in GitHub Desktop.
Fitting experimental data to model functions using `scipy.optimize.curve_fit`
"""Get fits and errors."""
from functools import partial
from warnings import catch_warnings
from numpy import array, diagonal, full, inf, isinf, linspace, nan, sqrt, where
from scipy.optimize import OptimizeWarning, curve_fit
from scipy.stats import t
# Docs: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html
# Module-level variables denoted by all-caps, which may become function arguments in
# your implementation, if you wrap all this into a proper function for instance
# We want to fit this function, maybe it even comes from an external library
def fun(b, c, a, const, x, other_const):
"""Maybe this comes from an external source and you can't control argument order.
This argument order is all out-of-whack, but we need it in a certain order since
`scipy.optimize.curve_fit` is picky!
"""
return a * x**2 + b * x + c + const + other_const
# `const` and `other_const` in the external `fun` we don't want to fit for
# They will be constants, supplied by us using `functools.partial` later
CONST = 0.25
OTHER_CONST = 0.75
# Your experimental data
EXPERIMENTAL_X = array([0.0, 1.2, 2.5, 3, 4.2, 5.1])
EXPERIMENTAL_Y = array([2.0, 4.5, 10.9, 13.1, 25.6, 31.8])
# Uncertainty in each of your `y` measurements from `x` data, computed through uncertainty approaches
# This will propagate the uncertainty through to the fit
# Optional, absolute if absolute_sigma is True, but you can try relative as well
SIGMA_Y = [0.01, 0.003, 0.1, 0.2, 0.01, 0.01]
SIGMAS_ARE_ABSOLUTE = True
# In this example, each y value (e.g. 4.5) is actually composed as the mean of 3 samples
# taken at that x value (at experiment time, for instance). Set this to `1` if you only
# have one sample per x value, which is also common.
NUMBER_OF_SAMPLES_FOR_EACH_Y = 3
# Pull from the student-t distribution. `ci` of 0.95 is 95% CI, for instance Will be
# used to compute uncertainty in your fit parameters propagated from uncertainty in your
# y-values
CONFIDENCE_INTERVAL_THRESH = 0.95
CONFIDENCE_INTERVAL_95 = t.interval(
CONFIDENCE_INTERVAL_THRESH, NUMBER_OF_SAMPLES_FOR_EACH_Y
)[1]
# Here we redefine `fun` to have appropriate argument order
def my_fun(x, a, b, c, const, other_const):
"""So you redefine it in the following order.
independent_variable, e.g. time or x: x
parameters to fit: a, b, c
fixed parameter: const
"""
return fun(b, c, a, const, x, other_const)
# Then we use `functools.partial` to fix the constant parameter(s). You could also bake
# these in to `my_fun` above, but we see use of `partial` here in case we sometimes want
# to fit different sets of parameters, and the constants aren't always the same params.
MODEL = partial(my_fun, const=CONST, other_const=OTHER_CONST)
# Perform fit, filling "nan" on failure or when covariance computation fails
with catch_warnings():
try:
# Because curve fit takes guesses/bounds just as tuples of values, it's very
# sensitive to argument order of your model function, and assumes you have
# complete control over the order of its arguments in the function definition.
# That's why we have to "wrap" `fun`, because maybe it comes from an external
# source we don't control
fits, pcov = curve_fit(
f=MODEL,
p0=[1, 1, 1], # Optional, guesses for [a, b, c]
# Expects e.g. ([a_lower, b_lower, c_lower], [a_upper, b_upper, c_upper])
bounds=([0, -inf, 0], [inf, inf, inf]), # Optional
xdata=EXPERIMENTAL_X, # This should be the same 'x' from your exp data
ydata=EXPERIMENTAL_Y, # Experimental `y` data, aka result of `my_fun`
sigma=SIGMA_Y, # Optional
absolute_sigma=SIGMAS_ARE_ABSOLUTE, # Optional
method="trf", # Optional, algo to fit with
)
except (RuntimeError, OptimizeWarning):
# We gotta catch fit errors and just return `nan` if it fails
dim = 3 # Number of parameters, aka a, b, c is 3 params
fits = full(dim, nan) # Fill with "nan" on failure
pcov = full((dim, dim), nan) # Fill with "nan" on failure
# Compute confidence interval
standard_errors = sqrt(diagonal(pcov))
errors = standard_errors * CONFIDENCE_INTERVAL_95
# Catching `OptimizeWarning` should be enough, but let's explicitly check for inf
fits = where(isinf(errors), nan, fits)
errors = where(isinf(errors), nan, errors)
# Embed the fit parameters into the function, so now `fitted_model` varies only in `x`
# Here we "unpack" the `fits` tuple into the left-hand-side, assigning three variables
# at once, corresponding to our fit parameters
a_fit, b_fit, c_fit = fits
a_err, b_err, c_err = errors
fitted_model = partial(MODEL, a=a_fit, b=b_fit, c=c_fit)
# We can evaluate `fitted_model` at any `x`
arbitrary_x = linspace(0, 5, 6)
# We can next string join statements to build a nice output. The " = " syntax inside
# curly braces is a nice feature that automatically renders the variable name
print( # noqa: T201
"\n".join([
"",
f"fitted_model: {a_fit:.4f} * x**2 + {b_fit:.4f} * x + {c_fit:.4f} + {CONST} + {OTHER_CONST}",
f"95% CI:\n\t{a_fit = :.4f} ± {a_err:.4f} \n\t{b_fit = :.4f} ± {b_err:.4f} \n\t{c_fit = :.4f} ± {c_err:.4f}", # noqa: E203
])
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment