Source code for bmtk.utils.reports.spike_trains.plotting

# Copyright 2020. Allen Institute. All rights reserved
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
# following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
# disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import numpy as np
import six
import matplotlib.pyplot as plt
import types
import copy
from functools import partial

from .spike_trains import SpikeTrains
from .spike_trains_api import SpikeTrainsAPI

from matplotlib.ticker import MaxNLocator


def __get_spike_trains(spike_trains):
    """Make sure SpikeTrainsAPI object is always returned"""
    if isinstance(spike_trains, six.string_types):
        # Load spikes from file
        return SpikeTrains.load(spike_trains)

    elif isinstance(spike_trains, (SpikeTrains, SpikeTrainsAPI)):
        return spike_trains

    raise AttributeError('Could not parse spiketrains. Pass in file-path, SpikeTrains object, or list of the previous')


def __get_population(spike_trains, population):
    """Helper function to figure out which population of nodes to use."""
    pops = spike_trains.populations
    if population is None:
        # If only one population exists in spikes object/file select that one
        if len(pops) > 1:
            raise Exception('SpikeTrains contains more than one population of nodes. Use "population" parameter '
                            'to specify population to display.')

        else:
            return pops[0]

    elif population not in pops:
        raise Exception('Could not find node population "{}" in SpikeTrains, only found {}'.format(population, pops))

    else:
        return population


def __get_node_groups(spike_trains, node_groups, population):
    """Helper function for parsing the 'node_groups' params"""
    if node_groups is None:
        # If none are specified by user make a 'node_group' consisting of all nodes
        selected_nodes = spike_trains.node_ids(population=population)
        return [{'node_ids': selected_nodes, 'c': 'b'}], selected_nodes
    else:
        # Fetch all node_ids which can be used to filter the data.
        node_groups = copy.deepcopy(node_groups)  # Make a copy since later we may be altering the dictionary
        selected_nodes = np.array(node_groups[0]['node_ids'])
        for grp in node_groups[1:]:
            if 'node_ids' not in grp:
                raise AttributeError('Could not find "node_ids" key in node_groups parameter.')
            selected_nodes = np.concatenate((selected_nodes, np.array(grp['node_ids'])))

        return node_groups, selected_nodes


[docs]def plot_raster(spike_trains, with_histogram=True, population=None, node_groups=None, times=None, title=None, show=True, save_as=None, plt_style=None): """will create a raster plot (plus optional histogram) from a SpikeTrains object or SONATA Spike-Trains file. Will return the figure By default will display all nodes, if you want to only display a subset of nodes and/or group together different nodes (by node_id) by dot colors and labels then you can use the node_groups, which should be a list of dicts:: plot_raster('/path/to/my/spike.h5', node_groups=[{'node_ids': range(0, 70), 'c': 'b', 'label': 'pyr'}, # first 70 nodes are blue pyr cells {'node_ids': range(70, 100), 'c': 'r', 'label': 'inh'}]) # last 30 nodes are red inh cells The histogram will not be grouped. :param spike_trains: SpikeTrains object or path to a (SONATA) spikes file. :param with_histogram: If True the a histogram will be shown as a small subplot below the scatter plot. Default True. :param population: string. If a spikes-file contains more than one population of nodes, use this to determine which nodes to actually plot. If only one population exists and population=None then the function will find it by default. :param node_groups: None or list of dicts. Used to group sets of nodes by labels and color. Each grouping should be a dictionary with a 'node_ids' key with a list of the ids. You can also add 'label' and 'c' keys for label and color. If None all nodes will be labeled and colored the same. :param times: (float, float). Used to set start and stop time. If not specified will try to find values from spiking data. :param title: str, Use to add a title. Default no tile :param show: bool to display or not display plot. default True. :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not save plot. :return: matplotlib figure.Figure object """ if plt_style is not None: plt.style.use(plt_style) spike_trains = __get_spike_trains(spike_trains=spike_trains) pop = __get_population(spike_trains=spike_trains, population=population) node_groups, selected_ids = __get_node_groups(spike_trains=spike_trains, node_groups=node_groups, population=pop) # Only show a legend if one of the node_groups have an explicit label, otherwise matplotlib will show an empty # legend box which looks bad show_legend = False # Situation where if the last (or first) M nodes don't spike matplotlib will cut off the y range, but it should # show these as empty rows. To do this need to keep track of range of all node_ids min_id, max_id = np.inf, -1 spikes_df = spike_trains.to_dataframe(populations=pop, with_population_col=False) spikes_df = spikes_df[spikes_df['node_ids'].isin(selected_ids)] if times is not None: min_ts, max_ts = times[0], times[1] spikes_df = spikes_df[(spikes_df['timestamps'] >= times[0]) & (spikes_df['timestamps'] <= times[1])] else: min_ts = np.min(spikes_df['timestamps']) max_ts = np.max(spikes_df['timestamps']) # Used to determine if with_histogram: fig, axes = plt.subplots(2, 1, gridspec_kw={'height_ratios': [7, 1]}, squeeze=True) raster_axes = axes[0] bottom_axes = hist_axes = axes[1] else: fig, axes = plt.subplots(1, 1) bottom_axes = raster_axes = axes hist_axes = None for node_grp in node_groups: grp_ids = node_grp.pop('node_ids') grp_spikes = spikes_df[spikes_df['node_ids'].isin(grp_ids)] # If label exists for at-least one group we want to show show_legend = show_legend or 'label' in node_grp # Finds min/max node_id for all node groups min_id = np.min([np.min(grp_ids), min_id]) max_id = np.max([np.max(grp_ids), max_id]) raster_axes.scatter(grp_spikes['timestamps'], grp_spikes['node_ids'], lw=1, s=12, marker='|',**node_grp) #raster_axes.scatter(grp_spikes['timestamps'], grp_spikes['node_ids'], lw=0, s=8, **node_grp) raster_axes.yaxis.set_major_locator(MaxNLocator(integer=True)) if show_legend: raster_axes.legend(loc='upper right', markerscale=1.5) if title: raster_axes.set_title(title) raster_axes.set_ylabel('node_ids') raster_axes.set_ylim(min_id - 0.5, max_id + 1) # add buffering to range else the rows at the ends look cut-off. raster_axes.set_xlim(min_ts, max_ts + 1) bottom_axes.set_xlabel('timestamps ({})'.format(spike_trains.units(population=pop))) if with_histogram: # Add a histogram if necessary hist_axes.hist(spikes_df['timestamps'], 100) hist_axes.set_xlim(min_ts - 0.5, max_ts + 1) hist_axes.axes.get_yaxis().set_visible(False) raster_axes.set_xticks([]) if save_as: plt.savefig(save_as) if show: plt.show() return fig
[docs]def moving_average(data, window_size=10): h = int(window_size / 2) x_max = len(data) return [np.mean(data[max(0, x - h):min(x_max, x + h)]) for x in range(0, x_max)]
[docs]def plot_rates(spike_trains, population=None, node_groups=None, times=None, smoothing=False, smoothing_params=None, title=None, show=True, save_as=None, plt_style=None): """Calculate and plot the rates of each node in a SpikeTrains object or SONATA Spike-Trains file. If start and stop times are not specified from the "times" parameter, will try to parse values from the timestamps data. If you want to only display a subset of nodes and/or group together different nodes (by node_id) by dot colors and labels then you can use the node_groups, which should be a list of dicts:: plot_rates('/path/to/my/spike.h5', node_groups=[{'node_ids': range(0, 70), 'c': 'b', 'label': 'pyr'}, {'node_ids': range(70, 100), 'c': 'r', 'label': 'inh'}]) :param spike_trains: SpikeTrains object or path to a (SONATA) spikes file. :param population: string. If a spikes-file contains more than one population of nodes, use this to determine which nodes to actually plot. If only one population exists and population=None then the function will find it by default. :param node_groups: None or list of dicts. Used to group sets of nodes by labels and color. Each grouping should be a dictionary with a 'node_ids' key with a list of the ids. You can also add 'label' and 'c' keys for label and color. If None all nodes will be labeled and colored the same. :param times: (float, float). Used to set start and stop time. If not specified will try to find values from spiking data. :param smoothing: Bool or function. Used to smooth the data. By default (False) no smoothing will be done. If True will using a moving average smoothing function. Or use a function pointer. :param smoothing_params: dict, parameters when using a function pointer smoothing value. :param title: str, Use to add a title. Default no tile :param show: bool to display or not display plot. default True. :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not save plot. :return: matplotlib figure.Figure object """ if plt_style is not None: plt.style.use(plt_style) spike_trains = __get_spike_trains(spike_trains=spike_trains) pop = __get_population(spike_trains=spike_trains, population=population) node_groups, selected_ids = __get_node_groups(spike_trains=spike_trains, node_groups=node_groups, population=pop) # Determine if smoothing will be applied to the data smoothing_params = smoothing_params or {} # pass in empty parameters if isinstance(smoothing, types.FunctionType): smoothing_fnc = partial(smoothing, **smoothing_params) elif smoothing: smoothing_fnc = partial(moving_average, **smoothing_params) else: smoothing_fnc = lambda d: d # Use a filler function that won't do anything # get data spikes_df = spike_trains.to_dataframe(populations=pop, with_population_col=False) spikes_df = spikes_df[spikes_df['node_ids'].isin(selected_ids)] if times is not None: recording_interval = times[1] - times[0] spikes_df = spikes_df[(spikes_df['timestamps'] >= times[0]) & (spikes_df['timestamps'] <= times[1])] else: recording_interval = np.max(spikes_df['timestamps']) - np.min(spikes_df['timestamps']) # Iterate through each group of nodes and add to the same plot fig, axes = plt.subplots() show_legend = False # Only show labels if one of the node group has label value for node_grp in node_groups: show_legend = show_legend or 'label' in node_grp # If label exists for at-least one group we want to show grp_ids = node_grp.pop('node_ids') grp_spikes = spikes_df[spikes_df['node_ids'].isin(grp_ids)] spike_rates = grp_spikes.groupby('node_ids').size() / (recording_interval / 1000.0) axes.plot(np.array(spike_rates.index), smoothing_fnc(spike_rates), '.', **node_grp) axes.set_ylabel('Firing Rates (Hz)') axes.set_xlabel('node_ids') if show_legend: axes.legend() # loc='upper right') if title: axes.set_title(title) if save_as: plt.savefig(save_as) if show: plt.show() return fig
[docs]def plot_rates_boxplot(spike_trains, population=None, node_groups=None, times=None, title=None, show=True, save_as=None, plt_style=None): """Creates a box plot of the firing rates taken from a SpikeTrains object or SONATA Spike-Trains file. If start and stop times are not specified from the "times" parameter, will try to parse values from the timestamps data. By default will plot all nodes together. To only display a subset of the nodes and/or create groups of nodes use the node_groups options:: plot_rates_boxplot( '/path/to/my/spike.h5', node_groups=[{'node_ids': range(0, 70), 'label': 'pyr'}, {'node_ids': range(70, 100), 'label': 'inh'}] ) :param spike_trains: SpikeTrains object or path to a (SONATA) spikes file. :param population: string. If a spikes-file contains more than one population of nodes, use this to determine which nodes to actually plot. If only one population exists and population=None then the function will find it by default. :param node_groups: None or list of dicts. Used to group sets of nodes by labels and color. Each grouping should be a dictionary with a 'node_ids' key with a list of the ids. You can also add 'label' and 'c' keys for label and color. If None all nodes will be labeled and colored the same. :param title: str, Use to add a title. Default no tile :param show: bool to display or not display plot. default True. :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not save plot. :return: matplotlib figure.Figure object """ if (plt_style is not None): plt.style.use(plt_style) spike_trains = __get_spike_trains(spike_trains=spike_trains) pop = __get_population(spike_trains=spike_trains, population=population) node_groups, selected_ids = __get_node_groups(spike_trains=spike_trains, node_groups=node_groups, population=pop) spikes_df = spike_trains.to_dataframe(populations=pop, with_population_col=False) spikes_df = spikes_df[spikes_df['node_ids'].isin(selected_ids)] if times is not None: recording_interval = times[1] - times[0] spikes_df = spikes_df[(spikes_df['timestamps'] >= times[0]) & (spikes_df['timestamps'] <= times[1])] else: recording_interval = np.max(spikes_df['timestamps']) - np.min(spikes_df['timestamps']) fig, axes = plt.subplots() rates_data = [] rates_labels = [] if len(node_groups) == 1 and 'label' not in node_groups[0]: node_groups[0]['label'] = 'All Nodes' for i, node_grp in enumerate(node_groups): rates_labels.append(node_grp.get('label', 'Node Group {}'.format(i))) grp_ids = node_grp.pop('node_ids') grp_spikes = spikes_df[spikes_df['node_ids'].isin(grp_ids)] spike_rates = grp_spikes.groupby('node_ids').size() / (recording_interval / 1000.0) rates_data.append(spike_rates) axes.boxplot(rates_data) axes.set_ylabel('Firing Rates (Hz)') axes.set_xticklabels(rates_labels) if title: axes.set_title(title) if save_as: plt.savefig(save_as) if show: plt.show() return fig