{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Sensitivity Measure (Bos 2016)\n\nHere we calculate the sensitivity measure of the :cite:t:`potjans2014` \nmicrocircuit model including modifications made in :cite:t:`bos2016`.\n\nThis example reproduces Fig. 6 and 7 in :cite:t:`bos2016`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import nnmt\nimport numpy as np\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\nimport matplotlib.gridspec as gridspec\nfrom mpl_toolkits.axes_grid1 import make_axes_locatable\nfrom mpl_toolkits.axes_grid1.inset_locator import InsetPosition\n\n\ndef colorbar(mappable, cax=None):\n    ax = mappable.axes\n    fig = ax.figure\n    divider = make_axes_locatable(ax)\n    if cax==None:\n        cax = divider.append_axes(\"right\", size=\"5%\", pad=0.05)\n    return fig.colorbar(mappable, cax=cax)\n\nplt.style.use('frontiers.mplstyle')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, create an instance of the network model class `Microcircuit`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "microcircuit = nnmt.models.Microcircuit(\n    network_params=\n    '../../tests/fixtures/integration/config/Bos2016_network_params.yaml',\n    analysis_params=\n    '../../tests/fixtures/integration/config/Bos2016_analysis_params.yaml')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The frequency resolution used in the original publication was quite high.\nHere, we reduce the frequency resolution for faster execution.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "reduce_frequency_resolution = False\n\nif reduce_frequency_resolution:\n    microcircuit.change_parameters(changed_analysis_params={'df': 1},\n                                overwrite=True)\n    derived_analysis_params = (\n        microcircuit._calculate_dependent_analysis_parameters())\n    microcircuit.analysis_params.update(derived_analysis_params)\n\nfrequencies = microcircuit.analysis_params['omegas']/(2.*np.pi)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Calculate all necessary quantities and finally the sensitivity \nmeasure for all eigenmodes.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# calculate working point for exponentially shape post synaptic currents\nnnmt.lif.exp.working_point(microcircuit, method='taylor')\n# calculate the transfer function\nnnmt.lif.exp.transfer_function(microcircuit, method='taylor')\n# calculate the delay distribution matrix\nnnmt.network_properties.delay_dist_matrix(microcircuit)\n# calculate the effective connectivity matrix\nnnmt.lif.exp.effective_connectivity(microcircuit)\n\nsensitivity_dict = nnmt.lif.exp.sensitivity_measure_all_eigenmodes(\n    microcircuit)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Here, we print the necessary entries of the sensitivity measure \ndictionary to see which eigenvalues are needed to reproduce \nFig.6 and Fig.7 of Bos 2016\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for i in range(8):\n    i = str(i)\n    print(sensitivity_dict[i]['critical_frequency'])\n    print(sensitivity_dict[i]['critical_eigenvalue'])\n    print(sensitivity_dict[i]['k'])\n    print(sensitivity_dict[i]['k_per'])    \n    \n# We identified indices manually. \neigenvalues_to_plot_high = [str(i) for i in [1, 0, 3, 2]]\neigenvalue_to_plot_low = str(6)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plotting: Sensitivity Measure corresponding to high frequency peak (Fig. 6)\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# two column figure, 180 mm wide\nwidth = 180. / 25.4 \nheight = 90. / 25.4\n\nfig = plt.figure(figsize=(width, height),\n                 constrained_layout=True)\ngrid_specification = gridspec.GridSpec(2, 2, figure=fig)\n\nlabels = ['2/3E', '2/3I', '4E', '4I', '5E', '5I', '6E', '6I']\n\ncolormap = mpl.cm.get_cmap('coolwarm').copy()\n\n# set colorbar max and min\nz = 1\n\n# by default np.nan are set to black with full transparency\n# .eps can't handle transparency\ncolormap.set_bad('w',1.)\n\nlabel_prms = dict(x=-0.3, y=1.2, fontsize=10, fontweight='bold',\n                  va='top', ha='right')\npanel_labels = ['(A)', '(B)', '(C)', '(D)']\n\nfor count, (ev, subpanel, panel_label) in enumerate(\n    zip(eigenvalues_to_plot_high, grid_specification, panel_labels)):\n\n    gs = gridspec.GridSpecFromSubplotSpec(1,3, \n                                          height_ratios=[1],\n                                width_ratios=[1, 1, 0.2], \n                                subplot_spec=subpanel)\n\n    # sensitivity_measure_amplitude\n    ax = fig.add_subplot(gs[0])\n    ax.text(s=panel_label, transform=ax.transAxes, **label_prms)\n    \n    frequency = sensitivity_dict[ev]['critical_frequency']\n    projection_of_sensitivity_measure = sensitivity_dict[ev][\n        'sensitivity_amp']\n    \n    rounded_frequency = str(int(np.round(frequency,0)))\n\n    plot_title = r'$\\mathbf{Z}_{j=%s}^{\\mathrm{amp}}(' % ev + \\\n        f'{rounded_frequency}' + r'\\,\\mathrm{Hz})$'\n    ax.set_title(plot_title)\n\n    data = np.ma.masked_where(projection_of_sensitivity_measure == 0,\n                              projection_of_sensitivity_measure)\n\n    # np.flipud needed to cope for different origin compared to imshow\n    heatmap = ax.pcolormesh(np.flipud(data),\n                        vmin=-z,\n                        vmax=z,\n                        cmap=colormap,\n                        edgecolors='k',\n                        linewidth=0.6)\n    \n    ax.set_aspect('equal')\n    \n    if labels is not None:\n        ax.set_xticks(np.arange(len(labels))+0.5)\n        ax.set_yticks(np.arange(len(labels))+0.5)\n        ax.set_xticklabels(labels, rotation=90)\n        ax.set_yticklabels(list(reversed(labels)))\n\n    ax.set_xlabel('sources')\n    ax.set_ylabel('targets')\n    \n    # sensitivity_measure_frequency\n    ax = fig.add_subplot(gs[1])\n    \n    frequency = sensitivity_dict[ev]['critical_frequency']\n    projection_of_sensitivity_measure = sensitivity_dict[ev][\n        'sensitivity_freq']\n    \n    rounded_frequency = str(int(np.round(frequency,0)))\n\n    plot_title = r'$\\mathbf{Z}_{j=%s}^{\\mathrm{freq}}(' % ev + \\\n        f'{rounded_frequency}' + r'\\,\\mathrm{Hz})$'\n    ax.set_title(plot_title)\n    \n    data = np.ma.masked_where(projection_of_sensitivity_measure == 0,\n                              projection_of_sensitivity_measure)\n\n    heatmap = ax.pcolormesh(np.flipud(data),\n                        vmin=-z,\n                        vmax=z,\n                        cmap=colormap,    \n                        edgecolors='k',\n                        linewidth=0.6)\n    \n    ax.set_aspect('equal')\n    \n    if labels is not None:\n        ax.set_xticks(np.arange(len(labels))+0.5)\n        ax.set_yticks(np.arange(len(labels))+0.5)\n        ax.set_xticklabels(labels, rotation=90)\n        ax.set_yticklabels([])\n\n    ax.set_xlabel('sources')\n        \n    colorbar_ax = fig.add_subplot(gs[2])\n    \n    colorbar_width = 0.1\n    ip = InsetPosition(ax, [1.05,0,colorbar_width,1]) \n    colorbar_ax.set_axes_locator(ip)\n    colorbar(heatmap, cax=colorbar_ax)\n\nfig.set_constrained_layout_pads(w_pad=0, h_pad=0,\n                                hspace=0.1, wspace=0.1)    \n    \nplt.savefig('figures/sensitivity_measure_high_gamma_Bos2016.eps')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plotting: Sensitivity Measure corresponding to low frequencies (Fig. 7)\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# two column figure, 180 mm wide\nfig = plt.figure(figsize=(3.34646, 3.34646/2),\n                 constrained_layout=True)\ncolormap = 'coolwarm'\nlabels = ['23E', '23I', '4E', '4I', '5E', '5I', '6E', '6I']\nplt.rcParams['xtick.labelsize'] = 'x-small'\nplt.rcParams['ytick.labelsize'] = 'x-small'\n\nax = fig.add_subplot(111)\n\nev = eigenvalue_to_plot_low\n\nfrequency = sensitivity_dict[ev]['critical_frequency']\nprojection_of_sensitivity_measure = sensitivity_dict[ev][\n    'sensitivity_amp']\n\n# obtain maximal absolute value\nz = np.max(abs(projection_of_sensitivity_measure))\nrounded_frequency = str(np.round(frequency,2))\n\nplot_title = r'$\\mathbf{Z}^{\\mathrm{amp}}(' + \\\n        f'{rounded_frequency}' + r'\\,\\mathrm{Hz})$'\nax.set_title(plot_title)\n\nheatmap = ax.imshow(projection_of_sensitivity_measure,\n                    vmin=-z,\n                    vmax=z,\n                    cmap=colormap)\n\ncolorbar(heatmap)\nif labels is not None:\n    ax.set_xticks(np.arange(len(labels)))\n    ax.set_yticks(np.arange(len(labels)))\n    ax.set_xticklabels(labels)\n    ax.set_yticklabels(labels)\n\nax.set_xlabel('sources')\nax.set_ylabel('targets')\n\nplt.savefig('figures/sensitivity_measure_0Hz_Bos2016.png', \n            bbox_inches='tight')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}