{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Binary rates: simulation vs mean-field\n\n<img src=\"file://../../../../examples/binary_rate_simulation/binary.png\" width=\"1000\" alt=\"Plot of simulated and estimated rates\">\n\nHere we simulate an E-I network of binary neurons, calculate the mean-field\nestimate of the firing rates, and plot them together.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import nnmt\nimport numpy as np\nimport matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Helper functions\n\nFirst, we define the functions used to perform the simulation.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def _sort_queue(q):\n    \"\"\"\n    Sorts a 2d array of id and update time point by update time points.\n\n    Sorts the queue of update time points\n        id_0, t_next_0\n        id_1, t_next_1\n        ...\n    by the update time point t_next.\n\n    Parameters\n    ----------\n    q : array\n        2d array of neuron ids and update time points.\n\n    Returns\n    -------\n    array\n        Sorted array of ids and update time points.\n    \"\"\"\n    return q[np.argsort(q[:, 1], kind='stable')]\n\n\ndef update_poisson(J, S, T, tau, thetas):\n    \"\"\"\n    Simulates a network of binary neurons.\n\n    Evolves the initial network state `S` in time by drawing exponentially\n    distributed update times (Poisson process) with time constant `tau` until\n    time `T` is reached, using connectivity matrix `J` and thresholds `thetas`.\n\n    Parameters\n    ----------\n    J : array\n        Connectivity matrix.\n    S : array\n        Initial state of each neuron.\n    T : float\n        Simulation time.\n    tau : float\n        Time constant of all binary neurons.\n    thetas : [array|list]\n        Thresholds of neurons.\n\n    Returns\n    -------\n    array\n        Update times.\n    array\n        State of each neuron at each update time point.\n    array\n        Mean population activity at each update time point.\n    \"\"\"\n    # check dimensions of J and S\n    assert J.shape[0] == J.shape[1] == S.shape[0]\n    # get total number of neurons\n    N = S.shape[0]\n    # create update queue\n    update_queue = np.empty((N, 2), dtype=np.float32)\n    update_queue[:, 0] = np.arange(N)\n    # draw first update time point for every neuron\n    # from an exponential distribution with mean tau\n    update_queue[:, 1] = np.random.exponential(tau, N)\n    # sort queue ascendingly according to update time\n    update_queue = _sort_queue(update_queue)\n    # initialize storage lists\n    ts, m, Ss = [], [], []\n    t = 0\n    while t < T:\n        # select neuron with next update time point\n        i, t = update_queue[0, :]\n        i = int(i)\n        # input to neuron i\n        h_i = J[i, :].dot(S)\n        # update state of neuron i\n        S[i] = np.heaviside(h_i-thetas[i], 0)\n        # draw new update time\n        update_queue[0, 1] += np.random.exponential(tau)\n        update_queue = _sort_queue(update_queue)\n        m.append(S.mean())\n        ts.append(t)\n        Ss.append(S.copy())\n    return np.array(ts), np.array(Ss).T, np.array(m)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Here we define the functions that construct the network properties in a\nformat needed to perform the simulation.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def fixed_indegree_connectivity(N, J, K):\n    \"\"\"\n    Constructs a fixed indegree matrix for the given network parameters.\n\n    Parameters\n    ----------\n    N : array of ints\n        Number of neurons in each population.\n    J : array of floats\n        Weight matrix.\n    K : array of ints\n        Indegree matrix.\n\n    Returns\n    -------\n    array\n        Connectivity matrix.\n    \"\"\"\n\n    W = np.zeros((N.sum(), N.sum()), dtype=np.float32)\n\n    # list of neurons (each one gets a unique number)\n    neurons = np.arange(N.sum())\n\n    population_ix = N.cumsum()\n\n    for pre_ix, pre_pop in enumerate(\n            np.array_split(neurons, population_ix[:-1])):\n        for post_ix, post_pop in enumerate(\n                np.array_split(neurons, population_ix[:-1])):\n            for post_neuron in post_pop:\n                pre_neurons = np.random.choice(pre_pop[pre_pop!=post_neuron],\n                                               size=K[post_ix][pre_ix],\n                                               replace=False)\n                W[post_neuron, pre_neurons] = J[post_ix][pre_ix]\n\n    return W\n\n\ndef neuron_thresholds(N, theta):\n    \"\"\"\n    Creates a list of thresholds for each neuron.\n\n    Parameters\n    ----------\n    N : array\n        Numbers of neurons in each population.\n    theta : array\n        Threshold for each population\n\n    Returns\n    -------\n    array\n        Threshold for each individual neuron numbered from 0 to N-1.\n    \"\"\"\n    thresholds = np.zeros(N.sum())\n    population_ix = np.append([0], N.cumsum())\n    for i in range(len(population_ix[:-1])):\n        thresholds[population_ix[i]: population_ix[i+1]] = theta[i]\n    return thresholds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Parameter definition\n\nHere we define the network parameters\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# numbers of neurons in each population and their respective name\nN = np.array([1000, 1000])\nlabel = ['E', 'I']\n\n# indegree matrix\nK = np.array([[150, 200],\n              [350, 200]])\n\n# weight matrix\nJ = np.array([[0.1, -0.2],\n              [0.1, -0.2]])\n\nnetwork_params = {'J': J, 'K': K, 'N': N}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Calculation of mean-field estimate\n\nLets define the network model. We decided to use the ``Plain`` model, as we\njust want to load the parameters into the models dicts and do not want to\ncalculate any dependent parameters from them.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# here we have to copy the dict due to a bug in ``nnmt.models.Network``\nnetwork = nnmt.models.Plain(dict(network_params))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We could set the firing threshold of the neurons directly, but here we\ndecided to calculate the threshold using a balanced condition for some\nexpected rates\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "expected_rates = [0.7, 0.4]\ntheta = nnmt.binary.balanced_threshold(network, expected_rates)\nnetwork.network_params['theta'] = theta"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We calculate the mean-field estimate of the rates\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "rates_thy = nnmt.binary.mean_activity(network)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Simulation\n\nFor simulating the network we require a concrete realization of the network\nconnectivity\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "W = fixed_indegree_connectivity(**network_params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "And we are going to use a list that defines the thresholds of each neuron\nseparately\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "thresholds = neuron_thresholds(network.network_params['N'],\n                               network.network_params['theta'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The simulation additionally requires to define the time constant of the\nneurons\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "sim_params = {'tau': 1., 'thetas': thresholds}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Then we can run the simulation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# simulation time\nT = 10\n# initial state\nS = np.zeros(W.shape[0], dtype=np.float32)\n# simulate\nt, S, m = update_poisson(W, S, T=T, **sim_params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we calculate the mean rates for each population\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "rates_sim = np.array([pop.mean(axis=0)\n                      for pop in np.array_split(S, N.cumsum()[:-1])])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Plotting\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "fig, axs = plt.subplots(1, len(N), figsize=(6, 2.5))\nfor i in range(len(N)):\n    axs[i].plot(t, rates_sim[i], '-', color='k', label='sim')\n    axs[i].plot(t, rates_thy[i]*np.ones_like(t), '--', color='gray',\n                label='thy')\n    axs[i].set_xlim(0, T)\n    axs[i].set_ylim(-0.1, 1.1)\n    axs[i].set_xlabel('time')\n    axs[i].set_ylabel('mean activity')\n    axs[i].legend(loc=0)\n    axs[i].set_title(f'{label[i]} population')\n\nplt.tight_layout()\nplt.savefig('binary.png', dpi=600)\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}