Source code for bmtk.analyzer.spike_trains

import os
import numpy as np
import pandas as pd
from functools import partial
from six import string_types

from bmtk.utils import sonata
from bmtk.utils.sonata.config import SonataConfig
from bmtk.utils.reports import SpikeTrains
from bmtk.utils.reports.spike_trains import plotting
from bmtk.simulator.utils import simulation_reports


def _find_spikes(spikes_file=None, config_file=None, population=None):
    candidate_spikes = []

    # Get spikes file(s)
    if spikes_file:
        # User has explicity set the location of the spike files
        candidate_spikes.append(spikes_file)

    elif config_file is not None:
        # Otherwise search the config.json for all possible output spikes_files. We can use the simulation_reports
        # module to find any spikes output file specified in config's "output" or "reports" section.
        config = SonataConfig.from_json(config_file)
        sim_reports = simulation_reports.from_config(config)
        for report in sim_reports:
            if report.module == 'spikes_report':
                # BMTK can end up output the same spikes file in SONATA, CSV, and NWB format. Try fetching the SONATA
                # version first, then CSV, and finally NWB if it exists.
                spikes_sonata = report.params.get('spikes_file', None)
                spikes_csv = report.params.get('spikes_file_csv', None)
                spikes_nwb = report.params.get('spikes_file_nwb', None)

                if spikes_sonata is not None:
                    candidate_spikes.append(spikes_sonata)
                elif spikes_csv is not None:
                    candidate_spikes.append(spikes_csv)
                elif spikes_csv is not None:
                    candidate_spikes.append(spikes_nwb)

        # TODO: Should we also look in the "inputs" for displaying input spike statistics?

    if not candidate_spikes:
        raise ValueError('Could not find an output spikes-file. Use "spikes_file" parameter option.')

    # Find file that contains spikes for the specified "population" of nodes. If "population" parameter is not
    # specified try to guess that spikes that the user wants to visualize.
    if population is not None:
        spikes_obj = None
        for spikes_f in candidate_spikes:
            st = SpikeTrains.load(spikes_f)
            if population in st.populations:
                if spikes_obj is None:
                    spikes_obj = st
                else:
                    spikes_obj.merge(st)

        if spikes_obj is None:
            raise ValueError('Could not fine spikes file with node population "{}".'.format(population))
        else:
            return population, spikes_obj

    else:
        if len(candidate_spikes) > 1:
            raise ValueError('Found more than one spike-trains file')

        spikes_f = candidate_spikes[0]
        if not os.path.exists(spikes_f):
            raise ValueError('Did not find spike-trains file {}. Make sure the simulation has completed.'.format(
                spikes_f))

        spikes_obj = SpikeTrains.load(spikes_f)

        if len(spikes_obj.populations) > 1:
            raise ValueError('Spikes file {} contains more than one node population.'.format(spikes_f))
        else:
            return spikes_obj.populations[0], spikes_obj


def _find_nodes(population, config=None, nodes_file=None, node_types_file=None):
    if nodes_file is not None:
        network = sonata.File(data_files=nodes_file, data_type_files=node_types_file)
        if population not in network.nodes.population_names:
            raise ValueError('node population "{}" not found in {}'.format(population, nodes_file))
        return network.nodes[population]

    elif config is not None:
        for nodes_grp in config.nodes:
            network = sonata.File(data_files=nodes_grp['nodes_file'], data_type_files=nodes_grp['node_types_file'])
            if population in network.nodes.population_names:
                return network.nodes[population]

    raise ValueError('Could not find nodes file with node population "{}".'.format(population))


def _plot_helper(plot_fnc, config_file=None, population=None, times=None, title=None, show=True, save_as=None,
                 group_by=None, group_excludes=None,
                 spikes_file=None, nodes_file=None, node_types_file=None):
    sonata_config = SonataConfig.from_json(config_file) if config_file else None
    pop, spike_trains = _find_spikes(config_file=config_file, spikes_file=spikes_file, population=population)

    # Create the title
    title = title if title is not None else "Nodes in network '{}'".format(pop)

    # Get start and stop times from config if needed
    if sonata_config and times is None:
        times = (sonata_config.tstart, sonata_config.tstop)

    # Create node-groups
    if group_by is not None:
        node_groups = []
        nodes = _find_nodes(population=pop, config=sonata_config, nodes_file=nodes_file,
                            node_types_file=node_types_file)
        grouped_df = None
        for grp in nodes.groups:
            if group_by in grp.all_columns:
                grp_df = grp.to_dataframe()
                grp_df = grp_df[['node_id', group_by]]
                grouped_df = grp_df if grouped_df is None else pd.concat([grouped_df, grp_df], ignore_index=True)

        if grouped_df is None:
            raise ValueError('Could not find any nodes with group_by attribute "{}"'.format(group_by))

        # Convert from string to list so we can always use the isin() method for filtering
        if isinstance(group_excludes, string_types):
            group_excludes = [group_excludes]
        elif group_excludes is None:
            group_excludes = []

        for grp_key, grp in grouped_df.groupby(group_by):
            if grp_key in group_excludes:
                continue
            node_groups.append({'node_ids': np.array(grp['node_id']), 'label': grp_key})

    else:
        node_groups = None

    return plot_fnc(
        spike_trains=spike_trains, node_groups=node_groups, population=pop, times=times, title=title, show=show,
        save_as=save_as
    )


[docs]def plot_raster(config_file=None, population=None, with_histogram=True, times=None, title=None, show=True, save_as=None, group_by=None, group_excludes=None, spikes_file=None, nodes_file=None, node_types_file=None, plt_style=None): """Create a raster plot (plus optional histogram) from the results of the simulation. Will using the SONATA simulation configs "output" section to locate where the spike-trains file was created and display them:: plot_raster(config_file='config.json') If the path the the report is different (or missing) than what's in the SONATA config then use the "spikes_file" option instead:: plot_raster(spikes_file='/my/path/to/membrane_potential.h5') You may also group together different subsets of nodes using specific attributes of the network using the "group_by" option, and the "group_excludes" option to exclude specific subsets. For example to color and label different subsets of nodes based on their cortical "layer", but exlcude plotting the L1 nodes:: plot_raster(config_file='config.json', groupy_by='layer', group_excludes='L1') :param config_file: path to SONATA simulation configuration. :param population: name of the membrane_report "report" which will be plotted. If only one compartment report in the simulation config then function will find it automatically. :param with_histogram: If True the a histogram will be shown as a small subplot below the scatter plot. Default True. :param times: (float, float), start and stop times of simulation. By default will get values from simulation configs "run" section. :param title: str, adds a title to the plot. If None (default) then name will be automatically generated using the report_name. :param show: bool to display or not display plot. default True. :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not save plot. :param group_by: Attribute of the "nodes" file used to group and average subsets of nodes. :param group_excludes: list of strings or None. When using the "group_by", allows users to exclude certain groupings based on the attribute value. :param spikes_file: Path to SONATA spikes file. Do not use with "config_file" options. :param nodes_file: path to nodes hdf5 file containing "population". By default this will be resolved using the config. :param node_types_file: path to node-types csv file containing "population". By default this will be resolved using the config. :return: matplotlib figure.Figure object """ plot_fnc = partial(plotting.plot_raster, with_histogram=with_histogram, plt_style=plt_style) return _plot_helper( plot_fnc, config_file=config_file, population=population, times=times, title=title, show=show, save_as=save_as, group_by=group_by, group_excludes=group_excludes, spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file )
[docs]def plot_rates(config_file=None, population=None, smoothing=False, smoothing_params=None, times=None, title=None, show=True, save_as=None, group_by=None, group_excludes=None, spikes_file=None, nodes_file=None, node_types_file=None, plt_style=None): """Calculate and plot the rates of each node recorded during the simulation - averaged across the entirety of the simulation. Will using the SONATA simulation configs "output" section to locate where the spike-trains file was created and display them:: plot_rates(config_file='config.json') If the path the the report is different (or missing) than what's in the SONATA config then use the "spikes_file" option instead:: plot_rates(spikes_file='/my/path/to/membrane_potential.h5') You may also group together different subsets of nodes using specific attributes of the network using the "group_by" option, and the "group_excludes" option to exclude specific subsets. For example to color and label different subsets of nodes based on their cortical "layer", but exlcude plotting the L1 nodes:: plot_rates(config_file='config.json', groupy_by='layer', group_excludes='L1') :param config_file: path to SONATA simulation configuration. :param population: name of the membrane_report "report" which will be plotted. If only one compartment report in the simulation config then function will find it automatically. :param smoothing: Bool or function. Used to smooth the data. By default (False) no smoothing will be done. If True will using a moving average smoothing function. Or use a function pointer. :param smoothing_params: dict, parameters when using a function pointer smoothing value. :param times: (float, float), start and stop times of simulation. By default will get values from simulation configs "run" section. :param title: str, adds a title to the plot. If None (default) then name will be automatically generated using the report_name. :param show: bool to display or not display plot. default True. :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not save plot. :param group_by: Attribute of the "nodes" file used to group and average subsets of nodes. :param group_excludes: list of strings or None. When using the "group_by", allows users to exclude certain groupings based on the attribute value. :param spikes_file: Path to SONATA spikes file. Do not use with "config_file" options. :param nodes_file: Path to nodes hdf5 file containing "population". By default this will be resolved using the config. :param node_types_file: Path to node-types csv file containing "population". By default this will be resolved using the config. :return: matplotlib figure.Figure object """ plot_fnc = partial(plotting.plot_rates, smoothing=smoothing, smoothing_params=smoothing_params, plt_style=plt_style) return _plot_helper( plot_fnc, config_file=config_file, population=population, times=times, title=title, show=show, save_as=save_as, group_by=group_by, group_excludes=group_excludes, spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file )
[docs]def plot_rates_boxplot(config_file=None, population=None, times=None, title=None, show=True, save_as=None, group_by=None, group_excludes=None, spikes_file=None, nodes_file=None, node_types_file=None, plt_style=None): """Creates a box plot of the firing rates taken from nodes recorded during the simulation. Will using the SONATA simulation configs "output" section to locate where the spike-trains file was created and display them:: plot_rates_boxplot(config_file='config.json') If the path the the report is different (or missing) than what's in the SONATA config then use the "spikes_file" option instead:: plot_rates_boxplot(spikes_file='/my/path/to/membrane_potential.h5') You may also group together different subsets of nodes using specific attributes of the network using the "group_by" option, and the "group_excludes" option to exclude specific subsets. For example to color and label different subsets of nodes based on their cortical "layer", but exlcude plotting the L1 nodes:: plot_rates_boxplot(config_file='config.json', groupy_by='layer', group_excludes='L1') :param config_file: path to SONATA simulation configuration. :param population: name of the membrane_report "report" which will be plotted. If only one compartment report in the simulation config then function will find it automatically. :param times: (float, float), start and stop times of simulation. By default will get values from simulation configs "run" section. :param title: str, adds a title to the plot. If None (default) then name will be automatically generated using the report_name. :param show: bool to display or not display plot. default True. :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not save plot. :param group_by: Attribute of the "nodes" file used to group and average subsets of nodes. :param group_excludes: list of strings or None. When using the "group_by", allows users to exclude certain groupings based on the attribute value. :param spikes_file: Path to SONATA spikes file. Do not use with "config_file" options. :param nodes_file: Path to nodes hdf5 file containing "population". By default this will be resolved using the config. :param node_types_file: Path to node-types csv file containing "population". By default this will be resolved using the config. :return: matplotlib figure.Figure object """ plot_fnc = partial(plotting.plot_rates_boxplot, plt_style=plt_style) return _plot_helper( plot_fnc, config_file=config_file, population=population, times=times, title=title, show=show, save_as=save_as, group_by=group_by, group_excludes=group_excludes, spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file )
[docs]def spike_statistics(spikes_file, simulation=None, population=None, simulation_time=None, group_by=None, network=None, config_file=None, **filterparams): """Get spike statistics (firing_rate, counts, inter-spike interval) of the nodes. :param spikes_file: Path to SONATA spikes file. Do not use with "config_file" options. :param simulation: :param population: :param simulation_time: :param groupby: :param network: :param config_file: :param filterparams: :return: pandas dataframe """ # TODO: Should be implemented in bmtk.utils.spike_trains.stats.py pop, spike_trains = _find_spikes(config_file=config_file, spikes_file=spikes_file, population=population) # spike_trains = SpikeTrains.load(spikes_file) def calc_stats(r): d = {} vals = np.sort(r['timestamps']) diffs = np.diff(vals) if diffs.size > 0: d['isi'] = np.mean(np.diff(vals)) else: d['isi'] = 0.0 d['count'] = len(vals) return pd.Series(d, index=['count', 'isi']) spike_counts_df = spike_trains.to_dataframe().groupby(['population', 'node_ids']).apply(calc_stats) spike_counts_df = spike_counts_df.rename({'timestamps': 'counts'}, axis=1) spike_counts_df.index.names = ['population', 'node_id'] if simulation is not None: nodes_df = simulation.net.node_properties(**filterparams) sim_time_s = simulation.simulation_time(units='s') spike_counts_df['firing_rate'] = spike_counts_df['count'] / sim_time_s vals_df = pd.merge(nodes_df, spike_counts_df, left_index=True, right_index=True, how='left') vals_df = vals_df.fillna({'count': 0.0, 'firing_rate': 0.0, 'isi': 0.0}) vals_df = vals_df.groupby(group_by)[['firing_rate', 'count', 'isi']].agg([np.mean, np.std]) return vals_df else: return spike_counts_df
[docs]def to_dataframe(config_file, spikes_file=None, population=None): """ :param config_file: :param spikes_file: :param population: :return: """ # _, spike_trains = _find_spikes(config_file=config_file, spikes_file=spikes_file, population=population) pop, spike_trains = _find_spikes(config_file=config_file, spikes_file=spikes_file, population=population) return spike_trains.to_dataframe()