from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from .simulator_module import SimulatorMod
from bmtk.simulator.core.io_tools import io
from bmtk.utils import lazy_property
try:
import pynwb
has_pynwb = True
except ImportError as ie:
has_pynwb = False
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
bcast = comm.bcast
MPI_rank = comm.Get_rank()
MPI_size = comm.Get_size()
has_mpi = True
except:
MPI_rank = 0
MPI_size = 1
bcast = lambda v, n: v
has_mpi = False
try:
# The ndx-aibs extensions for nwb are required to load the neuropixel data modules,
# warning: can take some time to load
file_dir = Path(__file__).parent
namespace_path = (file_dir/"ndx-aibs-ecephys.namespace.yaml").resolve()
pynwb.load_namespaces(str(namespace_path))
except:
io.log_debug('ECEphysUnitsModule: Unable to load ndx-aibs-ecephys.namespace.yaml')
[docs]class ECEphysUnitsModule(SimulatorMod):
"""
TODO:
- Have option to specify the nwb file units and/or get them from the NWB
- Have option to save units-node mapping to output folder
"""
def __init__(self, name, **kwargs):
self._name = name
self._node_set = kwargs['node_set']
if not has_pynwb:
io.log_exception('ECEphysUnitsModule: pynwb is not installed (pip install pynwb), unable to use module.')
# Load a Strategy for mapping SONATA node_ids to NWB unit_ids
self._mapping_name = kwargs.get('mapping', 'invalid_strategy').lower()
if self._mapping_name in ['units_map']:
self._mapping_strategy = UnitIdMapStrategy(**kwargs)
elif self._mapping_name in ['sample', 'sample_without_replacement']:
self._mapping_strategy = SamplingStrategy(with_replacement=False, **kwargs)
elif self._mapping_name in ['sample_with_replacement']:
self._mapping_strategy = SamplingStrategy(with_replacement=True, **kwargs)
else:
io.log_exception('ECEphysUnitsModule: Invalid "mapping" parameters, options: units_map, sample, sample_with_replacement')
[docs] def initialize(self, sim):
raise NotImplementedError()
[docs]class NWBFileWrapper(object):
# A Simple wrapper class for nwb files, mainly to keep track of file-path which I can't get from pynwb
# TODO: Implement a Singleton so that the same nwb file isn't loaded multiple times
def __init__(self, nwb_path):
if isinstance(nwb_path, pynwb.file.NWBFile):
self._id = nwb_path.identifier
self._io = nwb_path
else:
self._id = nwb_path
self._io = pynwb.NWBHDF5IO(nwb_path, 'r').read()
@property
def uuid(self):
return self._id
def __getattr__(self, name):
return getattr(self.__dict__['_io'], name)
[docs]class TimeWindow(object):
"""
A class for dealing with different strategies for storing and intrepreting time windows intervals [start, stop],
mainly for use in filtering spike times. Including converting between seconds/miliseconds, parsing an NWB stimulus
table, and look up for individual unit time, and dealing with defaults. To initialize::
tw = TimeWindow(defaults=[interval1, interval2, ...], nwb_files=[session1.nwb, session2.nwb, ...])
Where interval<i> is a time-window associated with all unit spikes in session<i>.nwb, and can include None (in which
case it will not filter unit spike times).
If individual units will have unique time intervals then you can pass in a pandas DataFrame with columns
[unit_ids, start_times, stop_times], with values in ms::
tw.units_lu = units_times_table_df
And to fetch the time window associated with unit, for example unit_id 9999 in session_0.nwb, then call::
window = tw[9999, 'session_0.nwb']
If will first check to see if unit 9999 has a special [start_time, stop_time] in units_lu table, and if not then fall
back to the default for 'session_0.nwb', and in seconds. If unit_id/default_session is then it returns a None value.
"""
def __init__(self, defaults=None, nwb_files=None):
self._units_lu = None
self._default_windows = None
self.conversion_factor = 1/1000.0
# By requestion, the "time_window" option can
# - None
# - a single window; "time_window": [100, 200]
# - A stim_table filter: {stim_name: gratings, ori: 90.0, tf: 2.0, ...}
# - one unique window for each nwb_ids/files; "time_window": [[0, 100], [300, 400], ...]
# To handle this we will 1) convert each possible option into a list of 0, 1, or more windows. 2) Check
# that the num of time windows makes sense for the number of nwb_ids/files. And 3) Create a map of
# defaults for each nwb_ids/file
if defaults is not None:
time_windows = self._tolist(defaults)
n_windows = len(time_windows)
if n_windows > 1 and len(nwb_files) != n_windows:
# There can be no default time window, or one default for all nwb_files, otherwise the number
# of time_windows must correspond to the number of nwb_files
io.log_exception('ECEphysUnitsModule: Cannot match each "time_window" with the "input_file"s.')
if n_windows == 1:
# convert [interal] -> [interval, interval, interval, ...]
time_windows = time_windows*n_windows
self._default_windows = {nwb.uuid: self._parse_tw(tw, nwb) for tw, nwb in zip(time_windows, nwb_files)}
@property
def units_lu(self):
return self._units_lu
@units_lu.setter
def units_lu(self, units_table):
if 'start_times' in units_table and 'stop_times' in units_table:
units_table['start_times'] = units_table['start_times']*self.conversion_factor
units_table['stop_times'] = units_table['stop_times']*self.conversion_factor
if 'unit_ids' in units_table.columns:
units_table = units_table.set_index('unit_ids')
self._units_lu = units_table
def _tolist(self, window):
if isinstance(window, dict):
# Is a stimulus_table filter, ex. {"interval_name": "gratings", "ori": 90.0, ...}
return [window]
elif isinstance(window, (tuple, list)) and len(window) == 0:
# Is an empyt list
return []
elif isinstance(window, (tuple, list)) and isinstance(window[0], (tuple, list, dict, np.ndarray)):
# is a list of intervals, ex. [[0.0, 100.0], [200.0, 300.0], {stim:gratings} ...]
return window
else:
# assume is a interval, ex. [0.0, 100.0]
return [window]
def _parse_tw(self, interval, nwb_file):
"""Converts intervals, including windows and stim-table filters, into appropiate format [start (s), stop (s)]"""
if isinstance(interval, dict):
# If it is a dictionary try to find time interval by filtering on nwb.intervals, eg stimulus_table. The filter
# is a dictionary {'interval_name': 'flashes', 'col1': val1, 'col2': val2, ...}
filter = interval.copy()
stim_name = filter.pop('interval_name', None)
stim_idx = filter.pop('interval_index', 'all')
# In the NWB there are separate tables for each stimulus, and sometimes they are stored in the
# nwb as <flashes>_presentations.
if stim_name is None:
io.log_exception('Stimulus table filter missing "interval_name"')
if stim_name in nwb_file.intervals.keys():
interval_df = nwb_file.intervals[stim_name].to_dataframe()
elif stim_name + '_presentations' in nwb_file.intervals.keys():
interval_df = nwb_file.intervals[stim_name + '_presentations'].to_dataframe()
else:
io.log_exception('interval name "{}" not found in {}'.format(stim_name, nwb_file.uuid))
# In most cases
interval_df = filter_table(interval_df, filter)
if len(interval_df) == 0:
return [0.0, np.inf]
if stim_idx == 'all':
start_time = interval_df['start_time'].min()
stop_time = interval_df['stop_time'].max()
else:
start_time = interval_df.iloc[stim_idx]['start_time']
stop_time = interval_df.iloc[stim_idx]['stop_time']
# In the NWB stim_table and units_tables uses seconds, do not convert time-window
return [start_time, stop_time]
else:
# Is an interval [stop, start] that is entered in manually in units of ms, convert
# to seconds so it matches nwb spike_times units (s)
return [interval[0]/1000.0, interval[1]/1000.0]
def __getitem__(self, unit_info):
unit_id, nwb_uuid = unit_info[0], unit_info[1]
if (self.units_lu is not None) and (unit_id in self.units_lu.index):
unit = self.units_lu.loc[unit_id]
return [unit['start_times'], unit['stop_times']]
elif self._default_windows and nwb_uuid in self._default_windows.keys():
return self._default_windows[nwb_uuid]
else:
return None
[docs]class MappingStrategy(object):
def __init__(self, **kwargs):
self._nwb_paths = kwargs['input_file']
self._filters = kwargs.get('units', {})
self._simulation_onset = kwargs.get('interval_offset', 0.0)/1000.0
self._missing_ids = kwargs.get('missing_ids', 'fail')
self._cache_spike_times = kwargs.get('cache', False)
self._spike_times_cache = {}
default_window = kwargs.get('interval', None)
self._time_window = TimeWindow(defaults=default_window, nwb_files=self.nwb_files)
self._units_table = None
self._units2nodes_map = None
@lazy_property
def nwb_files(self):
if not isinstance(self._nwb_paths, (list, tuple)):
self._nwb_paths = [self._nwb_paths]
nwb_files = []
for nwb_path in self._nwb_paths:
nwb_files.append(NWBFileWrapper(nwb_path))
return nwb_files
@property
def units2nodes_map(self):
return self._units2nodes_map
@property
def units_table(self):
if self._units_table is None:
# Combine the units and channels table from the nwb file
merged_table = None
for nwb_file in self.nwb_files:
units_table = self._load_units_table(nwb_file)
units_table = self._filter_units_table(units_table)
units_table = units_table[['spike_times']]
units_table['nwb_uid'] = nwb_file.uuid
merged_table = units_table if merged_table is None else pd.concat((merged_table, units_table))
# if merged_table is None or len(merged_table) == 0:
# io.log_exception('ECEphysUnitsModule: Could not parse units table from nwb_file(s).')
self._units_table = merged_table
return self._units_table
def _load_units_table(self, nwb_file):
units = nwb_file.units.to_dataframe().reset_index()
units = units.rename(columns={'id': 'unit_id', 'peak_channel_id': 'channel_id'})
channels = nwb_file.electrodes.to_dataframe()
channels = channels.reset_index().rename(columns={'id': 'channel_id'})
return units.merge(channels, how='left', on='channel_id').set_index('unit_id')
def _filter_units_table(self, units_table):
units_table = filter_table(units_table, self._filters)
return units_table
[docs] def build_map(self, node_set):
raise NotImplementedError()
[docs] def get_spike_trains(self, node_id, source_population):
# TODO: Is it worth caching spike-trains so we don't have to do a lookup + filtering
# every time?
if node_id not in self._units2nodes_map:
msg = 'ECEphysUnitsModule: Could not find mapping for node_id {}.'.format(node_id)
if self._missing_ids == 'fail':
io.log_exception(msg)
elif self._missing_ids == 'warn':
io.log_warning(msg)
return np.array([])
unit_id = self._units2nodes_map[node_id]
spike_times = np.array(self.units_table.loc[unit_id]['spike_times'])
nwb_uid = self.units_table.loc[unit_id]['nwb_uid']
time_window = self._time_window[unit_id, nwb_uid]
if time_window is not None:
spike_times = spike_times[
(time_window[0] <= spike_times) & (spike_times <= time_window[1])
]
spike_times = spike_times - time_window[0] + self._simulation_onset
spike_times = spike_times*1000.0 # Convert from seconds to miliseconds
return spike_times
[docs]class UnitIdMapStrategy(MappingStrategy):
def __init__(self, **kwargs):
super().__init__(**kwargs)
try:
self._mapping_path = kwargs['units']
except KeyError:
io.log_exception('ECEphysUnitsModule: Could not find "units" csv path for units-to-nodes mapping.')
def _filter_units_table(self, units_table):
return units_table
[docs] def build_map(self, node_set):
try:
# TODO: Include population name
mapping_file_df = self._mapping_path if isinstance(self._mapping_path, pd.DataFrame) else pd.read_csv(self._mapping_path, sep=' ')
self._units2nodes_map = mapping_file_df[['node_ids', 'unit_ids']].set_index('node_ids').to_dict()['unit_ids']
if 'start_times' in mapping_file_df:
self._time_window.units_lu = mapping_file_df[['unit_ids', 'start_times', 'stop_times']].set_index('unit_ids')
except (FileNotFoundError, UnicodeDecodeError):
io.log_exception('ECEphysUnitsModule: {} should be a space separated file with columns "node_ids", "unit_ids"'.format(self._mapping_path))
[docs]class SamplingStrategy(MappingStrategy):
def __init__(self, with_replacement=False, **kwargs):
super().__init__(**kwargs)
self._with_replacement = with_replacement
[docs] def build_map(self, node_set):
node_ids = node_set.node_ids
unit_ids = self.units_table.index.values
# There is no way to randomly map a subset of possible unit_ids to all possible
# node_ids. Ignore, warn user, or fail depending on config file option
if (not self._with_replacement) and len(node_ids) > len(unit_ids):
if self._missing_ids == 'fail':
# Fail application
io.log_exception('ECEphysUnitsModule: Not enough NWB unit_ids to map onto node_set.')
# Not all node_ids will have spikes, TODO: Make do this at random?
node_ids = node_ids[:len(unit_ids)]
# When running with MPI, need to make sure the sampling of the unit_id maps is
# consistent across all cores. Shuffle on rank 0 and broadcast the new order to all
# other ranks
if MPI_rank == 0:
shuffled_unit_ids = np.random.choice(
unit_ids, size=len(node_ids), replace=self._with_replacement
)
else:
shuffled_unit_ids = None
shuffled_unit_ids = bcast(shuffled_unit_ids, 0)
# TODO: Include population name
# Creates an mapping between SONTA node_ids and NWB unit_ids
self._units2nodes_map = pd.DataFrame({
'node_ids': node_ids,
'unit_ids': shuffled_unit_ids
}).set_index('node_ids').to_dict()['unit_ids']
[docs]def filter_table(table_df, filters_dict):
if filters_dict:
# Filter out only those specified units
mask = True
for col_name, filter_val in filters_dict.items():
try:
if isinstance(filter_val, str):
mask &= table_df[col_name] == filter_val
elif isinstance(filter_val, (list, np.ndarray, tuple)):
mask &= table_df[col_name].isin(filter_val)
elif isinstance(filter_val, dict):
col = filter_val.get('column', col_name)
op = filter_val['operation']
val = filter_val['value']
val = '"{}"'.format(val) if isinstance(val, str) else val
expr = 'table_df["{}"] {} {}'.format(col, op, val)
mask &= eval(expr)
else:
mask &= table_df[col_name] == filter_val
except KeyError as ke:
col = filter_val.get('column', col_name) if isinstance(filter_val, dict) else col_name
io.log_exception('ECEphysUnitsModule: Could not find "{}" column in units/electrodes table.'.format(col))
table_df = table_df[mask]
return table_df