Source code for bmtk.utils.reports.compartment.plotting

import matplotlib.pyplot as plt
import numpy as np
import copy
import warnings

from .compartment_report import CompartmentReport


def __get_population(report, population):
    """Helper function to figure out which population of nodes to use."""
    pops = report.populations
    if population is None:
        # If only one population exists in spikes object/file select that one
        if len(pops) > 1:
            raise Exception('Compartment Report contains more than one population of nodes. Use "population" parameter '
                            'to specify population to display.')

        else:
            return pops[0]

    elif population not in pops:
        raise Exception('Could not find node population "{}" in , only found {}'.format(population, pops))

    else:
        return population

def __get_node_groups(report, node_groups, population):
    """Helper function for parsing the 'node_groups' params"""
    if node_groups is None:
        # If none are specified by user make a 'node_group' consisting of all nodes
        selected_nodes = report.node_ids(population=population)
        return [{'node_ids': selected_nodes, 'c': 'b', 'label': 'averaged'}], selected_nodes
    else:
        # Fetch all node_ids which can be used to filter the data.
        node_groups = copy.deepcopy(node_groups)  # Make a copy since later we may be altering the dictionary
        selected_nodes = np.array(node_groups[0]['node_ids'])
        for grp in node_groups[1:]:
            if 'node_ids' not in grp:
                raise AttributeError('Could not find "node_ids" key in node_groups parameter.')
            selected_nodes = np.concatenate((selected_nodes, np.array(grp['node_ids'])))

        return node_groups, selected_nodes


[docs]def plot_traces(report, population=None, node_ids=None, sections='origin', average=False, node_groups=None, times=None, title=None, show_legend=None, show=True, save_as=None, plt_style=None): """Displays the time trace of one or more nodes from a SONATA CompartmentReport file. To plot a group of individual variable traces (based on their soma):: plot_traces('/path/to/report.h5', node_ids=[0, 10, 20, ...]) If node_ids=None (default) then all nodes in the report will be displayed. For large networks then can become difficult to visualize so it's recommended you use average=True:: plot_traces('/path/to/report.h5', average=True) Users also have the option of taking the averages of multiple subsets of nodes using the "node_groups" options. "node_groups" should be a list of dictionary, each dict with a 'node_ids': [list of ids], and optionally a 'label' and 'c' (color). For example support nodes [0, 70) are excitatory cells, nodes [70, 100) are inhibitory, and we want display excitatory and inhibitory averages and blue and red, respectivly:: plot_traces('/path/to/report.h5', node_groups=[{'node_ids': range(0, 70), 'label': 'exc', 'c': 'b'}, {'node_ids': range(70, 100), 'label': 'inh', 'c': 'r'}]) :param report: Path to SONATA report file or CompartmentReport object :param population: string. If the report more than one population of nodes, use this to determine which nodes to plot. If only one population exists and population=None then the function will find it by default. :param node_ids: int or list of ints. Individual node to display the variable :param sections: 'origin', 'all', or list of ids, Compartments/elements to display, By default will only show values at the soma. :param average: If True will take the averages of all/selected nodes. Default False :param node_groups: None or list of dicts. Used to group sets of nodes by labels and color. Each grouping should be a dictionary with a 'node_ids' key with a list of the ids. You can also add 'label' and 'c' keys for label and color. If None all nodes will be labeled and colored the same. :param times: (float, float), start and stop times of simulation :param title: str, adds a title to the plot :param show_legend: Set True or False to determine if legend should be displayed on the plot. The default (None) function itself will guess if legend should be shown. :param show: bool to display or not display plot. default True. :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not save plot. :return: matplotlib figure.Figure object """ if plt_style is not None: plt.style.use(plt_style) if node_groups is not None and node_ids is not None: warnings.warn('plot_traces is called with both "node_ids" and "node_groups" parameters.', UserWarning) elif node_ids is not None: if average and (not np.isscalar(node_ids)) and len(node_ids) > 1: node_groups = {'node_ids': node_ids, 'label': average} return plot_traces_averaged(report=report, population=population, sections=sections, node_groups=node_groups, times=times, title=title, show_legend=show_legend, show=show, save_as=save_as) return plot_traces_individual(report=report, population=population, node_ids=node_ids, sections=sections, times=times, title=title, show_legend=show_legend, show=show, save_as=save_as) elif node_groups is not None: return plot_traces_averaged(report=report, population=population, sections=sections, node_groups=node_groups, times=times, title=title, show_legend=show_legend, show=show, save_as=save_as) elif average is not None: return plot_traces_averaged(report=report, population=population, sections=sections, times=times, title=title, show_legend=show_legend, show=show, save_as=save_as) else: return plot_traces_individual(report=report, population=population, sections=sections, times=times, title=title, show_legend=show_legend, show=show, save_as=save_as)
[docs]def plot_traces_individual(report, population=None, node_ids=None, sections='origin', times=None, title=None, show_legend=None, show=True, save_as=None): """Used the plot time traces of individual nodes from a SONATA compartment report file. Recommended use plot_traces instead, which will call this function if necessarcy. :param report: Path to SONATA report file or CompartmentReport object. :param population: string. If the report more than one population of nodes, use this to determine which nodes to plot. If only one population exists and population=None then the function will find it by default. :param node_ids: int or list of ints. Individual node to display the variable. :param sections: 'origin', 'all', or list of ids, Compartments/elements to display, By default will only show values at the soma. :param times: (float, float), start and stop times of simulation. :param title: str, adds a title to the plot. :param show_legend: Set True or False to determine if legend should be displayed on the plot. The default (None) function itself will guess if legend should be shown. :param show: bool to display or not display plot. default True. :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not save plot. :return: matplotlib figure.Figure object """ cr = CompartmentReport.load(report) pop = __get_population(report=cr, population=population) # get node populations name if node_ids is None: node_ids = cr.node_ids(population=pop) elif np.isscalar(node_ids): node_ids = [node_ids] trace_times = cr.time_trace(population=pop) if times is not None: times_indx = np.argwhere((trace_times >= times[0]) & (trace_times <= times[1])) trace_times = trace_times[times_indx] else: # Get start and stop time from the traces data times = (trace_times[0], trace_times[-1]) fig, axes = plt.subplots() for node_id in node_ids: trace_data = cr.data(node_id=node_id, sections=sections, population=pop, time_window=times) axes.plot(trace_times, trace_data, label=node_id) axes.set_xlim(times[0], times[1]) axes.set_xlabel('time ({})'.format('ms')) if (show_legend is None and len(node_ids) <= 10) or show_legend: axes.legend() if title: axes.set_title(title) if save_as: plt.savefig(save_as) if show: plt.show() return fig
[docs]def plot_traces_averaged(report, population=None, sections='origin', node_groups=None, times=None, title=None, show_background=True, show_legend=None, show=True, save_as=None): """Used to plot averages across multiple nodes in a SONATA Compartment Report file. Recommended that you use "plot_traces" function. To plot multiple averages use the "node_groups" options. "node_groups" should be a list of dictionary, each dict with a 'node_ids': [list of ids], and optionally a 'label' and 'c' (color). For example support nodes [0, 70) are excitatory cells, nodes [70, 100) are inhibitory, and we want display excitatory and inhibitory averages and blue and red, respectively:: plot_traces('/path/to/report.h5', node_groups=[{'node_ids': range(0, 70), 'label': 'exc', 'c': 'b'}, {'node_ids': range(70, 100), 'label': 'inh', 'c': 'r'}]) :param report: Path to SONATA report file or CompartmentReport object. :param population: string. If the report more than one population of nodes, use this to determine which nodes to plot. If only one population exists and population=None then the function will find it by default. :param sections: 'origin', 'all', or list of ids, Compartments/elements to display, By default will only show values at the soma. :param node_groups: None or list of dicts. Used to group sets of nodes by labels and color. Each grouping should be a dictionary with a 'node_ids' key with a list of the ids. You can also add 'label' and 'c' keys for label and color. If None all nodes will be labeled and colored the same. :param times: (float, float), start and stop times of simulation. :param title: str, adds a title to the plot. :param show_background: shows all the individual traces greyed in the background. :param show_legend: Set True or False to determine if legend should be displayed on the plot. The default (None) function itself will guess if legend should be shown. :param show: bool to display or not display plot. default True. :param save_as: None or str: file-name/path to save the plot as a png/jpeg/etc. If None or empty string will not save plot. :return: matplotlib figure.Figure object """ cr = CompartmentReport.load(report) pop = __get_population(report=cr, population=population) # get node populations name node_groups, selected_ids = __get_node_groups(report=cr, node_groups=node_groups, population=pop) nodes_array = cr.node_ids(population=pop) trace_times = cr.time_trace(population=pop) traces_data = cr.data(sections=sections, population=pop) if times is not None: min_ts, max_ts = times[0], times[1] times_indx = np.argwhere((trace_times >= min_ts) & (trace_times <= max_ts)) times_indx = times_indx.flatten() traces_data = traces_data[times_indx, :] trace_times = trace_times[times_indx] else: min_ts, max_ts = trace_times[0], trace_times[-1] fig, axes = plt.subplots() has_labels = False if show_background: nodes_indx = np.argwhere(np.isin(nodes_array, selected_ids)) nodes_indx = nodes_indx.flatten() background_data = traces_data[:, nodes_indx] axes.plot(trace_times, background_data, c='lightgray') for node_grp in node_groups: grp_ids = node_grp.pop('node_ids') has_labels = has_labels or 'label' in node_grp nodes_indx = np.argwhere(np.isin(nodes_array, grp_ids)) nodes_indx = nodes_indx.flatten() if len(nodes_indx) == 0: continue grp_data = traces_data[:, nodes_indx] grp_mean = grp_data.mean(axis=1) axes.plot(trace_times, grp_mean, **node_grp) axes.set_xlim(min_ts, max_ts) axes.set_xlabel('time ({})'.format('ms')) y_label = cr.variable(population=pop) if cr.units(population=pop) is not None: y_label += '({})'.format(cr.units(population=pop)) axes.set_ylabel(cr.variable(population=pop)) if title: axes.set_title(title) if (show_legend is None or show_legend) and has_labels: axes.legend() if save_as: plt.savefig(save_as) if show: plt.show() return fig