Source code for nnmt.lif.delta

"""
Collection of functions for LIF neurons with delta synapses.

Network Functions
*****************

.. autosummary::
    :toctree: _toctree/lif/

    firing_rates
    mean_input
    std_input

Parameter Functions
*******************

.. autosummary::
    :toctree: _toctree/lif/

    _firing_rates
    _firing_rates_for_given_input
    _mean_input
    _std_input
    _derivative_of_firing_rates_wrt_mean_input

"""

import numpy as np
from scipy.special import (
    erf as _erf,
    erfcx as _erfcx,
    dawsn as _dawsn,
    roots_legendre as _roots_legendre
    )
from scipy.integrate import quad as _quad

from . import _general
from ..utils import (_cache,
                     _check_positive_params)


_prefix = 'lif.delta.'


[docs]def firing_rates(network, **kwargs): """ Calculates stationary firing rates for delta shaped PSCs. See :func:`nnmt.lif.delta._firing_rates` for full documentation. Parameters ---------- network : nnmt.models.Network or child class instance. Network with the network parameters listed in the docstring of :func:`nnmt.lif.delta._firing_rates`. kwargs For additional kwargs regarding the fixpoint iteration procedure see :func:`nnmt.lif._general._firing_rate_integration`. Returns ------- np.array Array of firing rates of each population in Hz. """ list_of_params = [ 'J', 'K', 'V_0_rel', 'V_th_rel', 'tau_m', 'tau_r', 'K_ext', 'J_ext', 'nu_ext', ] try: params = {key: network.network_params[key] for key in list_of_params} except KeyError as param: raise RuntimeError( f"You are missing {param} for calculating the firing rate!\n" "Have a look into the documentation for more details on 'lif' " "parameters.") params.update(kwargs) return _cache(network, _firing_rates, params, _prefix + 'firing_rates', 'hertz')
[docs]def _firing_rates(J, K, V_0_rel, V_th_rel, tau_m, tau_r, J_ext, K_ext, nu_ext, **kwargs): """ Calculation of firing rates for delta PSCs. See :func:`nnmt.lif._general._firing_rate_integration` for integration procedure. Uses :func:`nnmt.lif.delta._firing_rates_for_given_input`. Parameters ---------- J : np.array Weight matrix in V. K : np.array Indegree matrix. V_0_rel : [float | 1d array] Relative reset potential in V. V_th_rel : [float | 1d array] Relative threshold potential in V. tau_m : [float | 1d array] Membrane time constant in s. tau_r : [float | 1d array] Refractory time in s. J_ext : np.array External weight matrix in V. K_ext : np.array Numbers of external input neurons to each population. nu_ext : 1d array Firing rates of external populations in Hz. kwargs For additional kwargs regarding the fixpoint iteration procedure see :func:`nnmt.lif._general._firing_rate_integration`. Returns ------- np.array Array of firing rates of each population in Hz. """ firing_rate_params = { 'V_0_rel': V_0_rel, 'V_th_rel': V_th_rel, 'tau_m': tau_m, 'tau_r': tau_r, } input_params = { 'J': J, 'K': K, 'tau_m': tau_m, 'J_ext': J_ext, 'K_ext': K_ext, 'nu_ext': nu_ext, } return _general._firing_rate_integration(_firing_rates_for_given_input, firing_rate_params, input_params, **kwargs)
[docs]@_check_positive_params def _firing_rates_for_given_input(V_0_rel, V_th_rel, mu, sigma, tau_m, tau_r): """ Calculates stationary firing rate for delta shaped PSCs. Implementation of formula by Siegert for the mean-first-passage time :cite:p:`siegert1951`, found for example in Appendix A, Eq. A7 of :cite:t:`amit1997`. Parameters ---------- V_0_rel : [float | 1d array] Relative reset potential in V. V_th_rel : [float | 1d array] Relative threshold potential in V. mu : [float | 1d array] Mean input to population of neurons. sigma : [float | 1d array] Standard deviation of input to population of neurons. tau_m : [float | 1d array] Membrane time constant in s. tau_r : [float | 1d array] Refractory time in s. Returns ------- [float | np.array] Firing rates in Hz. """ y_th = (V_th_rel - mu) / sigma y_r = (V_0_rel - mu) / sigma y_th = np.atleast_1d(y_th) y_r = np.atleast_1d(y_r) # this brings tau_m and tau_r into the correct vectorized form if they are # scalars and doesn't do anything if they are arrays of appropriate size tau_m = tau_m + y_th - y_th tau_r = tau_r + y_th - y_th assert y_th.shape == y_r.shape assert y_th.ndim == y_r.ndim == 1 if np.any(V_th_rel - V_0_rel < 0): raise ValueError('V_th should be larger than V_0!') # determine order of quadrature params = {'start_order': 10, 'epsrel': 1e-12, 'maxiter': 10} gl_order = _get_erfcx_integral_gl_order(y_th=y_th, y_r=y_r, **params) # separate domains mask_exc = y_th < 0 mask_inh = 0 < y_r mask_interm = (y_r <= 0) & (0 <= y_th) # calculate rescaled siegert nu = np.zeros(shape=y_th.shape) params = {'tau_m': tau_m[mask_exc], 't_ref': tau_r[mask_exc], 'gl_order': gl_order} nu[mask_exc] = _siegert_exc(y_th=y_th[mask_exc], y_r=y_r[mask_exc], **params) params = {'tau_m': tau_m[mask_inh], 't_ref': tau_r[mask_inh], 'gl_order': gl_order} nu[mask_inh] = _siegert_inh(y_th=y_th[mask_inh], y_r=y_r[mask_inh], **params) params = {'tau_m': tau_m[mask_interm], 't_ref': tau_r[mask_interm], 'gl_order': gl_order} nu[mask_interm] = _siegert_interm(y_th=y_th[mask_interm], y_r=y_r[mask_interm], **params) # include exponential contributions nu[mask_inh] *= np.exp(-y_th[mask_inh]**2) nu[mask_interm] *= np.exp(-y_th[mask_interm]**2) # convert back to scalar if only one value calculated if nu.shape == (1,): return nu.item(0) else: return nu
def _get_erfcx_integral_gl_order(y_th, y_r, start_order, epsrel, maxiter): """Determine order of Gauss-Legendre quadrature for erfcx integral.""" # determine maximal integration range a = min(np.abs(y_th).min(), np.abs(y_r).min()) b = max(np.abs(y_th).max(), np.abs(y_r).max()) # adaptive quadrature from scipy.integrate for comparison I_quad = _quad(_erfcx, a, b, epsabs=0, epsrel=epsrel)[0] # increase order to reach desired accuracy order = start_order for _ in range(maxiter): I_gl = _erfcx_integral(a, b, order=order)[0] rel_error = np.abs(I_gl / I_quad - 1) if rel_error < epsrel: return order else: order *= 2 msg = f'Quadrature search failed to converge after {maxiter} iterations. ' msg += f'Last relative error {rel_error:e}, desired {epsrel:e}.' raise RuntimeError(msg) def _erfcx_integral(a, b, order): """Fixed order Gauss-Legendre quadrature of erfcx from a to b.""" assert np.all(a >= 0) and np.all(b >= 0) x, w = _roots_legendre(order) x = x[:, np.newaxis] w = w[:, np.newaxis] return (b - a) * np.sum(w * _erfcx((b - a) * x / 2 + (b + a) / 2), axis=0) / 2 def _siegert_exc(y_th, y_r, tau_m, t_ref, gl_order): """Calculate Siegert for y_th < 0.""" assert np.all(y_th < 0) Int = _erfcx_integral(np.abs(y_th), np.abs(y_r), gl_order) return 1 / (t_ref + tau_m * np.sqrt(np.pi) * Int) def _siegert_inh(y_th, y_r, tau_m, t_ref, gl_order): """Calculate Siegert without exp(-y_th**2) factor for 0 < y_th.""" assert np.all(0 < y_r) e_V_th_2 = np.exp(-y_th**2) Int = (2 * _dawsn(y_th) - 2 * np.exp(y_r**2 - y_th**2) * _dawsn(y_r)) Int -= e_V_th_2 * _erfcx_integral(y_r, y_th, gl_order) return 1 / (e_V_th_2 * t_ref + tau_m * np.sqrt(np.pi) * Int) def _siegert_interm(y_th, y_r, tau_m, t_ref, gl_order): """Calculate Siegert without exp(-y_th**2) factor for y_r <= 0 <= y_th.""" assert np.all((y_r <= 0) & (0 <= y_th)) e_V_th_2 = np.exp(-y_th**2) Int = 2 * _dawsn(y_th) Int += e_V_th_2 * _erfcx_integral(y_th, np.abs(y_r), gl_order) return 1 / (e_V_th_2 * t_ref + tau_m * np.sqrt(np.pi) * Int)
[docs]def mean_input(network): ''' Calc mean inputs to populations as function of firing rates of populations. See :func:`nnmt.lif._general._mean_input` for full documentation. Parameters ---------- network : Network object Model with the network parameters and previously calculated results listed in :func:`nnmt.lif.delta._mean_input`. Returns ------- np.array Array of mean inputs to each population in V. ''' list_of_params = ['J', 'K', 'tau_m', 'J_ext', 'K_ext', 'nu_ext'] try: params = {key: network.network_params[key] for key in list_of_params} except KeyError as param: raise RuntimeError(f'You are missing {param} for this calculation.') try: params['nu'] = network.results[_prefix + 'firing_rates'] except KeyError as quantity: raise RuntimeError(f'You first need to calculate the {quantity}.') return _cache(network, _mean_input, params, _prefix + 'mean_input', 'volt')
[docs]@_check_positive_params def _mean_input(nu, J, K, tau_m, J_ext, K_ext, nu_ext): """ Calc mean input for lif neurons in fixed in-degree connectivity network. See :func:`nnmt.lif._general._mean_input` for full documentation. Parameters ---------- nu : np.array Firing rates of populations in Hz. J : np.array Weight matrix in V. K : np.array In-degree matrix. tau_m : [float | 1d array] Membrane time constant in s. J_ext : np.array External weight matrix in V. K_ext : np.array Numbers of external input neurons to each population. nu_ext : 1d array Firing rates of external populations in Hz. Returns ------- np.array Array of mean inputs to each population in V. """ return _general._mean_input(nu, J, K, tau_m, J_ext, K_ext, nu_ext)
[docs]def std_input(network): ''' Calculates standard deviation of inputs to populations. See :func:`nnmt.lif._general._std_input` for full documentation. Parameters ---------- network : nnmt.models.Network or child class instance. Network with the network parameters and previously calculated results listed in :func:`nnmt.lif.delta._std_input`. Returns ------- np.array Array of mean inputs to each population in V. ''' list_of_params = ['J', 'K', 'tau_m', 'J_ext', 'K_ext', 'nu_ext'] try: params = {key: network.network_params[key] for key in list_of_params} except KeyError as param: raise RuntimeError(f'You are missing {param} for this calculation.') try: params['nu'] = network.results[_prefix + 'firing_rates'] except KeyError as quantity: raise RuntimeError(f'You first need to calculate the {quantity}.') return _cache(network, _std_input, params, _prefix + 'std_input', 'volt')
[docs]@_check_positive_params def _std_input(nu, J, K, tau_m, J_ext, K_ext, nu_ext): """ Plain calculation of standard deviation of neuronal input. See :func:`nnmt.lif._general._std_input` for full documentation. Parameters ---------- nu : np.array Firing rates of populations in Hz. J : np.array Weight matrix in V. K : np.array In-degree matrix. tau_m : [float | 1d array] Membrane time constant in s. J_ext : np.array External weight matrix in V. K_ext : np.array Numbers of external input neurons to each population. nu_ext : 1d array Firing rates of external populations in Hz. Returns ------- np.array Array of mean inputs to each population in V. """ return _general._std_input(nu, J, K, tau_m, J_ext, K_ext, nu_ext)
[docs]def _derivative_of_firing_rates_wrt_mean_input(V_0_rel, V_th_rel, mu, sigma, tau_m, tau_r): """ Derivative of the stationary firing rate with respect to the mean input. See Appendix B in :cite:t:`schuecker2014`. Parameters ---------- V_0_rel : float Relative reset potential in V. V_th_rel : float Relative threshold potential in V. mu : float Mean neuron activity in V. sigma : float Standard deviation of neuron activity in V. tau_m : float Membrane time constant in s. tau_r : float Refractory time in s. Returns ------- float Zero frequency limit of white noise transfer function in Hz/V. """ if np.any(sigma == 0): raise ZeroDivisionError('Phi_prime_mu contains division by sigma!') y_th = (V_th_rel - mu) / sigma y_r = (V_0_rel - mu) / sigma nu0 = _firing_rates_for_given_input(V_0_rel, V_th_rel, mu, sigma, tau_m, tau_r) return (np.sqrt(np.pi) * tau_m * np.power(nu0, 2) / sigma * (np.exp(y_th**2) * (1 + _erf(y_th)) - np.exp(y_r**2) * (1 + _erf(y_r))))