Source code for virocon._fitting

"""
Model fitting (parameter estimation).
"""

import numpy as np

from scipy.optimize import minimize, curve_fit


[docs] def fit_function(func, x, y, p0, method, bounds, weights=None): if bounds is not None: bounds = convert_bounds_for_curve_fit(bounds) if method == "lsq": if bounds is None: popt, _ = curve_fit(func, x, y, p0) else: popt, _ = curve_fit(func, x, y, p0, bounds=bounds) elif method == "wlsq": if bounds is None: popt, _ = curve_fit(func, x, y, p0, sigma=weights) else: popt, _ = curve_fit(func, x, y, p0, sigma=weights, bounds=bounds) else: raise ValueError( "method must be either lsq for least squares or" "wlsq for weighted least squares" ) return popt
[docs] def convert_bounds_for_curve_fit(bounds): lower_bounds = [] upper_bounds = [] for lower, upper in bounds: lower_bounds.append(lower if lower is not None else -np.inf) upper_bounds.append(upper if upper is not None else np.inf) return [lower_bounds, upper_bounds]
[docs] def get_least_squares_error_func(func, x, y): def least_squares_error_func(p): return np.sum((func(x, *p) - y) ** 2) return least_squares_error_func
[docs] def bounds_to_constraints(bounds): # https://stackoverflow.com/a/41761740 cons = [] for factor in range(len(bounds)): lower, upper = bounds[factor] lower = -np.inf if lower is None else lower upper = np.inf if upper is None else upper lo = {"type": "ineq", "fun": lambda x, lb=lower, i=factor: x[i] - lb} up = {"type": "ineq", "fun": lambda x, ub=upper, i=factor: ub - x[i]} cons.append(lo) cons.append(up) return cons
[docs] def fit_constrained_function(func, x, y, p0, method, bounds, constraints, weights=None): if method == "lsq": error_func = get_least_squares_error_func(func, x, y) else: # TODO implement WLSQ raise NotImplementedError( "At this time only least squares (lsq) fitting is supported." ) if constraints is None: constraints = [] result = minimize( error_func, p0, method="SLSQP", bounds=bounds, options={"eps": 1e-15}, ) if not result.success: raise RuntimeError( "Error during fitting in scipy.optimize.minimize. " f"Error message was: \n {result.message}." ) popt = result.x return popt