Source code for nnmt.binary
"""
Collection of functions for binarys neurons.
Network Functions
*****************
.. autosummary::
:toctree: _toctree/lif/
mean_activity
mean_input
std_input
working_point
balanced_threshold
Parameter Functions
*******************
.. autosummary::
:toctree: _toctree/lif/
_mean_activity
_mean_activity_for_given_input
_mean_input
_std_input
_balanced_threshold
"""
import numpy as np
from scipy.special import erfc as _erfc
from . import _solvers
from .utils import _cache
_prefix = 'binary.'
[docs]def _mean_activity_for_given_input(mu, sigma, theta):
"""
Calcs the firing rates of binary neurons for given input statistics.
Parameters
----------
mu : array
Mean inputs for each population.
sigma : array
Standard deviation of inputs of each population.
theta : [array | float]
Firing thresholds for each population.
Returns
-------
array
Mean firing rates for each population.
"""
return 0.5 * _erfc(-(mu - theta) / (np.sqrt(2) * sigma))
[docs]def mean_activity(network, **kwargs):
"""
Calculates stationary firing rates for a network of binary neurons.
See :func:`nnmt.binary._mean_activity` for full documentation.
Parameters
----------
network : nnmt.models.Network or child class instance.
Network with the network parameters listed in the docstring of
:func:`nnmt.binary._mean_activity`.
kwargs
For additional kwargs regarding the fixpoint iteration procedure see
:func:`nnmt._solvers._firing_rate_integration`.
Returns
-------
array
Mean firing rates for each population.
"""
required_params = ['J', 'K', 'theta']
optional_params = ['J_ext', 'K_ext', 'm_ext']
try:
params = {key: network.network_params[key] for key in required_params}
except KeyError as param:
raise RuntimeError(
f"You are missing {param} for calculating the firing rate!\n"
"Have a look into the documentation for more details on 'binary' "
"parameters.")
try:
params = {key: network.network_params[key] for key in optional_params}
except KeyError as param:
pass
params.update(kwargs)
return _cache(network, _mean_activity, params, _prefix + 'mean_activity')
[docs]def _mean_activity(J, K, theta, **kwargs):
"""
Calcs firing rates for each population in a network of binary neurons.
See :func:`nnmt._solvers._firing_rate_integration` for integration
procedure.
Uses :func:`nnmt.binary._mean_activity_for_given_input`.
Parameters
----------
J : array
Weight matrix.
K : array
Connectivity matrix.
theta : [array | float]
Firing threshold.
Returns
-------
array
Mean firing rates for each population.
"""
firing_rate_params = {
'theta': theta
}
input_funcs = [_mean_input, _std_input]
input_params = {
'J': J,
'K': K,
}
return _solvers._firing_rate_integration(_mean_activity_for_given_input,
firing_rate_params,
input_funcs,
input_params,
**kwargs)
[docs]def mean_input(network):
'''
Calc mean inputs to populations as function of firing rates of populations.
See :func:`nnmt.binary._mean_input` for full documentation.
Parameters
----------
network : Network object
Model with the network parameters and previously calculated results
listed in :func:`nnmt.binary._mean_input`.
Returns
-------
array
Array of mean inputs to each population.
'''
required_params = ['J', 'K']
optional_params = ['J_ext', 'K_ext', 'm_ext']
try:
params = {key: network.network_params[key] for key in required_params}
except KeyError as param:
raise RuntimeError(f'You are missing {param} for this calculation.')
try:
params = {key: network.network_params[key] for key in optional_params}
except KeyError as param:
pass
try:
params['m'] = network.results[_prefix + 'mean_activity']
except KeyError as quantity:
raise RuntimeError(f'You first need to calculate the {quantity}.')
return _cache(network, _mean_input, params, _prefix + 'mean_input')
[docs]def _mean_input(m, J, K, J_ext=0, K_ext=0, m_ext=0):
"""
Calculates the mean inputs in a network of binary neurons.
Parameters
----------
m : array
Mean activity of each population.
J : array
Weight matrix.
K : array
Connectivity matrix.
J_ext : array
Weight matrix of external inputs.
K_ext : array
Connectivity matrix of external inputs.
m_ext : float
External input.
Returns
-------
array
Mean input of each population.
"""
return np.dot(K * J, m) + np.dot(K_ext * J_ext, m_ext)
[docs]def std_input(network):
'''
Calcs the standard deviation of the inputs in a network of binary neurons.
See :func:`nnmt.binary._std_input` for full documentation.
Parameters
----------
network : Network object
Model with the network parameters and previously calculated results
listed in :func:`nnmt.binary._std_input`.
Returns
-------
array
Array of standard deviations of inputs to each population.
'''
required_params = ['J', 'K']
optional_params = ['J_ext', 'K_ext', 'm_ext']
try:
params = {key: network.network_params[key] for key in required_params}
except KeyError as param:
raise RuntimeError(f'You are missing {param} for this calculation.')
try:
params = {key: network.network_params[key] for key in optional_params}
except KeyError as param:
pass
try:
params['m'] = network.results[_prefix + 'mean_activity']
except KeyError as quantity:
raise RuntimeError(f'You first need to calculate the {quantity}.')
return _cache(network, _std_input, params, _prefix + 'std_input')
[docs]def _std_input(m, J, K, J_ext=0, K_ext=0, m_ext=0):
"""
Calcs the standard deviation of the inputs in a network of binary neurons.
Parameters
----------
m : array
Mean activity of each population.
J : array
Weight matrix.
K : array
Connectivity matrix.
J_ext : array
Weight matrix of external inputs.
K_ext : array
Connectivity matrix of external inputs.
m_ext : float
External input.
Returns
-------
array
Standard deviations of input.
"""
return np.sqrt(np.dot(K * J**2, m * (1 - m))
+ np.dot(K_ext * J_ext**2, m_ext * (1 - m_ext)))
[docs]def working_point(network, **kwargs):
"""
Calculates working point (rates, mean, and std input) for binary network.
Calculates the firing rates using :func:`nnmt.binary.mean_activity`,
the mean input using :func:`nnmt.binary.mean_input`,
and the standard deviation of the input using
:func:`nnmt.binary.std_input`.
Parameters
----------
network : nnmt.models.Network or child class instance.
Network with the network parameters listed in
:func:`nnmt.binary._mean_activity`.
kwargs
For additional kwargs regarding the fixpoint iteration procedure see
:func:`nnmt._solvers._firing_rate_integration`.
Returns
-------
dict
Dictionary containing firing rates, mean input and std input.
"""
return {'mean_activity': mean_activity(network, **kwargs),
'mean_input': mean_input(network),
'std_input': std_input(network)}
[docs]def balanced_threshold(network, m_exp):
"""
Calculate threshold equal to input given expected mean activity (balance).
See :func:`nnmt.binary._balanced_threshold` for full documentation.
Parameters
----------
network : nnmt.models.Network or child class instance.
Network with the network parameters listed in the docstring of
:func:`nnmt.binary._balanced_threshold`.
m_exp : array
Expected mean activity for each population.
Returns
-------
array
Balanced threshold for each population.
"""
list_of_params = ['J', 'K']
try:
params = {key: network.network_params[key] for key in list_of_params}
except KeyError as param:
raise RuntimeError(
f"You are missing {param} for calculating the balanced "
"threshold!\nHave a look into the documentation for more details "
"on 'binary' parameters.")
params['m_exp'] = m_exp
return _cache(network, _balanced_threshold, params,
_prefix + 'balanced_threshold')
[docs]def _balanced_threshold(m_exp, J, K):
"""
Calculate threshold equal to input given expected mean activity (balance).
Parameters
----------
m_exp : array
Expected mean activity for each population.
J : array
Weight matrix.
K : array
Connectivity matrix.
Returns
-------
array
Balanced threshold for each population.
"""
return np.dot(K * J, m_exp)