Source code for nnmt.lif.delta

"""
Collection of functions for LIF neurons with delta synapses.

Network Functions
*****************

.. autosummary::
    :toctree: _toctree/lif/

    firing_rates
    mean_input
    std_input

Parameter Functions
*******************

.. autosummary::
    :toctree: _toctree/lif/

    _firing_rates
    _firing_rates_for_given_input
    _mean_input
    _std_input
    _derivative_of_firing_rates_wrt_mean_input

"""

import numpy as np
from scipy.special import (
    erf as _erf,
    erfcx as _erfcx,
    dawsn as _dawsn,
    roots_legendre as _roots_legendre
    )
from scipy.integrate import quad as _quad

from . import _general
from .. import _solvers
from ..utils import (_cache,
                     _check_positive_params,
                     get_optional_network_params,
                     get_required_network_params,
                     get_required_results)


_prefix = 'lif.delta.'


[docs]def firing_rates(network, **kwargs): """ Calculates stationary firing rates for delta shaped PSCs. See :func:`nnmt.lif.delta._firing_rates` for full documentation. Parameters ---------- network : nnmt.models.Network or child class instance. Network with the network parameters listed in the docstring of :func:`nnmt.lif.delta._firing_rates`. kwargs For additional kwargs regarding the fixpoint iteration procedure see :func:`nnmt._solvers._firing_rate_integration`. Returns ------- np.array Array of firing rates of each population in Hz. """ params = get_required_network_params(network, _firing_rates) params.update(get_optional_network_params(network, _firing_rates)) params.update(kwargs) return _cache(network, _firing_rates, params, _prefix + 'firing_rates', 'hertz')
[docs]def _firing_rates(J, K, V_0_rel, V_th_rel, tau_m, tau_r, J_ext, K_ext, nu_ext, I_ext=None, C=None, **kwargs): """ Calculation of firing rates for delta PSCs. See :func:`nnmt._solvers._firing_rate_integration` for integration procedure. Uses :func:`nnmt.lif.delta._firing_rates_for_given_input`, :func:`nnmt.lif._general._mean_input`, and :func:`nnmt.lif._general._std_input`. Parameters ---------- J : np.array Weight matrix in V. K : np.array Indegree matrix. V_0_rel : [float | 1d array] Relative reset potential in V. V_th_rel : [float | 1d array] Relative threshold potential in V. tau_m : [float | 1d array] Membrane time constant of post-synatic neuron in s. tau_r : [float | 1d array] Refractory time in s. J_ext : np.array External weight matrix in V. K_ext : np.array Numbers of external input neurons to each population. nu_ext : 1d array Firing rates of external populations in Hz. I_ext : float, optional External d.c. input in A, requires membrane capacitance as well. C : float, optional Membrane capacitance in F, required if external input is given. kwargs For additional kwargs regarding the fixpoint iteration procedure see :func:`nnmt._solvers._firing_rate_integration`. Returns ------- np.array Array of firing rates of each population in Hz. """ firing_rate_params = { 'V_0_rel': V_0_rel, 'V_th_rel': V_th_rel, 'tau_m': tau_m, 'tau_r': tau_r, } input_params = { 'J': J, 'K': K, 'tau_m': tau_m, 'J_ext': J_ext, 'K_ext': K_ext, 'nu_ext': nu_ext, } mu_input_params = input_params.copy() mu_input_params['I_ext'] = I_ext mu_input_params['C'] = C input_dict = dict( mu={'func': _general._mean_input, 'params': mu_input_params}, sigma={'func': _general._std_input, 'params': input_params}, ) return _solvers._firing_rate_integration(_firing_rates_for_given_input, firing_rate_params, input_dict, **kwargs)
[docs]@_check_positive_params def _firing_rates_for_given_input(mu, sigma, V_0_rel, V_th_rel, tau_m, tau_r): """ Calculates stationary firing rate for delta shaped PSCs. Implementation of formula by Siegert for the mean-first-passage time :cite:p:`siegert1951`, found for example in Appendix A, Eq. A7 of :cite:t:`amit1997`. The implementation is explained in :cite:t:`layer2022`, Appendix A.1. And alternative but less stable way of implementing the Siegert function can be found in :cite:t:`hahne2017`, Appendix A.1. Parameters ---------- mu : [float | 1d array] Mean input to population of neurons. sigma : [float | 1d array] Standard deviation of input to population of neurons. V_0_rel : [float | 1d array] Relative reset potential in V. V_th_rel : [float | 1d array] Relative threshold potential in V. tau_m : [float | 1d array] Membrane time constant of post-synatic neuron in s. tau_r : [float | 1d array] Refractory time in s. Returns ------- [float | np.array] Firing rates in Hz. """ y_th = (V_th_rel - mu) / sigma y_r = (V_0_rel - mu) / sigma y_th = np.atleast_1d(y_th) y_r = np.atleast_1d(y_r) # this brings tau_m and tau_r into the correct vectorized form if they are # scalars and doesn't do anything if they are arrays of appropriate size tau_m = tau_m + y_th - y_th tau_r = tau_r + y_th - y_th assert y_th.shape == y_r.shape assert y_th.ndim == y_r.ndim == 1 if np.any(V_th_rel - V_0_rel < 0): raise ValueError('V_th should be larger than V_0!') # determine order of quadrature params = {'start_order': 10, 'epsrel': 1e-12, 'maxiter': 10} gl_order = _get_erfcx_integral_gl_order(y_th=y_th, y_r=y_r, **params) # separate domains mask_exc = y_th < 0 mask_inh = 0 < y_r mask_interm = (y_r <= 0) & (0 <= y_th) # calculate rescaled siegert nu = np.zeros(shape=y_th.shape) params = {'tau_m': tau_m[mask_exc], 't_ref': tau_r[mask_exc], 'gl_order': gl_order} nu[mask_exc] = _siegert_exc(y_th=y_th[mask_exc], y_r=y_r[mask_exc], **params) params = {'tau_m': tau_m[mask_inh], 't_ref': tau_r[mask_inh], 'gl_order': gl_order} nu[mask_inh] = _siegert_inh(y_th=y_th[mask_inh], y_r=y_r[mask_inh], **params) params = {'tau_m': tau_m[mask_interm], 't_ref': tau_r[mask_interm], 'gl_order': gl_order} nu[mask_interm] = _siegert_interm(y_th=y_th[mask_interm], y_r=y_r[mask_interm], **params) # include exponential contributions nu[mask_inh] *= np.exp(-y_th[mask_inh]**2) nu[mask_interm] *= np.exp(-y_th[mask_interm]**2) # convert back to scalar if only one value calculated if nu.shape == (1,): return nu.item(0) else: return nu
def _get_erfcx_integral_gl_order(y_th, y_r, start_order, epsrel, maxiter): """Determine order of Gauss-Legendre quadrature for erfcx integral.""" # determine maximal integration range a = min(np.abs(y_th).min(), np.abs(y_r).min()) b = max(np.abs(y_th).max(), np.abs(y_r).max()) # adaptive quadrature from scipy.integrate for comparison I_quad = _quad(_erfcx, a, b, epsabs=0, epsrel=epsrel)[0] # increase order to reach desired accuracy order = start_order for _ in range(maxiter): I_gl = _erfcx_integral(a, b, order=order)[0] rel_error = np.abs(I_gl / I_quad - 1) if rel_error < epsrel: return order else: order *= 2 msg = f'Quadrature search failed to converge after {maxiter} iterations. ' msg += f'Last relative error {rel_error:e}, desired {epsrel:e}.' raise RuntimeError(msg) def _erfcx_integral(a, b, order): """Fixed order Gauss-Legendre quadrature of erfcx from a to b.""" assert np.all(a >= 0) and np.all(b >= 0) x, w = _roots_legendre(order) x = x[:, np.newaxis] w = w[:, np.newaxis] return (b - a) * np.sum(w * _erfcx((b - a) * x / 2 + (b + a) / 2), axis=0) / 2 def _siegert_exc(y_th, y_r, tau_m, t_ref, gl_order): """Calculate Siegert for y_th < 0.""" assert np.all(y_th < 0) Int = _erfcx_integral(np.abs(y_th), np.abs(y_r), gl_order) return 1 / (t_ref + tau_m * np.sqrt(np.pi) * Int) def _siegert_inh(y_th, y_r, tau_m, t_ref, gl_order): """Calculate Siegert without exp(-y_th**2) factor for 0 < y_th.""" assert np.all(0 < y_r) e_V_th_2 = np.exp(-y_th**2) Int = (2 * _dawsn(y_th) - 2 * np.exp(y_r**2 - y_th**2) * _dawsn(y_r)) Int -= e_V_th_2 * _erfcx_integral(y_r, y_th, gl_order) return 1 / (e_V_th_2 * t_ref + tau_m * np.sqrt(np.pi) * Int) def _siegert_interm(y_th, y_r, tau_m, t_ref, gl_order): """Calculate Siegert without exp(-y_th**2) factor for y_r <= 0 <= y_th.""" assert np.all((y_r <= 0) & (0 <= y_th)) e_V_th_2 = np.exp(-y_th**2) Int = 2 * _dawsn(y_th) Int += e_V_th_2 * _erfcx_integral(y_th, np.abs(y_r), gl_order) return 1 / (e_V_th_2 * t_ref + tau_m * np.sqrt(np.pi) * Int)
[docs]def mean_input(network): ''' Calc mean inputs to populations as function of firing rates of populations. See :func:`nnmt.lif._general._mean_input` for full documentation. Parameters ---------- network : Network object Model with the network parameters and previously calculated results listed in :func:`nnmt.lif._general._mean_input`. Returns ------- np.array Array of mean inputs to each population in V. ''' params = get_required_network_params( network, _general._mean_input, exclude=['nu']) params.update( get_required_results(network, ['nu'], [_prefix + 'firing_rates'])) params.update(get_optional_network_params(network, _general._mean_input)) return _cache(network, _mean_input, params, _prefix + 'mean_input', 'volt')
[docs]def _mean_input(*args, **kwargs): """ Calc mean input for lif neurons in fixed in-degree connectivity network. See :func:`nnmt.lif._general._mean_input` for full documentation. """ return _general._mean_input(*args, **kwargs)
[docs]def std_input(network): ''' Calculates standard deviation of inputs to populations. See :func:`nnmt.lif._general._std_input` for full documentation. Parameters ---------- network : nnmt.models.Network or child class instance. Network with the network parameters and previously calculated results listed in :func:`nnmt.lif._general._std_input`. Returns ------- np.array Array of mean inputs to each population in V. ''' params = get_required_network_params( network, _general._std_input, exclude=['nu']) params.update( get_required_results(network, ['nu'], [_prefix + 'firing_rates'])) params.update(get_optional_network_params(network, _general._std_input)) return _cache(network, _std_input, params, _prefix + 'std_input', 'volt')
[docs]def _std_input(*args, **kwargs): """ Plain calculation of standard deviation of neuronal input. See :func:`nnmt.lif._general._std_input` for full documentation. """ return _general._std_input(*args, **kwargs)
[docs]def _derivative_of_firing_rates_wrt_mean_input(mu, sigma, V_0_rel, V_th_rel, tau_m, tau_r): """ Derivative of the stationary firing rate with respect to the mean input. See Appendix B in :cite:t:`schuecker2014`. Parameters ---------- mu : float Mean neuron activity in V. sigma : float Standard deviation of neuron activity in V. V_0_rel : float Relative reset potential in V. V_th_rel : float Relative threshold potential in V. tau_m : float Membrane time constant of post-synatic neuron in s. tau_r : float Refractory time in s. Returns ------- float Zero frequency limit of white noise transfer function in Hz/V. """ if np.any(sigma == 0): raise ZeroDivisionError('Phi_prime_mu contains division by sigma!') y_th = (V_th_rel - mu) / sigma y_r = (V_0_rel - mu) / sigma nu0 = _firing_rates_for_given_input(mu, sigma, V_0_rel, V_th_rel, tau_m, tau_r) return (np.sqrt(np.pi) * tau_m * np.power(nu0, 2) / sigma * (np.exp(y_th**2) * (1 + _erf(y_th)) - np.exp(y_r**2) * (1 + _erf(y_r))))