Source code for nnmt._solvers

"""
Collection of general functions used to solve mean-field equations.

.. autosummary::
    :toctree: _toctree/lif/

    _firing_rate_integration

"""

from functools import partial
import numpy as np
import scipy.integrate as sint
import scipy.optimize as sopt


[docs]def _firing_rate_integration(firing_rate_func, firing_rate_params, input_dict, nu_0=None, fixpoint_method='ODE', eps_tol=1e-7, t_max_ODE=1000, maxiter_ODE=1000): """ Solves the self-consistent equations for firing rates. Parameters ---------- firing_rate_func : func Function to be integrated. firing_rates_params : dict Parameters passed to firing_rates_func input_dict : dict Dictionary specifying the functions that need to be computed in each iteration step to get the input for the firing rate function. It should specify the names of the respective `firing_rate_func` arguments as keys and provide a dictionary for each input including the function using the key 'func' and the parameters using the key 'params': ``{'arg1': {'func': func1, 'params': params_dict1}, ...}`` All input functions are assumed to take the firing rate as their first argument. nu_0 : [None | np.ndarray] Initial guess for fixed point integration. If `None` the initial guess is 0 for all populations. Default is `None`. fixpoint_method : str Method used for finding the fixed point. Currently, the following method are implemented: `ODE`, `LSQTSQ`. ODE is a very good choice, which finds stable fixed points even if the initial guess is far from the fixed point. LSQTSQ also finds unstable fixed points but needs a good initial guess. Default is `ODE`. ODE : Solves the initial value problem dnu / ds = - nu + firing_rate_func(nu) with initial value `nu_0` on the interval [0, t_max_ODE]. The final value at `t_max_ODE` is used as a new initial value and the initial value problem is solved again. This procedure is iterated until the criterion for a self-consistent solution max( abs(nu[t_max_ODE-1] - nu[t_max_ODE]) ) < eps_tol is fulfilled. Raises an error if this does not happen within `maxiter_ODE` iterations. LSTSQ : Determines the minimum of (nu - firing_rate_func(nu))^2 using least squares. Raises an error if the solution is a local minimum with mean squared differnce above eps_tol. eps_tol : float Maximal incremental stepsize at which to stop the iteration procedure. Default is 1e-7. t_max_ODE : int Determines the interval [0, t_max_ODE] on which the initial value problem for the method `ODE` is solved in a single iteration. Default is 1000. maxiter_ODE : int Determines the maximum number of iterations of the initial value problem for the method `ODE`. Default is 1000. """ def get_rate_difference(_, nu, rate_func): """ Calculate difference between new iteration step and previous one. """ # new inputs inputs = {} for key, input in input_dict.items(): inputs[key] = input['func'](nu, **input['params']) # new rate new_nu = rate_func(**{**firing_rate_params, **inputs}) return -nu + new_nu get_rate_difference = partial(get_rate_difference, rate_func=firing_rate_func) # TODO improve the following way of finding the dimension dimension = input_dict['mu']['params']['K'].shape[0] if nu_0 is None: nu_0 = np.zeros(int(dimension)) if fixpoint_method == 'ODE': # do iteration procedure, until stationary firing rates are found for _ in range(maxiter_ODE): sol = sint.solve_ivp(get_rate_difference, [0, t_max_ODE], nu_0, t_eval=[t_max_ODE - 1, t_max_ODE], method='LSODA') assert sol.success is True eps = max(np.abs(sol.y[:, 1] - sol.y[:, 0])) if eps < eps_tol: return sol.y[:, 1] else: nu_0 = sol.y[:, 1] msg = f'Iteration failed to converge after {maxiter_ODE} steps. ' msg += f'Last maximum difference {eps:e}, desired {eps_tol:e}.' raise RuntimeError(msg) elif fixpoint_method == 'LSTSQ': # search roots using least squares get_rate_difference = partial(get_rate_difference, None) res = sopt.least_squares(get_rate_difference, nu_0, bounds=(0, np.inf)) if res.cost/dimension < eps_tol: return res.x else: msg = 'Least squares converged in a local minimum. ' msg += f'Mean squared differences: {res.cost/dimension}.' raise RuntimeError(msg) else: msg = f"The method '{fixpoint_method}' to determine the self-" msg += "consistent fixpoint is not implemented." raise NotImplementedError(msg)