Source code for bmtk.utils.reports.compartment.compartment_reader

import h5py
import numpy as np

from .core import CompartmentReaderABC
from bmtk.utils.hdf5_helper import get_attribute_h5


class _CompartmentPopulationReaderVer01(CompartmentReaderABC):
    sonata_columns = ['element_ids', 'element_pos', 'index_pointer', 'node_ids', 'time']

    def __init__(self, pop_grp, pop_name):
        self._pop_grp = pop_grp
        self._data_grp = pop_grp['data']
        self._mapping = pop_grp['mapping']
        self._population = pop_name

        self._gid2data_table = {}
        if self._mapping is None:
            raise Exception('could not find /mapping group')

        gids_ds = self._mapping[self.node_ids_ds]  # ['node_ids']
        index_pointer_ds = self._mapping['index_pointer']
        for indx, gid in enumerate(gids_ds):
            self._gid2data_table[gid] = slice(index_pointer_ds[indx], index_pointer_ds[indx+1])

        time_ds = self._mapping['time']
        self._t_start = float(time_ds[0])
        self._t_stop = float(time_ds[1])
        self._dt = float(time_ds[2])
        self._n_steps = int((self._t_stop - self._t_start) / self._dt)

        self._custom_cols = {col: grp for col, grp in self._mapping.items() if
                             col not in self.sonata_columns and isinstance(grp, h5py.Dataset)}

    def _get_index(self, node_id):
        if node_id not in self._gid2data_table:
            raise KeyError('node_id {} not found in {}/mapping/node_ids in file {}'.format(
                node_id, self._pop_grp.name, self._pop_grp.file.filename
            ))

        return self._gid2data_table[node_id]

    @property
    def populations(self):
        return [self._population]

    @property
    def data_ds(self):
        return self._data_grp

    @property
    def node_ids_ds(self):
        return 'node_ids'

    def get_population(self, population, default=None):
        raise NotImplementedError()

    def units(self, population=None):
        return get_attribute_h5(self.data_ds, 'units', None)
        # return self.data_ds.attrs.get('units', None)

    def variable(self, population=None):
        return get_attribute_h5(self.data_ds, 'variable', None)
        # return self.data_ds.attrs.get('variable', None)

    def tstart(self, population=None):
        return self._t_start

    def tstop(self, population=None):
        return self._t_stop

    def dt(self, population=None):
        return self._dt

    def n_steps(self, population=None):
        return self._n_steps

    def time(self, population=None):
        return self._mapping['time'][()]

    def time_trace(self, population=None):
        return np.linspace(self.tstart(), self.tstop(), num=self._n_steps, endpoint=True)

    def node_ids(self, population=None):
        return self._mapping['node_ids'][()]

    def index_pointer(self, population=None):
        return self._mapping['index_pointer'][()]

    def element_pos(self, node_id=None, population=None):
        if node_id is None:
            return self._mapping['element_pos'][()]
        else:
            return self._mapping['element_pos'][self._get_index(node_id)]#[indx_beg:indx_end]

    def element_ids(self, node_id=None, population=None):
        if node_id is None:
            return self._mapping['element_ids'][()]
        else:
            # indx_beg, indx_end = self._get_index(node_id)
            # return self._mapping['element_ids'][self._get_index(node_id)]#[indx_beg:indx_end]
            return self._mapping['element_ids'][self._get_index(node_id)]

    def n_elements(self, node_id=None, population=None):
        return len(self.element_pos(node_id))

    def data(self, node_id=None, population=None, time_window=None, sections='all', **opts):
        # filtered_data = self._data_grp
        multi_compartments = True
        if node_id is not None:
            node_range = self._get_index(node_id)
            if sections == 'origin' or self.n_elements(node_id) == 1:
                # Return the first (and possibly only) compartment for said gid
                gid_slice = node_range
                multi_compartments = False
            elif sections == 'all':
                # Return all compartments
                gid_slice = node_range #slice(node_beg, node_end)
            else:
                # return all compartments with corresponding element id
                compartment_list = list(sections) if np.isscalar(sections) else sections
                gid_slice = [i for i in self._get_index(node_id) if self._mapping['element_ids'] in compartment_list]
        else:
            gid_slice = slice(0, self._data_grp.shape[1])

        if time_window is None:
            time_slice = slice(0, self._n_steps)
        else:
            if len(time_window) != 2:
                raise Exception('Invalid time_window, expecting tuple [being, end].')

            window_beg = max(int((time_window[0] - self.tstart()) / self.dt()), 0)
            window_end = min(int((time_window[1] - self.tstart()) / self.dt()), self._n_steps)
            time_slice = slice(window_beg, window_end)

        filtered_data = np.array(self._data_grp[time_slice, gid_slice])
        return filtered_data if multi_compartments else filtered_data[:]

    def custom_columns(self, population=None):
        return {k: v[()] for k,v in self._custom_cols.items()}

    def get_column(self, column_name, population=None):
        return self._mapping[column_name][()]

    def get_element_data(self, node_id, population=None):
        pass

    def get_report_description(self, population=None):
        pass

    def __getitem__(self, population):
        return self


class _CompartmentPopulationReaderVer00(_CompartmentPopulationReaderVer01):
    sonata_columns = ['element_id', 'element_pos', 'index_pointer', 'gids', 'time']

    def node_ids(self, population=None):
        return self._mapping[self.node_ids_ds][()]

    @property
    def node_ids_ds(self):
        return 'gids'

    def element_ids(self, node_id=None, population=None):
        if node_id is None:
            return self._mapping['element_id'][()]
        else:
            # indx_beg, indx_end = self._get_index(node_id)
            # return self._mapping['element_id'][self._get_index(node_id)]#[indx_beg:indx_end]
            return self._mapping['element_id'][self._get_index(node_id)]  # [indx_beg:indx_end]


[docs]class CompartmentReaderVer01(CompartmentReaderABC): def __init__(self, filename, mode='r', **params): self._h5_handle = h5py.File(filename, mode) self._h5_root = self._h5_handle[params['h5_root']] if 'h5_root' in params else self._h5_handle['/'] self._popgrps = {} self._mapping = None if 'report' in self._h5_handle.keys(): report_grp = self._h5_root['report'] for pop_name, pop_grp in report_grp.items(): self._popgrps[pop_name] = _CompartmentPopulationReaderVer01(pop_grp=pop_grp, pop_name=pop_name) else: self._default_population = 'pop_na' self._popgrps[self._default_population] = _CompartmentPopulationReaderVer00(pop_grp=self._h5_root, pop_name=self._default_population) if 'default_population' in params: # If user has specified a default population self._default_population = params['default_population'] if self._default_population not in self._popgrps.keys(): raise Exception('Unknown population {} found in report.'.format(self._default_population)) elif len(self._popgrps.keys()) == 1: # If there is only one population in the report default to that self._default_population = list(self._popgrps.keys())[0] else: self._default_population = None @property def default_population(self): if self._default_population is None: raise Exception('Please specify a node population.') return self._default_population @property def populations(self): return list(self._popgrps.keys())
[docs] def get_population(self, population, default=None): if population not in self.populations: return default return self[population]
[docs] def units(self, population=None): population = population or self.default_population return self[population].units()
[docs] def variable(self, population=None): population = population or self.default_population return self[population].variable()
[docs] def tstart(self, population=None): population = population or self.default_population return self[population].tstart()
[docs] def tstop(self, population=None): population = population or self.default_population return self[population].tstop()
[docs] def dt(self, population=None): population = population or self.default_population return self[population].dt()
[docs] def time_trace(self, population=None): population = population or self.default_population return self[population].time_trace()
[docs] def node_ids(self, population=None): population = population or self.default_population return self[population].node_ids()
[docs] def element_pos(self, node_id=None, population=None): population = population or self.default_population return self[population].element_pos(node_id)
[docs] def element_ids(self, node_id=None, population=None): population = population or self.default_population return self[population].element_ids(node_id)
[docs] def n_elements(self, node_id=None, population=None): population = population or self.default_population return self[population].n_elements(node_id)
[docs] def data(self, node_id=None, population=None, time_window=None, sections='all', **opt_attrs): population = population or self.default_population return self[population].data(node_id=node_id, time_window=time_window, sections=sections, **opt_attrs)
[docs] def custom_columns(self, population=None): population = population or self.default_population return self[population].custom_columns(population)
[docs] def get_column(self, column_name, population=None): population = population or self.default_population return self[population].get_column(column_name)
[docs] def get_node_description(self, node_id, population=None): population = population or self.default_population return self[population].get_node_description(node_id)
[docs] def get_report_description(self, population=None): population = population or self.default_population return self[population].get_report_description()
def __getitem__(self, population): return self._popgrps[population]