{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Mapping LIF to rate network\n\nThis example demonstrates the methods used in Figure 5 of :cite:t:`senk2020`.\nThe transfer function of a LIF neuron is fitted for different working points.\nA figure illustrating the network structure of the used model is set up in\n:doc:`network_structure`.\nThe same model is used in the example :doc:`spatial_patterns`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import nnmt.lif.exp as mft  # main set of meanfield tools\nfrom nnmt.models.basic import Basic as BasicNetwork\nimport numpy as np\nimport matplotlib as mpl\nimport matplotlib.gridspec as gridspec\nimport matplotlib.pyplot as plt\nfrom matplotlib.patches import Patch\nfrom matplotlib import ticker\nplt.style.use('frontiers.mplstyle')\nmpl.rcParams.update({'legend.fontsize': 'medium',  # old: 5.0 was too small\n                     'axes.titlepad': 0.0})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, we define parameters for data generation and plotting.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "params = {\n    # mean and standard deviations of inputs to scan (in V)\n    'mean_inputs_scan': np.arange(6., 14., 2.) * 1e-3,\n    'std_inputs_scan': np.arange(6., 14., 2.) * 1e-3,\n\n    # pairs of mean and standard deviation of inputs to show transfer function\n    # and fit (in V)\n    'mean_std_inputs_tf': np.array([[6., 6.],\n                                    [6., 12.],\n                                    [10., 10.],\n                                    [12., 12.]]) * 1e-3,\n\n    # mean and standard deviation of input used for stability analysis (in V)\n    'mean_std_inputs_stability': np.array([10., 10.]) * 1e-3,\n\n    # figure width in inch\n    # (setting the value a bit smaller results in 180 mm width of .eps output)\n    'figwidth_2cols': 7.08, # < 180. / 25.4\n\n    # labels and corresponding scaling parameters for plotted quantities\n    'quantities': {\n        'mean_input': {\n            'label': r'mean input $\\mu$ (mV)',\n            'scale': 1e3},\n        'std_input': {\n            'label': r'std input $\\sigma$ (mV)',\n            'scale': 1e3},\n        'nu_ext_exc': {\n            'label': 'exc. external rate\\n' + r'$\\nu_\\mathrm{ext,E}$ (1000/s)',\n            'scale': 1e-3},\n        'nu_ext_inh': {\n            'label': 'inh. external rate\\n' + r'$\\nu_\\mathrm{ext,I}$ (1000/s)',\n            'scale': 1e-3},\n        'firing_rates': {\n            'label': 'rate\\n' + r'$\\nu$ (1/s)',\n            'scale': 1.},\n        'tau_rate': {\n            'label': 'fit time constant\\n' + r'$\\tau$ (ms)',\n            'scale': 1e3},\n        'W_rate': {\n            'label': 'fit exc. weight\\n' + r'$w_\\mathrm{E}$',\n            'scale': 1.},  # unitless\n        'fit_error': {\n            'label': 'fit error\\n' + r'$\\eta$ (%)',\n            'scale': 1e2},\n        'transfer_function': {\n            'label': r'transfer function $N_\\mathrm{cn,s}$',\n            'scale': 1e-3},\n        'transfer_function_amplitude': {\n            'label':\n                r'amplitude $|N_\\mathrm{cn,s}|\\quad(\\mathrm{s}\\cdot\\mathrm{mV})^{-1}$'},\n        'transfer_function_phase': {\n            'label': r'phase $\\angle N_\\mathrm{cn,s}\\quad(\\circ)$', },\n        'frequencies': {\n            'label': r'frequency $\\mathrm{Im}[\\lambda]/(2\\pi)$ (Hz)',\n            'scale': 1.}},\n\n    # color definitions\n    # numbers from discrete rainbow scheme of https://personal.sron.nl/~pault\n    'colors': {\n        'light_grey': '#BBBBBB',\n        'dark_grey': '#555555',\n        'dark_purple': '#882E72',  # no. 9\n        'light_purple': '#D1BBD7',  # no. 3\n        'dark_green': '#4EB265',  # no. 15\n        'light_green': '#CAE0AB',  # no. 17\n        'dark_orange': '#E8601C',  # no. 24\n        'light_orange': '#F6C141'},  # no. 20\n\n    # colors for transfer function [dark for LIF trans. func., light for fit]\n    'colors_tf': [\n        ['dark_purple', 'light_purple'],\n        ['dark_green', 'light_green'],\n        ['dark_orange', 'light_orange'],\n        ['dark_grey', 'light_grey']]}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We also define a helper function for adding labels to figure panels.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def _add_label(ax, label, xshift=0., yshift=0., scale_fs=1.):\n    \"\"\"\n    Adds label to plot panel given by axis.\n\n    Parameters:\n    -----------\n    ax : matplotlib.axes.Axes object\n        Axes.\n    label : str\n        Letter.\n    xshift : float\n        x-shift of label position.\n    yshift : float\n        y-shift of label position.\n    scale_fs : float\n        Scale factor for font size.\n    \"\"\"\n    label_pos = [0., 1.]\n    ax.text(label_pos[0] + xshift, label_pos[1] + yshift, label,\n            ha='left', va='bottom',\n            transform=ax.transAxes, fontweight='bold',\n            fontsize=mpl.rcParams['font.size'] * scale_fs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Generating data\nWe instantiate a ``Basic`` model with a set of pre-defined network and\nanalysis parameters.\nThe relative inhibition is here g = 5 in contrast to the original Figure 5 of\n:cite:t:`senk2020` which uses g = 6.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "network = BasicNetwork(\n    network_params='Senk2020_network_params.yaml',\n    analysis_params='Senk2020_analysis_params.yaml')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "All results will be stored in ``tf_scan_results``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "tf_scan_results = {}\ntf_scan_results['frequencies'] = \\\n    network.analysis_params['omegas'] / (2. * np.pi)\ndims = (len(params['mean_inputs_scan']), len(params['std_inputs_scan']))\nfor key in ['nu_ext_exc', 'nu_ext_inh', 'firing_rates',\n            'tau_rate', 'fit_error', 'W_rate']:\n    tf_scan_results[key] = np.zeros(dims)\nfor key in ['transfer_function', 'transfer_function_fit']:\n    tf_scan_results[key] = np.zeros(\n        (dims[0], dims[1], len(network.analysis_params['omegas'])),\n        dtype=complex)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The main loop for generating the data iterates over working points which are\ndefined as pairs of mean and standard deviation of inputs.\nFor each working point, we first compute the excitatory and inhibitory\nexternal firing rates required to preserve the working point and adjust the\nnetwork parameters accordingly.\nThen, we calculate the LIF transfer function and fit it with the one of a\nlow-pass filter using a least-squares fit.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print('Iterating over working points and fitting the LIF transfer function.')\n\nfor i, mu in enumerate(params['mean_inputs_scan']):\n    for j, sigma in enumerate(params['std_inputs_scan']):\n\n        print(f'    (mu, sigma) = ({mu * 1e3}, {sigma * 1e3}) mV')\n\n        # fix working point via external rates\n        nu_ext = mft.external_rates_for_fixed_input(\n            network, mu_set=mu, sigma_set=sigma)\n\n        network.change_parameters(\n            changed_network_params={'nu_ext': nu_ext},\n            overwrite=True)\n\n        # calculate transfer function and its fit\n        mft.working_point(network)\n        mft.transfer_function(network)\n        mft.fit_transfer_function(network)\n\n        # store results\n        tf_scan_results['nu_ext_exc'][i, j] = nu_ext[0]\n        tf_scan_results['nu_ext_inh'][i, j] = nu_ext[1]\n\n        # 1D results (assert equal values for populations, store only one)\n        for key in ['firing_rates', 'tau_rate', 'fit_error']:\n            res = network.results[mft._prefix + key]\n            assert len(np.shape(res)) == 1 and len(np.unique(res)) == 1\n            tf_scan_results[key][i, j] = res[0]\n\n        # 2D results (assert equal rows, store only first value (E->E,I))\n        for key in ['W_rate']:\n            res = network.results[mft._prefix + key]\n            assert len(\n                np.shape(res)) == 2 and np.isclose(\n                res, res[0]).all()\n            tf_scan_results[key][i, j] = res[0, 0]\n\n        # 2D results (assert equal columns for populations, store only one)\n        for key in ['transfer_function', 'transfer_function_fit']:\n            res = network.results[mft._prefix + key]\n            res_t = np.transpose(res)\n            assert (len(np.shape(res)) == 2) and (\n                np.isclose(res_t, res_t[0]).all())\n            tf_scan_results[key][i, j] = res[:, 0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Plotting\nWe generate a figure with three panels to show the results from scanning\nover the input.\nThe figure spans two columns.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print('Plotting.')\n\nfig = plt.figure(figsize=(params['figwidth_2cols'],\n                          params['figwidth_2cols'] / 2))\ngs = gridspec.GridSpec(1, 10, figure=fig)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, we plot results from scanning over the full ranges of working points.\nPanel A contains the fixed external rates and the predicted firing rates\nof the neuronal populations.\nPanel C contains the results from fitting the transfer function, i.e.,\nthe time constants, weights, and fit errors.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "gs_wp = gridspec.GridSpecFromSubplotSpec(\n    2, 3, subplot_spec=gs[0, :6], hspace=0.3, wspace=0)\n\nmus = params['mean_inputs_scan']  # first index\nsigmas = params['std_inputs_scan']  # second index\nmu_star = params['mean_std_inputs_stability'][0]\nsigma_star = params['mean_std_inputs_stability'][1]\n\nfor k, key in enumerate([\n    'nu_ext_exc', 'nu_ext_inh', 'firing_rates',  # panel A\n        'tau_rate', 'W_rate', 'fit_error']):  # panel C\n    ax = plt.subplot(gs_wp[k])\n    img = ax.pcolormesh(\n        np.transpose(\n            tf_scan_results[key] *\n            params['quantities'][key]['scale']))\n\n    # pcolormesh places ticks by default to lower bound, therefore add 0.5\n    ax.set_xticks(np.arange(len(mus)) + 0.5)\n    ax.set_yticks(np.arange(len(sigmas)) + 0.5)\n    ax.set_xticklabels(\n        (mus * params['quantities']['mean_input']['scale']).astype(int))\n    ax.set_yticklabels(\n        (sigmas * params['quantities']['std_input']['scale']).astype(int))\n\n    if k == 1 or k == 4:\n        ax.set_xlabel(params['quantities']['mean_input']['label'])\n\n    if k == 0 or k == 3:\n        ax.set_ylabel(params['quantities']['std_input']['label'])\n    else:\n        ax.set_yticklabels([])\n\n    xshift = -0.6\n    yshift = 0.22\n    if k == 0:\n        _add_label(ax, 'A', xshift=xshift, yshift=yshift)\n    if k == 3:\n        _add_label(ax, 'C', xshift=xshift, yshift=yshift)\n\n    cb = plt.colorbar(img)\n    cb.ax.tick_params(pad=0)\n    cb.locator = ticker.MaxNLocator(nbins=4)\n    cb.update_ticks()\n\n    # star for mu and sigma used in this circuit (0.5 offset for\n    # pcolormesh)\n    xmu = (np.max(ax.get_xticks() - 0.5) * (mu_star - np.min(mus))\n           / (np.max(mus) - np.min(mus)) + 0.5)\n    ysigma = (np.max(ax.get_yticks() - 0.5)\n              * (sigma_star - np.min(sigmas))\n              / (np.max(sigmas) - np.min(sigmas)) + 0.5)\n    ax.plot(xmu, ysigma,\n            marker='*', markerfacecolor='white', markeredgecolor='none',\n            markersize=mpl.rcParams['lines.markersize'] * 2.5)\n    ax.plot(xmu, ysigma,\n            marker='*', markerfacecolor='k', markeredgecolor='none',\n            markersize=mpl.rcParams['lines.markersize'] * 2.)\n\n    ax.set_title(params['quantities'][key]['label'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To panel B, we plot the LIF transfer function and its fit for some selected\nworking points.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "gs_tf = gridspec.GridSpecFromSubplotSpec(\n    2, 1, subplot_spec=gs[0, 7:], hspace=0)\nax_amplitude = plt.subplot(gs_tf[0])\n_add_label(ax_amplitude, 'B', xshift=-0.4, yshift=0.02)\nax_phase = plt.subplot(gs_tf[1])\n\nleg_handles, leg_labels = [[], []], []\nc = 0\nfor i, mu in enumerate(params['mean_inputs_scan']):\n    for j, sigma in enumerate(params['std_inputs_scan']):\n        if [mu, sigma] in params['mean_std_inputs_tf'].tolist():\n            cols = [params['colors'][x] for x in params['colors_tf'][c]]\n\n            transfer_function = (\n                tf_scan_results['transfer_function'][i, j]\n                * params['quantities']['transfer_function']['scale'])\n\n            transfer_function_fit = (\n                tf_scan_results['transfer_function_fit'][i, j]\n                * params['quantities']['transfer_function']['scale'])\n\n            frequencies = (\n                tf_scan_results['frequencies']\n                * params['quantities']['frequencies']['scale'])\n\n            # amplitude\n            tf_orig = np.abs(transfer_function)\n            tf_fit = np.abs(transfer_function_fit)\n            ax_amplitude.plot(frequencies, tf_orig, c=cols[0])\n            ax_amplitude.plot(\n                frequencies, tf_fit,\n                c=cols[1], linestyle='none', marker='o',\n                markersize=mpl.rcParams['lines.markersize'] * 0.1)\n            ax_amplitude.set_title(\n                params['quantities']['transfer_function']['label'])\n\n            # phase\n            tf_orig = np.arctan2(np.imag(transfer_function),\n                                 np.real(transfer_function)) * 180 / np.pi\n            tf_fit = np.arctan2(\n                np.imag(transfer_function_fit),\n                np.real(transfer_function_fit)) * 180 / np.pi\n            ax_phase.plot(frequencies, tf_orig, c=cols[0])\n            ax_phase.plot(\n                frequencies, tf_fit,\n                c=cols[1], linestyle='none', marker='o',\n                markersize=mpl.rcParams['lines.markersize'] * 0.1)\n            ax_phase.set_xlabel(\n                params['quantities']['frequencies']['label'])\n\n            for ax, ylabel in zip(\n                [ax_amplitude, ax_phase],\n                [params['quantities']['transfer_function_amplitude']['label'],\n                 params['quantities']['transfer_function_phase']['label']]):\n\n                if any(frequencies > 0):\n                    ax.set_xscale('log')\n                ax.set_ylabel(ylabel)\n                ax.set_xlim(frequencies[0], frequencies[-1])\n            ax_amplitude.set_xticklabels([])\n\n            leg_handles[0].append(Patch(facecolor=cols[0]))\n            leg_handles[1].append(Patch(facecolor=cols[1]))\n            leg_labels.append(f'({int(mu * 1e3)}, {int(sigma * 1e3)})')\n            c += 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "For panel B, we customize a legend.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "leg_handles = leg_handles[0] + leg_handles[1]\nleg_labels = [''] * len(leg_labels) + leg_labels\nax_phase.legend(\n    handles=leg_handles,\n    labels=leg_labels,\n    title=r'$(\\mu, \\sigma)$ in mV',\n    ncol=2,\n    handletextpad=0.5,\n    handlelength=1.,\n    columnspacing=-0.5)\n\nspiking = mpl.lines.Line2D([], [], color='k', label='spiking model')\nrate = mpl.lines.Line2D(\n    [], [], color='k', linestyle='none', marker='o',\n    markersize=mpl.rcParams['lines.markersize'] * 0.1,\n    label='rate model (fit)')\nax_amplitude.legend(handles=[spiking, rate])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The final figure is saved to file.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "plt.savefig('mapping_lif_rate.eps')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}