Source code for nnmt.lif._general
"""
Collection of functions used by both `lif.delta` and `lif.exp`.
Static Quantities
*****************
.. autosummary::
:toctree: _toctree/lif/
_firing_rate_integration
_input_calc
mean_input
_mean_input
std_input
_std_input
_fit_transfer_function
"""
from functools import partial
import numpy as np
import scipy.integrate as sint
import scipy.optimize as sopt
from .. import ureg
[docs]def _firing_rate_integration(firing_rate_func, firing_rate_params,
input_params, nu_0=None, fixpoint_method='ODE',
eps_tol=1e-7, t_max_ODE=1000, maxiter_ODE=1000):
"""
Solves the self-consistent eqs for firing rates, mean and std of input.
Parameters
----------
firing_rate_func : func
Function to be integrated.
firing_rates_params : dict
Parameters passed to firing_rates_func
input_params : dict
Parameters passed to functions calculating mean and std of input.
nu_0 : [None | np.ndarray]
Initial guess for fixed point integration. If `None` the initial guess
is 0 for all populations. Default is `None`.
fixpoint_method : str
Method used for finding the fixed point. Currently, the following
method are implemented: `ODE`, `LSQTSQ`. ODE is a very good choice,
which finds stable fixed points even if the initial guess is far from
the fixed point. LSQTSQ also finds unstable fixed points but needs a
good initial guess. Default is `ODE`.
ODE :
Solves the initial value problem
dnu / ds = - nu + firing_rate_func(nu)
with initial value `nu_0` on the interval [0, t_max_ODE].
The final value at `t_max_ODE` is used as a new initial value
and the initial value problem is solved again. This procedure
is iterated until the criterion for a self-consistent solution
max( abs(nu[t_max_ODE-1] - nu[t_max_ODE]) ) < eps_tol
is fulfilled. Raises an error if this does not happen within
`maxiter_ODE` iterations.
LSQTSQ :
Determines the minimum of
(nu - firing_rate_func(nu))^2
using least squares. Raises an error if the solution is a local
minimum with mean squared differnce above eps_tol.
eps_tol : float
Maximal incremental stepsize at which to stop the iteration procedure.
Default is 1e-7.
t_max_ODE : int
Determines the interval [0, t_max_ODE] on which the initial value
problem for the method `ODE` is solved in a single iteration.
Default is 1000.
maxiter_ODE : int
Determines the maximum number of iterations of the initial value
problem for the method `ODE`. Default is 1000.
"""
dimension = input_params['K'].shape[0]
def get_rate_difference(_, nu, rate_func):
"""
Calculate difference between new iteration step and previous one.
"""
# new mean
mu = _mean_input(nu=nu, **input_params)
# new std
sigma = _std_input(nu=nu, **input_params)
new_nu = rate_func(mu=mu, sigma=sigma, **firing_rate_params)
return -nu + new_nu
get_rate_difference = partial(get_rate_difference,
rate_func=firing_rate_func)
if nu_0 is None:
nu_0 = np.zeros(int(dimension))
if fixpoint_method == 'ODE':
# do iteration procedure, until stationary firing rates are found
for _ in range(maxiter_ODE):
sol = sint.solve_ivp(get_rate_difference, [0, t_max_ODE], nu_0,
t_eval=[t_max_ODE - 1, t_max_ODE],
method='LSODA')
assert sol.success is True
eps = max(np.abs(sol.y[:, 1] - sol.y[:, 0]))
if eps < eps_tol:
return sol.y[:, 1]
else:
nu_0 = sol.y[:, 1]
msg = f'Iteration failed to converge after {maxiter_ODE} steps. '
msg += f'Last maximum difference {eps:e}, desired {eps_tol:e}.'
raise RuntimeError(msg)
elif fixpoint_method == 'LSTSQ':
# search roots using least squares
get_rate_difference = partial(get_rate_difference, None)
res = sopt.least_squares(get_rate_difference, nu_0, bounds=(0, np.inf))
if res.cost/dimension < eps_tol:
return res.x
else:
msg = 'Least squares converged in a local minimum. '
msg += f'Mean squared differences: {res.cost/dimension}.'
raise RuntimeError(msg)
else:
msg = f"The method '{fixpoint_method}' to determine the self-"
msg += "consistent fixpoint is not implemented."
raise NotImplementedError(msg)
[docs]def mean_input(network, prefix):
'''
Calcs mean input for `network` and stores results using `prefix`.
See :func:`nnmt.lif._general._mean_input` for full documentation.
Parameters
----------
network : Network object
The network model for which the mean input is to be calculated. Needs
to contain the parameters defined in
:func:`nnmt.lif._general._mean_input`.
prefix : str
The prefix used to store the result (e.g. 'lif.delta.').
Returns
-------
np.array
Array of mean inputs to each population in V.
See Also
--------
nnmt.lif._general._mean_input : For full documentation of network
parameters.
'''
return _input_calc(network, prefix, _mean_input)
[docs]def std_input(network, prefix):
'''
Calcs std of input for `network` and stores results using `prefix`.
See :func:`nnmt.lif._general._std_input` for full documentation.
Parameters
----------
network : Network object
The network model for which the mean input is to be calculated. Needs
to contain the parameters defined in
:func:`nnmt.lif._general._std_input`.
prefix : str
The prefix used to store the result (e.g. 'lif.delta.').
Returns
-------
np.array
Array of standard deviation of inputs to each population in V.
See Also
--------
nnmt.lif._general._std_input : For full documentation of network
parameters.
'''
return _input_calc(network, prefix, _std_input)
[docs]def _input_calc(network, prefix, input_func):
'''
Helper function for input related calculations.
Checks the requirements for calculating input related quantities and calls
the respective input function.
Parameters
----------
network : nnmt.create.Network object
The network for which the calculation should be done.
prefix : str
The prefix used to store the results (e.g. 'lif.delta.').
input_func : function
The function that should be calculated (either `_mean_input` or
`_std_input`).
'''
try:
rates = (
network.results[prefix + 'firing_rates'].to_base_units().magnitude)
except KeyError as quantity:
raise RuntimeError(f'You first need to calculate the {quantity}.')
list_of_params = ['K', 'J', 'tau_m', 'nu_ext', 'K_ext', 'J_ext']
try:
params = {key: network.network_params[key] for key in list_of_params}
except KeyError as param:
raise RuntimeError(f'You are missing {param} for this calculation.')
return input_func(rates, **params) * ureg.V
[docs]def _mean_input(nu, J, K, tau_m, J_ext, K_ext, nu_ext):
"""
Calc mean input for lif neurons in fixed in-degree connectivity network.
Following Eq. 3.4 in :cite:t:`fourcaud2002`.
Parameters
----------
nu : np.array
Firing rates of populations in Hz.
J : np.array
Weight matrix in V.
K : np.array
In-degree matrix.
tau_m : [float | 1d array]
Membrane time constant of post-synatic neuron in s.
J_ext : np.array
External weight matrix in V.
K_ext : np.array
Numbers of external input neurons to each population.
nu_ext : 1d array
Firing rates of external populations in Hz.
Returns
-------
np.array
Array of mean inputs to each population in V.
"""
# contribution from within the network
m0 = tau_m * np.dot(K * J, nu)
# contribution from external sources
m_ext = tau_m * np.dot(K_ext * J_ext, nu_ext)
# add them up
m = m0 + m_ext
return m
[docs]def _std_input(nu, J, K, tau_m, J_ext, K_ext, nu_ext):
"""
Calc std of input for lif neurons in fixed in-degree connectivity network.
Following Eq. 3.4 in :cite:t:`fourcaud2002`.
Parameters
----------
nu : np.array
Firing rates of populations in Hz.
J : np.array
Weight matrix in V.
K : np.array
In-degree matrix.
tau_m : [float | 1d array]
Membrane time constant of post-synatic neuron in s.
J_ext : np.array
External weight matrix in V.
K_ext : np.array
Numbers of external input neurons to each population.
nu_ext : 1d array
Firing rates of external populations in Hz.
Returns
-------
np.array
Array of standard deviation of inputs to each population in V.
"""
# contribution from within the network to variance
var0 = tau_m * np.dot(K * J**2, nu)
# contribution from external sources to variance
var_ext = tau_m * np.dot(K_ext * J_ext**2, nu_ext)
# add them up
var = var0 + var_ext
# standard deviation is square root of variance
return np.sqrt(var)
[docs]def _fit_transfer_function(transfunc, omegas):
"""
Fits the transfer function (tf) of a low-pass filter to the passed tf.
A least-squares fit is used for the fitting procedure.
For details refer to
:cite:t:`senk2020`, Sec. F 'Comparison of neural-field and spiking models'.
Parameters
----------
transfer_function : np.array
Transfer functions for each population with the following shape:
(number of freqencies, number of populations).
omegas : [float | np.ndarray]
Input frequencies to population in Hz.
Returns
-------
transfer_function_fit : np.array
Fit of transfer functions in Hertz/volt for each population with the
following shape: (number of freqencies, number of populations).
tau_rate : np.array
Fitted time constant of low-pass filter for each population in s.
h0 : np.array
Fitted gain of low-pass filter for each population in Hertz/volt.
fit_error : float
Combined fit error.
"""
def func(omega, tau, h0):
return h0 / (1. + 1j * omega * tau)
# absolute value for fitting
def func_abs(omega, tau, h0):
return np.abs(func(omega, tau, h0))
transfunc_fit = np.zeros(np.shape(transfunc), dtype=np.complex_)
dim = np.shape(transfunc)[1]
tau_rate = np.zeros(dim)
h0 = np.zeros(dim)
fit_error = np.zeros(dim)
for i in np.arange(dim):
# fit low-pass filter transfer function (func) to LIF transfer function
# (transfunc) to obtain parameters of rate model with fit errors
fitParams, fitCovariances = sopt.curve_fit(
func_abs, omegas, np.abs(transfunc[:, i]))
tau_rate[i] = fitParams[0]
h0[i] = fitParams[1]
transfunc_fit[:, i] = func(omegas, tau_rate[i], h0[i])
# adjust sign of imaginary part (just check sign of last value)
sign_imag = 1 if (transfunc[-1, i].imag > 0) else -1
sign_imag_fit = 1 if (transfunc_fit[-1, i].imag > 0) else -1
if sign_imag != sign_imag_fit:
transfunc_fit[:, i].imag *= -1
tau_rate[i] *= -1
# standard deviation
fit_errs = np.sqrt(np.diag(fitCovariances))
# relative error
err_tau = fit_errs[0] / tau_rate[i]
err_h0 = fit_errs[1] / h0[i]
# combined error
fit_error[i] = np.sqrt(err_tau**2 + err_h0**2)
return transfunc_fit, tau_rate, h0, fit_error