Source code for nnmt._solvers
"""
Collection of general functions used to solve mean-field equations.
.. autosummary::
:toctree: _toctree/lif/
_firing_rate_integration
"""
from functools import partial
import numpy as np
import scipy.integrate as sint
import scipy.optimize as sopt
[docs]def _firing_rate_integration(firing_rate_func, firing_rate_params,
input_funcs, input_params, nu_0=None,
fixpoint_method='ODE', eps_tol=1e-7,
t_max_ODE=1000, maxiter_ODE=1000):
"""
Solves the self-consistent eqs for firing rates, mean, and std of input.
Parameters
----------
firing_rate_func : func
Function to be integrated.
firing_rates_params : dict
Parameters passed to firing_rates_func
input_funcs : list
List of functions needed to be run to calculate input to
firing_rate_func. They need to be in the order they are passed to the
firing_rate_func, and they need to be the first arguments of
firing_rate_func.
input_params : dict
Parameters passed to functions calculating mean and std of input.
nu_0 : [None | np.ndarray]
Initial guess for fixed point integration. If `None` the initial guess
is 0 for all populations. Default is `None`.
fixpoint_method : str
Method used for finding the fixed point. Currently, the following
method are implemented: `ODE`, `LSQTSQ`. ODE is a very good choice,
which finds stable fixed points even if the initial guess is far from
the fixed point. LSQTSQ also finds unstable fixed points but needs a
good initial guess. Default is `ODE`.
ODE :
Solves the initial value problem
dnu / ds = - nu + firing_rate_func(nu)
with initial value `nu_0` on the interval [0, t_max_ODE].
The final value at `t_max_ODE` is used as a new initial value
and the initial value problem is solved again. This procedure
is iterated until the criterion for a self-consistent solution
max( abs(nu[t_max_ODE-1] - nu[t_max_ODE]) ) < eps_tol
is fulfilled. Raises an error if this does not happen within
`maxiter_ODE` iterations.
LSTSQ :
Determines the minimum of
(nu - firing_rate_func(nu))^2
using least squares. Raises an error if the solution is a local
minimum with mean squared differnce above eps_tol.
eps_tol : float
Maximal incremental stepsize at which to stop the iteration procedure.
Default is 1e-7.
t_max_ODE : int
Determines the interval [0, t_max_ODE] on which the initial value
problem for the method `ODE` is solved in a single iteration.
Default is 1000.
maxiter_ODE : int
Determines the maximum number of iterations of the initial value
problem for the method `ODE`. Default is 1000.
"""
dimension = input_params['K'].shape[0]
def get_rate_difference(_, nu, rate_func):
"""
Calculate difference between new iteration step and previous one.
"""
# new inputs
inputs = []
for func in input_funcs:
inputs.append(func(nu, **input_params))
# new rate
new_nu = rate_func(*inputs, **firing_rate_params)
return -nu + new_nu
get_rate_difference = partial(get_rate_difference,
rate_func=firing_rate_func)
if nu_0 is None:
nu_0 = np.zeros(int(dimension))
if fixpoint_method == 'ODE':
# do iteration procedure, until stationary firing rates are found
for _ in range(maxiter_ODE):
sol = sint.solve_ivp(get_rate_difference, [0, t_max_ODE], nu_0,
t_eval=[t_max_ODE - 1, t_max_ODE],
method='LSODA')
assert sol.success is True
eps = max(np.abs(sol.y[:, 1] - sol.y[:, 0]))
if eps < eps_tol:
return sol.y[:, 1]
else:
nu_0 = sol.y[:, 1]
msg = f'Iteration failed to converge after {maxiter_ODE} steps. '
msg += f'Last maximum difference {eps:e}, desired {eps_tol:e}.'
raise RuntimeError(msg)
elif fixpoint_method == 'LSTSQ':
# search roots using least squares
get_rate_difference = partial(get_rate_difference, None)
res = sopt.least_squares(get_rate_difference, nu_0, bounds=(0, np.inf))
if res.cost/dimension < eps_tol:
return res.x
else:
msg = 'Least squares converged in a local minimum. '
msg += f'Mean squared differences: {res.cost/dimension}.'
raise RuntimeError(msg)
else:
msg = f"The method '{fixpoint_method}' to determine the self-"
msg += "consistent fixpoint is not implemented."
raise NotImplementedError(msg)