{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Response nonlinearities\n\nIn this example, we reproduce the different types of response nonlinearities of\nan EI network that were uncovered in :cite:t:`sanzeni2020`. To this end, we\nneed to determine the self-consistent rates of EI networks with specific\nindegrees and synaptic weights for changing external input.\n\nMost of this script handles all the necessary parameters, the relevant\ncalculation is performed by the function :meth:`nnmt.lif.delta._firing_rates`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nimport matplotlib.pyplot as plt\nimport matplotlib.gridspec as gridspec\n\nfrom nnmt.lif.delta import _firing_rates\n\n# try loading svgutils to add the network sketch\ntry:\n    import os\n    import svgutils.transform as sg\n    insert_sketch = True\nexcept ImportError:\n    insert_sketch = False\n\n# use matplotlib style file\nplt.style.use('frontiers.mplstyle')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, we define the common parameters for all networks: the time constants\nand the reset and threshold voltage.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "params_all = dict(\n    # time constants in s\n    tau_m=20.*1e-3, tau_r=2.*1e-3,\n    # reset and threshold voltage relative to leak in mV\n    V_0_rel=10., V_th_rel=20.,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Next, we define the indegrees and the synaptic weights corresponding to the\ndifferent nonlinearities.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Saturation-driven nonlinearity\ng_E, g_I, a_E, a_I = 8, 7, 4, 2\nparams_sdn = dict(\n    J=np.array([[0.2, -g_E*0.2], [0.2, -g_I*0.2]]),\n    K=np.array([[400., 100.], [400., 100.]]),\n    J_ext=np.array([0.2, 0.2]),\n    K_ext=np.array([a_E*400., a_I*400.])\n)\n# Saturation-driven multisolution\ng_E, g_I, a_E, a_I = 2.08, 1.67, 1, 1\nparams_sdm = dict(\n    J=np.array([[0.2, -g_E*0.2], [2.4*0.2/2.5, -g_I*2.4*0.2/2.5]]),\n    K=np.array([[400., 100.], [400., 100.]]),\n    J_ext=np.array([0.2, 2.4*0.2/2.5]),\n    K_ext=np.array([a_E*400., a_I*400.])\n)\n# Response-onset supersaturation\ng_E, g_I, a_E, a_I = 4.5, 2.9, 1, 1\nparams_ros = dict(\n    J=np.array([[0.2, -g_E*0.2], [0.2, -g_I*0.2]]),\n    K=np.array([[400., 100.], [400., 100.]]),\n    J_ext=np.array([0.2, 0.2]),\n    K_ext=np.array([a_E*400., a_I*400.])\n)\n# Mean-driven multisolution\ng_E, g_I, a_E, a_I = 4.1, 2.46, 1, 0.2\nparams_mdm = dict(\n    K=np.array([[800., 200.], [400., 100.]]),\n    J=np.array([[0.2, -g_E*0.2], [0.2, -g_I*0.2]]),\n    J_ext=np.array([0.2, 0.2]),\n    K_ext=np.array([a_E*800., a_I*400.])\n)\n# Noise-driven multisolution\ng_E, g_I, a_E, a_I = 7, 6, 1, 0.7\nparams_ndm = dict(\n    K=np.array([[400., 100.], [400., 100.]]),\n    J=np.array([[0.5, -g_E*0.5], [0.5, -g_I*0.5]]),\n    J_ext=np.array([0.5, 0.5]),\n    K_ext=np.array([a_E*400., a_I*400.])\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We introduce a helper function to handle the parameters. The firing rates\nare determined using the :meth:`nnmt.lif.delta._firing_rates` function.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def solve_theory(params, nu_0, nu_ext_min, nu_ext_max, nu_ext_steps, method):\n    # combine common and specific parameters\n    params.update(params_all)\n    # create an array with all external rates and an array for the results\n    nu_ext_arr = np.linspace(nu_ext_min, nu_ext_max, nu_ext_steps)\n    nu_arr = np.zeros((nu_ext_steps, 2))\n    # iterate through the ext. rates and determine the self-consistent rates\n    for i, nu_ext in enumerate(nu_ext_arr):\n        try:\n            nu_arr[i] = _firing_rates(nu_0=nu_0, nu_ext=nu_ext,\n                                      fixpoint_method=method, **params)\n        except RuntimeError:\n            # set non-convergent solutions to nan for convenience\n            nu_arr[i] = (np.nan, np.nan)\n    return nu_ext_arr, nu_arr"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now, we calculate the firing rate for each nonlinearity. By default, we use\nthe `ODE` method. If this does not converge, we use `LSTSQ`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print('Calculating self-consitent rates...')\nprint('Saturation-driven nonlinearity...')\nnu_ext_sdn, nu_sdn = solve_theory(params_sdn, (0, 0), 1, 100, 50,\n                                  method='ODE')\nprint('Saturation-driven multisolution...')\nnu_ext_sdm_a, nu_sdm_a = solve_theory(params_sdm, (0, 0), 1, 9, 10,\n                                      method='ODE')\nnu_ext_sdm_b, nu_sdm_b = solve_theory(params_sdm, (500, 500), 1, 100, 50,\n                                      method='ODE')\nnu_ext_sdm_c, nu_sdm_c = solve_theory(params_sdm, (10, 10), 1, 20, 10,\n                                      method='LSTSQ')\nnu_ext_sdm_d, nu_sdm_d = solve_theory(params_sdm, (100, 100), 1, 20, 10,\n                                      method='LSTSQ')\nprint('Response-onset supersaturation...')\nnu_ext_ros_a, nu_ros_a = solve_theory(params_ros, (0, 0), 0.5, 50, 50,\n                                      method='ODE')\nnu_ext_ros_b, nu_ros_b = solve_theory(params_ros, (10, 10), 7.5, 12.5, 50,\n                                      method='ODE')\nprint('Mean-driven multisolution...')\nnu_ext_mdm_a, nu_mdm_a = solve_theory(params_mdm, (0, 0), 0.1, 5, 25,\n                                      method='LSTSQ')\nnu_ext_mdm_b, nu_mdm_b = solve_theory(params_mdm, (50, 50), 0.1, 10, 50,\n                                      method='LSTSQ')\nnu_ext_mdm_c, nu_mdm_c = solve_theory(params_mdm, (10, 0), 0.1, 5, 25,\n                                      method='LSTSQ')\nprint('Noise-driven multisolution...')\nnu_ext_ndm_a, nu_ndm_a = solve_theory(params_ndm, (0, 0), 0.05, 5, 50,\n                                      method='ODE')\nnu_ext_ndm_b, nu_ndm_b = solve_theory(params_ndm, (10, 10), 0.05, 5, 50,\n                                      method='ODE')\nnu_ext_ndm_c, nu_ndm_c = solve_theory(params_ndm, (5, 4), 0.05, 5, 50,\n                                      method='LSTSQ')\nnu_ext_ndm_d, nu_ndm_d = solve_theory(params_ndm, (2, 0), 0.05, 5, 50,\n                                      method='LSTSQ')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Finally, we plot the result. Again, we introduce a helper function to handle\nthe parameters. Panel (A) contains a sketch of the network that can only be\nadded in the pdf output but is not shown in `plt.show()`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def plot_rates(ax, nu_ext_arr_lst, nu_arr_lst, xmax, ymax, xlabel, ylabel,\n               title, colors, label, label_prms):\n    ax.set_prop_cycle(color=colors)\n    for i, nu_arr in enumerate(nu_arr_lst):\n        ax.plot(nu_ext_arr_lst[i], nu_arr, 'o')\n    ax.set_xlim(0, xmax)\n    ax.set_ylim(0, ymax)\n    ax.set_xticks((0, xmax/2, xmax))\n    ax.set_yticks((0, ymax/2, ymax))\n    ax.set_xlabel(xlabel)\n    ax.set_ylabel(ylabel)\n    ax.set_title(title)\n    ax.text(s=label, transform=ax.transAxes, **label_prms)\n\n\nprint('Plotting...')\nfig = plt.figure(figsize=(7.08661, 2.95276),  # two column figure, 180mm wide\n                 constrained_layout=True)\ngs = gridspec.GridSpec(2, 3, figure=fig)\nlabel_prms = dict(x=-0.3, y=1.45, fontsize=10, fontweight='bold',\n                  va='top', ha='right')\ncolors = ['#4c72b0', '#c44e52']\n# Sketch\nax_sketch = fig.add_subplot(gs[0, 0])\nax_sketch.axis('off')\nax_sketch.set_title('network\\nsketch')\nax_sketch.text(s='A', transform=ax_sketch.transAxes, **label_prms)\n# Saturation-driven nonlinearity\nplot_rates(fig.add_subplot(gs[0, 1]), [nu_ext_sdn], [nu_sdn], 100, 500,\n           '', r'rate $\\nu$ (1/s)', 'saturation-driven\\nnonlinearity',\n           colors, 'B', label_prms)\n# Saturation-driven multisolution\nplot_rates(fig.add_subplot(gs[0, 2]),\n           [nu_ext_sdm_a, nu_ext_sdm_b, nu_ext_sdm_c, nu_ext_sdm_d],\n           [nu_sdm_a, nu_sdm_b, nu_sdm_c, nu_sdm_d], 100, 500,\n           '', '', 'saturation-driven\\nmulti-solution',\n           colors, 'C', label_prms)\n# Response-onset supersaturation\nplot_rates(fig.add_subplot(gs[1, 0]), [nu_ext_ros_a, nu_ext_ros_b],\n           [nu_ros_a, nu_ros_b], 50, 5,\n           r'external rate $\\nu_X$ (1/s)', r'rate $\\nu$ (1/s)',\n           'response-onset\\nsupersaturation', colors, 'D', label_prms)\n# Mean-driven multisolution\nplot_rates(fig.add_subplot(gs[1, 1]),\n           [nu_ext_mdm_a, nu_ext_mdm_b, nu_ext_mdm_c],\n           [nu_mdm_a, nu_mdm_b, nu_mdm_c], 10, 50,\n           r'external rate $\\nu_X$ (1/s)', '', 'mean-driven\\nmulti-solution',\n           colors, 'E', label_prms)\n# Noise-driven multisolution\nplot_rates(fig.add_subplot(gs[1, 2]),\n           [nu_ext_ndm_a, nu_ext_ndm_b, nu_ext_ndm_c, nu_ext_ndm_d],\n           [nu_ndm_a, nu_ndm_b, nu_ndm_c, nu_ndm_d], 5, 10,\n           r'external rate $\\nu_X$ (1/s)', '', 'noise-driven\\nmulti-solution',\n           colors, 'F', label_prms)\n\n# insert sketch using svgutil, try saving as pdf using inkscape\nif insert_sketch:\n    sketch_fn = 'brunel_sketch.svg'\n    plot_fn = 'response_nonlinearities'\n    svg_mpl = sg.from_mpl(fig, savefig_kw=dict(transparent=True))\n    w_svg, h_svg = svg_mpl.get_size()\n    svg_mpl.set_size((w_svg+'pt', h_svg+'pt'))\n    svg_sketch = sg.fromfile(sketch_fn).getroot()\n    svg_sketch.moveto(x=25, y=30, scale_x=1.0)\n    svg_mpl.append(svg_sketch)\n    svg_mpl.save(f'{plot_fn}.svg')\n    os_return = os.system(f'inkscape --export-pdf={plot_fn}.pdf {plot_fn}.svg')\n    if os_return == 0:\n        os.remove(f'{plot_fn}.svg')\n    else:\n        print('Conversion to pdf using inkscape failed, keeping svg...')\n\n# show figure (without sketch)\nax_sketch.annotate('(sketch)', xy=(0., 0.5))\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}