{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Spatial patterns\n\nThis example demonstrates the methods used in Figure 6 of :cite:t:`senk2020`.\nA figure illustrating the network structure of the used model is set up in\n:doc:`network_structure`.\nThe same model is used in the example :doc:`mapping_lif_rate`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import nnmt.spatial as spatial\nimport nnmt.linear_stability as linstab\nimport nnmt.lif.exp as mft  # main set of meanfield tools\nfrom nnmt.models.basic import Basic as BasicNetwork\nimport numpy as np\nimport scipy.optimize as sopt\nimport scipy.integrate as sint\nimport scipy.misc as smisc\nimport matplotlib as mpl\nimport matplotlib.gridspec as gridspec\nimport matplotlib.pyplot as plt\nplt.style.use('frontiers.mplstyle')\nmpl.rcParams.update({'legend.fontsize': 'medium',  # old: 5.0 was too small\n                     'axes.titlepad': 0.0,\n                     'figure.constrained_layout.use': False})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, we define parameters for data generation and plotting.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "params = {\n    # mean and standard deviation of input used for stability analysis (in V)\n    'mean_std_inputs_stability': np.array([10., 10.]) * 1e-3,\n\n    # labels and corresponding scaling parameters for plotted quantities\n    'quantities': {\n        'k_wavenumbers': {\n            'label': 'wavenumber $k$ (1/mm)',\n            'scale': 1e-3},\n        'eigenvalues': {\n            'label': r'eigenvalue $\\lambda$'},\n        'eigenvalues_real': {\n            'label': r'Re[$\\lambda$] (1000/s)',\n            'scale': 1e-3},\n        'eigenvalues_imag': {\n            'label': r'Im[$\\lambda$] (1000/s)',\n            'scale': 1e-3}},\n\n    # figure width in inch\n    'figwidth_1col': 85. / 25.4,\n\n    # colors for branches of Lambert W function\n    'colors_br': {0: '#994455',  # dark red\n                  -1: '#EE99AA',  # light red\n                  1: '#004488',  # dark blue\n                  -2: '#6699CC',  # light blue\n                  2: '#997700',  # dark yellow\n                  -3: '#EECC66'}}  # light yellow"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Helper functions\nHere, we define a number of helper functions which are currently considered\ntoo specific for a global integration into NNMT.\nThese functions are concerned with solving the characteristic equation of the\nspiking and rate models and interpolating between them.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def _solve_chareq_numerically_alpha(\n        lambda_rate, k, alpha, network, tau_rate, W_rate):\n    \"\"\"\n    Solves the full characteristic equation numerically.\n\n    Parameters\n    ----------\n    lambda_rate: complex float\n        Eigenvalue of rate model in 1/s.\n    k : float\n        Wave number in 1/m.\n    alpha : float\n        Interpolation parameter.\n    network : nnmt.models.Network or child class instance\n        Network instance.\n    tau_rate : np.array\n        Time constants of rate model in s.\n    W_rate : np.array\n        Weight matrix of rate model.\n\n    Returns\n    -------\n    lamb :\n        Numerically optimized eigenvalues as a function of the interpolation\n        parameter.\n    \"\"\"\n    def fsolve_complex(l_re_im):\n        lam = complex(l_re_im[0], l_re_im[1])\n\n        spatial_profile = spatial._ft_spatial_profile_boxcar(\n            k=k, width=network.network_params['width'])\n\n        eff_conn_spiking = _linalg_max_eigenvalue(\n            _effective_connectivity_spiking(lam, network) * spatial_profile)\n        eff_conn_rate = _linalg_max_eigenvalue(\n            _effective_connectivity_rate(\n                lam, tau_rate, W_rate) * spatial_profile)\n\n        eff_conn_alpha = (alpha * eff_conn_spiking\n                          + (1. - alpha) * eff_conn_rate)\n\n        roots = eff_conn_alpha * np.exp(-lam * d) - 1.\n\n        return [roots.real, roots.imag]\n\n    delay_mean = network.network_params['D_mean']\n    assert np.isscalar(delay_mean) or len(np.unique(delay_mean) == 1)\n    d = np.unique(delay_mean)[0]\n\n    lambda_guess_list = [np.real(lambda_rate), np.imag(lambda_rate)]\n    l_opt = sopt.fsolve(fsolve_complex, lambda_guess_list)\n    lamb = complex(l_opt[0], l_opt[1])\n    return lamb\n\n\ndef _solve_lambda_of_alpha_integral(\n        lambda_rate, k, alphas, network, tau_rate, W_rate):\n    \"\"\"\n    Integrates the derivative of the eigenvalue wrt. interpolation parameters.\n\n    Parameters\n    ----------\n    lambda_rate : complex float\n        Eigenvalue of rate model in 1/s.\n    k : float\n        Wave number in 1/m.\n    alphas : np.array of floats\n        All interpolation parameters.\n    network : nnmt.models.Network or child class instance\n        Network instance.\n    tau_rate : np.array\n        Time constants of rate model in s.\n    W_rate : np.array\n        Weight matrix of rate model.\n\n    Returns\n    -------\n    lambdas_of_alpha :\n        Numerically integrated eigenvalues as a function of interpolation\n        parameters.\n    \"\"\"\n    assert alphas[0] == 0, 'First alpha must be 0!'\n    lambda0_list = [lambda_rate.real, lambda_rate.imag]\n\n    def derivative(lambda_list, alpha):\n        lam = complex(lambda_list[0], lambda_list[1])\n        deriv = _d_lambda_d_alpha(lam, alpha, k, network, tau_rate, W_rate)\n        return [deriv.real, deriv.imag]\n\n    llist = sint.odeint(func=derivative, y0=lambda0_list, t=alphas)\n\n    lambdas_of_alpha = [complex(lam[0], lam[1]) for lam in llist]\n    return lambdas_of_alpha\n\n\ndef _effective_connectivity_spiking(lam, network):\n    \"\"\"\n    Computes the effective connectivity of the spiking model.\n\n    Parameters\n    ----------\n    lam : complex float\n        Eigenvalue in 1/s.\n    network : nnmt.models.Network or child class instance\n        Network instance.\n\n    Returns\n    -------\n    eff_conn :\n        Effective connectivity.\n    \"\"\"\n    omega = complex(0, -lam)\n    transfunc = mft.transfer_function(\n        network=network, freqs=np.array([omega / (2. * np.pi)]))\n\n    D = np.array([1])  # ignore delay distribution here\n    eff_conn = mft._effective_connectivity(\n        transfer_function=transfunc, D=D, J=network.network_params['J'],\n        K=network.network_params['K'], tau_m=network.network_params['tau_m'])\n    return eff_conn\n\n\ndef _effective_connectivity_rate(lam, tau_rate, W_rate):\n    \"\"\"\n    Computes the effective connectivity of the rate model.\n\n    Parameters\n    ----------\n    lam : complex float\n        Eigenvalue in 1/s.\n    network : nnmt.models.Network or child class instance\n        Network instance.\n\n    Returns\n    -------\n    eff_conn :\n        Effective connectivity.\n    \"\"\"\n    omega = complex(0, -lam)\n    eff_conn = W_rate / (1. + 1j * omega * tau_rate)\n    return eff_conn\n\n\ndef _d_lambda_d_alpha(lam, alpha, k, network, tau_rate, W_rate):\n    \"\"\"\n    Computes the derivative of the eigenvalue wrt. the interpolation parameter.\n\n    Parameters\n    ----------\n    lam : complex float\n        Eigenvalue of rate model in 1/s.\n    alpha : float\n        Interpolation parameter.\n    k : float\n        Wave number in 1/m.\n    network : nnmt.models.Network or child class instance\n        Network instance.\n    tau_rate : np.array\n        Time constants of rate model in s.\n    W_rate : np.array\n        Weight matrix of rate model.\n\n    Returns\n    -------\n    deriv :\n        Derivative.\n    \"\"\"\n    spatial_profile = spatial._ft_spatial_profile_boxcar(\n        k=k, width=network.network_params['width'])\n\n    eff_conn_spiking = _linalg_max_eigenvalue(\n        _effective_connectivity_spiking(lam, network) * spatial_profile)\n    eff_conn_rate = _linalg_max_eigenvalue(\n        _effective_connectivity_rate(lam, tau_rate, W_rate) * spatial_profile)\n\n    eff_conn_alpha = alpha * eff_conn_spiking + (1. - alpha) * eff_conn_rate\n\n    d_eff_conn_spiking_d_lambda = _linalg_max_eigenvalue(\n        _d_eff_conn_spiking_d_lambda(lam, network) * spatial_profile)\n\n    d_eff_conn_rate_d_lambda = _linalg_max_eigenvalue(\n        _d_eff_conn_rate_d_lambda(lam, tau_rate, W_rate) * spatial_profile)\n\n    d_eff_conn_alpha_d_lambda = (alpha * d_eff_conn_spiking_d_lambda\n                                 + (1. - alpha) * d_eff_conn_rate_d_lambda)\n\n    delay_mean = network.network_params['D_mean']\n    assert np.isscalar(delay_mean) or len(np.unique(delay_mean) == 1)\n    d = np.unique(delay_mean)[0]\n\n    nominator = eff_conn_spiking - eff_conn_rate\n    denominator = d_eff_conn_alpha_d_lambda - d * eff_conn_alpha\n\n    deriv = - nominator / denominator\n    return deriv\n\n\ndef _d_eff_conn_spiking_d_lambda(lam, network):\n    \"\"\"\n    Computes the derivative of the effective connectivity of the spiking model.\n\n    Parameters\n    ----------\n    l : complex float\n        Eigenvalue of rate model in 1/s.\n    network : nnmt.models.Network or child class instance\n        Network instance.\n\n    Returns\n    -------\n    deriv :\n        Derivative.\n    \"\"\"\n    def f(x):\n        return _effective_connectivity_spiking(x, network)\n    deriv = smisc.derivative(func=f, x0=lam, dx=1e-10)\n    return deriv\n\n\ndef _d_eff_conn_rate_d_lambda(lam, tau_rate, W_rate):\n    \"\"\"\n    Computes the derivative of the effective connectivity of the rate model.\n\n    Parameters\n    ----------\n    l : complex float\n        Eigenvalue of rate model in 1/s.\n    tau_rate : np.array\n        Time constants of rate model in s.\n    W_rate : np.array\n        Weight matrix of rate model.\n\n    Returns\n    -------\n    deriv:\n        Derivative.\n    \"\"\"\n    lp = 1. / (1. + lam * tau_rate)\n    deriv = -1. * W_rate * lp**2 * tau_rate\n    return deriv\n\n\ndef _linalg_max_eigenvalue(matrix):\n    \"\"\"\n    Computes the eigenvalue with the largest absolute value of a given matrix.\n\n    Note that this a general matrix operation and the eigenvalue should not be\n    confused with lambda, the temporal eigenvalue of a characteristic\n    equation.\n\n    Parameters\n    ----------\n    matrix : np.array\n        Matrix to calculate eigenvalues from.\n\n    Returns\n    -------\n    max_eigval : float\n        Maximum eigenvalue.\n    \"\"\"\n    eigvals = np.linalg.eigvals(matrix)\n    max_eigval = eigvals[np.argmax(np.abs(eigvals))]\n    return max_eigval"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We also define a helper function for adding labels to figure panels.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def _add_label(ax, label, xshift=0., yshift=0., scale_fs=1.):\n    \"\"\"\n    Adds label to plot panel given by axis.\n\n    Parameters:\n    -----------\n    ax : matplotlib.axes.Axes object\n        Axes.\n    label : str\n        Letter.\n    xshift : float\n        x-shift of label position.\n    yshift : float\n        y-shift of label position.\n    scale_fs : float\n        Scale factor for font size.\n    \"\"\"\n    label_pos = [0., 1.]\n    ax.text(label_pos[0] + xshift, label_pos[1] + yshift, '(' + label + ')',\n            ha='left', va='bottom',\n            transform=ax.transAxes, fontweight='bold',\n            fontsize=mpl.rcParams['font.size'] * scale_fs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Generating data\nWe instantiate a ``Basic`` model with a set of pre-defined network and\nanalysis parameters.\nThe relative inhibition is here g = 5 in contrast to the original Figure 5 of\n:cite:t:`senk2020` which uses g = 6.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print('Instantiating network model.')\n\nnetwork = BasicNetwork(\n    network_params='Senk2020_network_params.yaml',\n    analysis_params='Senk2020_analysis_params.yaml')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The working point is set with a given mean and standard deviation of the\ninput.\nThe excitatory and inhibitory external firing rates required to preserve this\nworking point are computed and the network parameters adjusted.\nThen, we calculate the LIF transfer function and fit it with the one of a\nlow-pass filter using a least-squares fit.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# fix working point via external rates\nnu_ext = mft.external_rates_for_fixed_input(\n    network,\n    mu_set=params['mean_std_inputs_stability'][0],\n    sigma_set=params['mean_std_inputs_stability'][1])\n\nnetwork.change_parameters(changed_network_params={'nu_ext': nu_ext},\n                          overwrite=True)\n\n# calculate transfer function and its fit\nmft.working_point(network)\nmft.transfer_function(network)\nmft.fit_transfer_function(network)\n\n# fit results\ntau_rate = network.results[mft._prefix + 'tau_rate']\nW_rate = network.results[mft._prefix + 'W_rate']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The first loop for generating data iterates over branches of the Lambert W\nfunction and wave numbers of the spatial connectivity profile.\nWith the rate model and the parameters obtained by fitting the LIF transfer\nfunction, we can calculate an anlytical solution for the eigenvalues solving\nthe characteristic equation.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print('Solving characteristic equation for rate model analytically.')\n\nbranches = sorted(params['colors_br'].keys())\nk_wavenumbers = network.analysis_params['k_wavenumbers']\neigenvalues = np.zeros((len(branches), len(k_wavenumbers)), dtype=complex)\nfor i, branch_nr in enumerate(branches):\n    for j, k_wavenumber in enumerate(k_wavenumbers):\n        connectivity = W_rate * spatial._ft_spatial_profile_boxcar(\n            k_wavenumber, network.network_params['width'])\n        eigenvalues[i, j] = (\n            linstab._solve_chareq_lambertw_constant_delay(\n                branch_nr=branch_nr, tau=tau_rate,\n                delay=network.network_params['D_mean'],\n                connectivity=connectivity))\n# index of eigenvalue with maximum real part\nidx_max = list(\n    np.unravel_index(np.argmax(eigenvalues.real), eigenvalues.shape))\n\n# if max at branch -1, swap with 0\nif branches[idx_max[0]] == -1:\n    idx_n1 = idx_max[0]  # index of current branch -1\n    idx_0 = list(branches).index(0)  # index of current branch 0\n    eigenvalues[[idx_n1, idx_0], [idx_0, idx_n1]]\n    idx_max[0] = idx_0\n\neigenval_max = eigenvalues[idx_max[0], idx_max[1]]\nk_eigenval_max = k_wavenumbers[idx_max[1]]\nidx_k_eigenval_max = idx_max[1]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Then, we perform a linear interpolation of eigenvalues from the rate model\n(analytical) to the spiking model (LIF, only numerical) by two methods and\nloop here over branch numbers:\n 1. solving the full characteristic equation numerically and\n 2. integrating the derivative of the eigenvalue with respect to the\n    interpolation parameter.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print('Linear interpolation between rate and spiking models.')\n\nalphas = np.linspace(0, 1, 5)\nlambdas_integral = np.zeros((len(branches), (len(alphas))), dtype=complex)\nlambdas_chareq = np.zeros((len(branches), len(alphas)), dtype=complex)\nfor i, branch_nr in enumerate(branches):\n    print(f'    branch number = {branch_nr}')\n\n    # evaluate all eigenvalues at k_eig_max\n    # (wavenumbers with largest real part of eigenvalue from theory)\n    lambda0 = eigenvalues[i, idx_k_eigenval_max]\n    # 1. solution by solving the characteristic equation numerically\n    for j, alpha in enumerate(alphas):\n        lambdas_chareq[i, j] = _solve_chareq_numerically_alpha(\n            lambda_rate=lambda0, k=k_eigenval_max, alpha=alpha,\n            network=network, tau_rate=tau_rate, W_rate=W_rate)\n\n    # 2. solution by solving the integral\n    lambdas_integral[i, :] = _solve_lambda_of_alpha_integral(\n        lambda_rate=lambda0, k=k_eigenval_max, alphas=alphas,\n        network=network, tau_rate=tau_rate, W_rate=W_rate)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "All results are stored in ``stability_results``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "stability_results = {\n    'branches': branches,\n    'k_wavenumbers': k_wavenumbers,\n    'eigenvalues': eigenvalues,\n    'eigenval_max': eigenval_max,\n    'k_eigenval_max': k_eigenval_max,\n    'idx_k_eigenval_max': idx_k_eigenval_max,\n    'alphas': alphas,\n    'lambdas_integral': lambdas_integral,\n    'lambdas_chareq': lambdas_chareq}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Plotting\nWe generate a figure with two panels\nThe figure spans one column.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print('Plotting.')\n\nfig = plt.figure(figsize=(params['figwidth_1col'], params['figwidth_1col']))\ngs = gridspec.GridSpec(1, 2, figure=fig)\nplt.subplots_adjust(bottom=0.19, top=0.95, left=0.15, right=0.93, wspace=1.2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To panel A, we plot eigenvalues from the rate model vs. wavenumbers.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "gs_rt = gridspec.GridSpecFromSubplotSpec(\n    2, 1, subplot_spec=gs[0, 0], hspace=0.1)\n\nax_real = plt.subplot(gs_rt[0])\n_add_label(ax_real, 'A', xshift=-0.6, yshift=0.02)\nax_imag = plt.subplot(gs_rt[1])\n\nif (params['quantities']['eigenvalues_real']['scale'] ==\n        params['quantities']['eigenvalues_imag']['scale']):\n    scale_ev = params['quantities']['eigenvalues_real']['scale']\n\nbranches = stability_results['branches']\nk_wavenumbers = stability_results['k_wavenumbers']\neigenvalues = stability_results['eigenvalues']\n\nks = k_wavenumbers * params['quantities']['k_wavenumbers']['scale']\nfor i, branch_nr in enumerate(branches):\n    ax_real.plot(ks, np.real(eigenvalues)[i] * scale_ev,\n                 color=params['colors_br'][branch_nr],\n                 label=branch_nr)\n    ax_real.set_ylabel(params['quantities']['eigenvalues_real']['label'],\n                       labelpad=5.5)\n\n    ax_imag.plot(ks, np.imag(eigenvalues)[i] * scale_ev,\n                 color=params['colors_br'][branch_nr])\n    ax_imag.set_ylabel(params['quantities']['eigenvalues_imag']['label'],\n                       labelpad=0)\n\nax_real.set_title(params['quantities']['eigenvalues']['label'])\nax_real.set_xticklabels([])\n# add whitespace via new lines to match axes of alpha plot\nax_imag.set_xlabel(params['quantities']['k_wavenumbers']['label'] + '\\n\\n')\n\n# legend\nlabels = ['0', '-1', '1', '-2', '2', '-3']  # ordered\nhandles_old, labels_old = ax_real.get_legend_handles_labels()\nhandles = []\nfor lam in labels:\n    for i, lo in enumerate(labels_old):\n        if lam == lo:\n            handles.append(handles_old[i])\nax_real.legend(handles, labels, title='branch number', ncol=3,\n               columnspacing=0.1, loc='center', bbox_to_anchor=(0.55, 0.2))\n\n# find index where imag. of principle branch becomes 0 for xlims\nlambdas_imag = np.imag(eigenvalues)\nidx_b0 = np.where(np.array(branches) == 0)[0][0]\n# first index where imag. goes to 0\nidx_0 = np.where(np.array(lambdas_imag[idx_b0]) == 0)[0][0]\noffset = 5  # manual offset\nklim = ks[idx_0 - offset]\n\nfor ax in [ax_real, ax_imag]:\n    ax.axhline(0, linestyle='-', color='k',\n                            linewidth=mpl.rcParams['lines.linewidth'] * 0.5)\n    ax.set_xlim(ks[0], klim)\n\n# star for maximum real part\nfor ax, fun in zip([ax_real, ax_imag],\n                   [np.real, np.imag]):\n    ax.plot(stability_results['k_eigenval_max']\n            * params['quantities']['k_wavenumbers']['scale'],\n            np.abs(fun(stability_results['eigenval_max']) * scale_ev),\n            marker='*', markerfacecolor='k', markeredgecolor='none',\n            markersize=mpl.rcParams['lines.markersize'] * 2.)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To panel B, we plot the linear interpolation of eigenvalues.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "gs_int = gridspec.GridSpecFromSubplotSpec(\n    2, 1, subplot_spec=gs[0, 1], hspace=0.1)\n\nax_real = plt.subplot(gs_int[0])\n_add_label(ax_real, 'B', xshift=-0.65, yshift=0.02)\nax_imag = plt.subplot(gs_int[1])\n\nif (params['quantities']['eigenvalues_real']['scale'] ==\n        params['quantities']['eigenvalues_imag']['scale']):\n    scale_ev = params['quantities']['eigenvalues_real']['scale']\n\nbranches = stability_results['branches']\nalphas = stability_results['alphas']\nlambdas_integral = stability_results['lambdas_integral']\nlambdas_chareq = stability_results['lambdas_chareq']\n\nxlim = [-0.1, 1.1]\nfor i, branch_nr in enumerate(branches):\n    for ax, fun, lab in zip([ax_real, ax_imag],\n                            [np.real, np.imag],\n                            ['real', 'imag']):\n        ax.plot(alphas,\n                fun(lambdas_integral[i]) * scale_ev,\n                linestyle='',\n                color=params['colors_br'][branch_nr],\n                marker='o',\n                markersize=mpl.rcParams['lines.markersize'] * 1.5,\n                markeredgecolor='none')\n        ax.plot(alphas, fun(lambdas_chareq[i]) * scale_ev,\n                linestyle='-',\n                color=params['colors_br'][branch_nr],\n                markeredgecolor='none')\n        ax.set_ylabel(params['quantities']['eigenvalues_' + lab]['label'],\n                      labelpad=0)\n\n        # lambda = 0\n        ax.plot(xlim, [0, 0], 'k-',\n                linewidth=mpl.rcParams['lines.linewidth'] * 0.5)\n        ax.set_xlim(xlim[0], xlim[1])\n\n        # star marker\n        if branch_nr == 0:\n            ax.plot(0,\n                    np.abs(fun(stability_results['eigenval_max']))\n                    * scale_ev,\n                    marker='*',\n                    markerfacecolor='k',\n                    markeredgecolor='none',\n                    markersize=mpl.rcParams['lines.markersize'] * 2.)\n\nxticks = [0, 0.5, 1]\nax_real.set_title(params['quantities']['eigenvalues']['label'])\nax_real.set_xticks(xticks)\nax_real.set_xticklabels([])\nax_imag.set_xlabel(r'interpolation parameter $\\alpha$')\nax_imag.set_xticks(xticks)\nax_imag.set_xticklabels(['0.0\\nrate\\nmodel', '0.5', '1.0\\nspiking\\nmodel'])\n\n# legend for symbols\nintegral = mpl.lines.Line2D([], [], color='k',\n                            marker='o', linestyle='', label='integral')\nchareq = mpl.lines.Line2D([], [], color='k',\n                          marker=None, linestyle='-', label='char. eq.')\nax_real.legend(\n    handles=[integral, chareq], bbox_to_anchor=(0.6, 0.7), loc='center')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The final figure is saved to file.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "plt.savefig('spatial_patterns.eps')"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}