{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Transfer Functions (Schuecker 2015)\n\nHere we calculate the transfer functions as in :cite:t:`schuecker2015`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import nnmt\nimport numpy as np\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\nimport matplotlib.gridspec as gridspec\nfrom collections import defaultdict\nimport matplotlib.ticker\n\nplt.style.use('frontiers.mplstyle')\nmpl.rcParams.update({'legend.fontsize': 'medium',  # old: 5.0 was too small\n                     'axes.titlepad': 0.0})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The parameters used for calculation of the transfer functions\nin :cite:t:`schuecker2015` were gathered in a .yaml-File and are loaded here.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "params = nnmt.input_output.load_val_unit_dict_from_yaml(\n        '../../tests/fixtures/integration/config/'\n        'Schuecker2015_parameters.yaml')\n\n# without converting to si\nnetwork_params = params.copy()\nnnmt.utils._strip_units(network_params)\n\n# converting to si\nsi_network_params = params.copy()\nnnmt.utils._convert_to_si_and_strip_units(si_network_params)\n\nfrequencies = np.logspace(\n        si_network_params['f_start_exponent'],\n        si_network_params['f_end_exponent'],\n        si_network_params['n_freqs'])\n# add the zero frequency\nfrequencies = np.insert(frequencies, 0, 0.0)\nomegas = 2 * np.pi * frequencies\n\nindices = [1,2]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Calculate results for different input means and standard deviations.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "absolute_values = []\nphases = []\ntransfer_function_zero_freqs = []\nnu0_fbs = []\nfor i, index in enumerate(indices):\n    # Stationary firing rates for filtered synapses (via shift)\n    nu0_fb = nnmt.lif.exp._firing_rate_shift(\n        si_network_params[f'mean_input_{index}'],\n        si_network_params[f'sigma_{index}'],\n        si_network_params['V_reset'],\n        si_network_params['theta'],\n        si_network_params['tau_m'],\n        si_network_params['tau_r'],\n        si_network_params['tau_s'])\n\n    transfer_function = nnmt.lif.exp._transfer_function(\n        si_network_params[f'mean_input_{index}'],\n        si_network_params[f'sigma_{index}'],\n        si_network_params['tau_m'],\n        si_network_params['tau_s'],\n        si_network_params['tau_r'],\n        si_network_params['theta'],\n        si_network_params['V_reset'],\n        omegas,\n        method='shift',\n        synaptic_filter=False)\n\n    # the result is returned in SI-units (1/(s*V))\n    # the original figure in the paper is in (1/(s*mV))\n    transfer_function /= 1000\n\n    # calculate properties plotted in Schuecker 2015\n    absolute_value = np.abs(transfer_function)\n    phase = (np.angle(transfer_function)\n                / 2 / np.pi * 360)\n\n    # collect all results\n    absolute_values.append(absolute_value)\n    phases.append(phase)\n    nu0_fbs.append(nu0_fb)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Prepare data for plotting by parsing into a dictionary.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pre_results = dict(\n    absolute_values=absolute_values,\n    phases=phases,\n    nu0_fbs=nu0_fbs)\n\ntest_results = defaultdict(str)\ntest_results['sigma'] = defaultdict(dict)\nfor i, index in enumerate(indices):\n    sigma = network_params[f'sigma_{index}']\n    test_results['sigma'][sigma]['mu'] = (\n        defaultdict(dict))\n    for j, mu in enumerate(network_params[f'mean_input_{index}']):\n        test_results['sigma'][sigma]['mu'][mu] = {\n                'absolute_value': pre_results['absolute_values'][i][:, j],\n                'phase': pre_results['phases'][i][:, j],\n                'nu0_fb': pre_results['nu0_fbs'][i][j]}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plotting\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "width = 180. / 25.4\nheight = 75. / 25.4\n\nfig = plt.figure(figsize=(width, height))\n\ngrid_specification = gridspec.GridSpec(1, 2, figure=fig)\n\naxA = fig.add_subplot(grid_specification[0])\naxB = ax = fig.add_subplot(grid_specification[1])\n\nfor sigma in test_results['sigma'].keys():\n    for i, mu in enumerate(test_results['sigma'][sigma]['mu'].keys()):\n\n        print(sigma)\n        colors = ['black', 'grey']\n        lw = 1\n        markersize_cross = 4\n\n        # shift the zero frequency to be plotted on log-scale\n        zero_freq = 0.06\n\n        if sigma == 4.0:\n            ls = '-'\n        else:\n            ls = '--'\n\n        firing_rate = round(test_results['sigma'][sigma]['mu'][mu]['nu0_fb'])\n        # excluding zero frequency\n        axA.semilogx(frequencies[1:],\n                        test_results['sigma'][sigma]['mu'][mu]['absolute_value'][1:],\n                        color=colors[i],\n                        linestyle=ls,\n                        linewidth=lw)\n        axB.semilogx(frequencies,\n                        test_results['sigma'][sigma]['mu'][mu]['phase'],\n                        color=colors[i],\n                        linestyle=ls,\n                        linewidth=lw,\n                        label=f'({np.round(mu, 1)}, {sigma})')\n        # just zero frequency\n        # axA.semilogx(zero_freq,\n        #              test_results['sigma'][sigma]['mu'][mu][\n        #              'absolute_value'][0],\n        #              '+',\n        #              color=colors[i],\n        #              markersize=markersize_cross)\n\naxA.set_xlabel(r'frequency $\\omega/2\\pi\\quad(1/\\mathrm{s})$')\naxA.set_ylabel(r'amplitude $|N_{\\mathrm{cn}}\\left(\\omega\\right)|\\quad(\\mathrm{s}\\cdot\\mathrm{mV})^{-1}$'\n               ,labelpad = 0)\n\naxB.set_xlabel(r'frequency $\\omega/2\\pi\\quad(1/\\mathrm{s})$')\naxB.set_ylabel(r'phase $\\angle N_{\\mathrm{cn}}\\left(\\omega\\right)\\quad(^{\\circ})$'\n               ,labelpad = 2)\n\naxA.set_xticks([1e-1, 1e0, 1e1, 1e2])\naxA.set_yticks([0, 6, 12])\n\naxB.set_xticks([1e-1, 1e0, 1e1, 1e2])\naxB.set_yticks([-60, -30, 0])\n\nlabel_prms = dict(x=-0.05, y=1.1, fontsize=10, fontweight='bold',\n                  va='top', ha='right')\naxA.text(s='A', transform=axA.transAxes, **label_prms)\naxB.text(s='B', transform=axB.transAxes, **label_prms)\n\n\nx_minor = matplotlib.ticker.LogLocator(\n    base = 10.0,\n    subs = np.arange(1.0, 10.0) * 0.1,\n    numticks = 10)\naxA.xaxis.set_minor_locator(x_minor)\naxA.xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())\naxB.xaxis.set_minor_locator(x_minor)\naxB.xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())\n\naxB.legend(title='$(\\mu, \\sigma)$ in mV', title_fontsize=None,\n          handlelength=2, labelspacing=0.0,\n          loc='lower left')\n\nplt.savefig('transfer_functions_Schuecker2015.eps')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.13"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}