# Copyright 2020. Allen Institute. All rights reserved
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
# following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
# disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import numpy as np
import pandas as pd
import csv
import time
from array import array
from .core import SortOrder, pop_na, comm, MPI_size, MPI_rank, comm_barrier
from .core import col_node_ids, col_population, col_timestamps
from .spike_trains_api import SpikeTrainsAPI
def _spikes_filter1(p, t, time_window, populations):
return p in populations and time_window[0] <= t <= time_window[1]
def _spikes_filter2(p, t, populations):
return p in populations
def _spikes_filter3(p, t, time_window):
return time_window[0] <= t <= time_window[1]
def _create_filter(populations, time_window):
from functools import partial
if populations is None and time_window is None:
return lambda p, t: True
if populations is None:
return partial(_spikes_filter3, time_window=time_window)
populations = [populations] if np.isscalar(populations) else populations
if time_window is None:
return partial(_spikes_filter2, populations=populations)
else:
return partial(_spikes_filter1, populations=populations, time_window=time_window)
def _create_empty_df(with_population_col=True):
columns = [col_timestamps, col_population, col_node_ids] if with_population_col else [col_timestamps, col_node_ids]
return pd.DataFrame(columns=columns)
[docs]class STMemoryBuffer(SpikeTrainsAPI):
""" A Class for creating, storing and reading multi-population spike-trains - especially for saving the spikes of a
large scale network simulation. Keeps a running tally of the (timestamp, population-name, node_id) for each
individual spike.
The spikes are stored in memory and very large and/or epiletic simulations may run into memory issues. Not designed
to work with parallel simulations.
"""
def __init__(self, default_population=None, store_type='array', **kwargs):
self._default_population = default_population or kwargs.get('population', None) or pop_na
self._store_type = store_type
# self._pop_counts = {self._default_population: 0} # A count of spikes per population
self._units = kwargs.get('units', 'ms') # for backwards compatability default to milliseconds
self._pops = {}
[docs] def add_spike(self, node_id, timestamp, population=None, **kwargs):
population = population or self._default_population
if population not in self._pops:
self._create_store(population)
self._pops[population][col_node_ids].append(node_id)
self._pops[population][col_timestamps].append(timestamp)
[docs] def add_spikes(self, node_ids, timestamps, population=None, **kwargs):
population = population or self._default_population
if np.isscalar(node_ids):
node_ids = [node_ids]*len(timestamps)
if len(node_ids) != len(timestamps):
raise ValueError('node_ids and timestamps must by of the same length')
if population not in self._pops:
self._create_store(population)
pop_data = self._pops[population]
pop_data[col_node_ids].extend(node_ids)
pop_data[col_timestamps].extend(timestamps)
def _create_store(self, population):
"""Helper for creating storage data struct of a population, so add_spike/add_spikes is consistent."""
# Benchmark Notes:
# Tested with numpy, lists and arrays. np.concate/append is too slow to consider. regular list is ~2-3x
# faster than array, but require 2-4x the amount of memory. For larger and parallelized applications
# (> 100 million spikes) use array since the amount of memory can required can exceed amount available. But
# if memory is not an issue use list.
if self._store_type == 'list':
self._pops[population] = {col_node_ids: [], col_timestamps: []}
elif self._store_type == 'array':
self._pops[population] = {col_node_ids: array('I'), col_timestamps: array('d')}
else:
raise AttributeError('Uknown store type {} for SpikeTrains'.format(self._store_type))
[docs] def import_spikes(self, obj, **kwargs):
pass
[docs] def flush(self):
pass # not necessary since everything is stored in memory
[docs] def close(self):
pass # don't need to do anything
@property
def populations(self):
return list(self._pops.keys())
[docs] def node_ids(self, population=None):
population = population if population is not None else self._default_population
if population not in self._pops:
return []
return np.unique(self._pops[population][col_node_ids]).astype(np.uint)
[docs] def units(self, population=None):
return self._units
[docs] def set_units(self, v, population=None):
self._units = v
[docs] def n_spikes(self, population=None):
population = population if population is not None else self._default_population
if population not in self._pops:
return 0
return len(self._pops[population][col_timestamps])
[docs] def get_times(self, node_id, population=None, time_window=None, **kwargs):
population = population if population is not None else self._default_population
pop = self._pops[population]
# filter by node_id and (if specified) by time.
mask = np.array(pop[col_node_ids]) == node_id
ts = np.array(pop[col_timestamps])
if time_window:
mask &= (time_window[0] <= ts) & (ts <= time_window[1])
return ts[mask]
[docs] def to_dataframe(self, populations=None, sort_order=SortOrder.none, with_population_col=True, **kwargs):
if populations is None:
selelectd_pops = list(self.populations)
elif np.isscalar(populations):
selelectd_pops = [populations]
else:
selelectd_pops = populations
ret_df = None
for pop_name in selelectd_pops:
pop_data = self._pops.get(pop_name, {col_node_ids: [], col_timestamps: []})
pop_df = pd.DataFrame({
col_node_ids: pop_data[col_node_ids],
col_timestamps: pop_data[col_timestamps]
})
if with_population_col:
pop_df[col_population] = pop_name
if sort_order == SortOrder.by_id:
pop_df = pop_df.sort_values(col_node_ids)
elif sort_order == SortOrder.by_time:
pop_df = pop_df.sort_values(col_timestamps)
if ret_df is None:
ret_df = pop_df
else:
ret_df = pd.concat((ret_df, pop_df))
# Make sure ret_df is not None
ret_df = _create_empty_df(with_population_col) if ret_df is None else ret_df
return ret_df
[docs] def spikes(self, populations=None, time_window=None, sort_order=SortOrder.none, **kwargs):
if populations is None:
populations = self.populations
elif np.isscalar(populations):
populations = [populations]
for pop_name in populations:
pop_data = self._pops.get(pop_name, {col_node_ids: [], col_timestamps: []})
timestamps = pop_data[col_timestamps]
node_ids = pop_data[col_node_ids]
if sort_order == SortOrder.by_id:
sort_indx = np.argsort(node_ids)
elif sort_order == SortOrder.by_time:
sort_indx = np.argsort(timestamps)
else:
sort_indx = range(len(timestamps))
filter = _create_filter(populations, time_window)
for i in sort_indx:
t = timestamps[i]
p = pop_name
if filter(p=p, t=t):
yield t, p, node_ids[i]
return
def __len__(self):
return len(self.to_dataframe())
[docs]class STMPIBuffer(STMemoryBuffer):
def __init__(self, default_population=None, store_type='array', **kwargs):
self.mpi_rank = kwargs.get('MPI_rank', MPI_rank)
self.mpi_size = kwargs.get('MPI_size', MPI_size)
super(STMPIBuffer, self).__init__(default_population=default_population, store_type=store_type, **kwargs)
def _gatherv(self, population, on_all_ranks=True):
from mpi4py import MPI
local_n_spikes = super(STMPIBuffer, self).n_spikes(population)
sizes = comm.allgather(local_n_spikes)
offsets = np.zeros(MPI_size, dtype=np.int64)
offsets[1:] = np.cumsum(sizes)[:-1]
all_n_spikes = np.sum(sizes)
local_population = self._pops.get(population, {col_node_ids: [], col_timestamps: []}) # if pop not on rank
local_node_ids = np.array(local_population[col_node_ids], dtype=np.uint64)
all_node_ids = np.zeros(all_n_spikes, dtype=np.uint64)
if on_all_ranks:
comm.Allgatherv(local_node_ids, [all_node_ids, sizes, offsets, MPI.UINT64_T])
else:
comm.Gatherv(local_node_ids, [all_node_ids, sizes, offsets, MPI.UINT64_T], root=0)
if MPI_rank != 0:
all_node_ids = None
local_timestamps = np.array(local_population[col_timestamps], dtype=np.double)
all_timestamps = np.zeros(all_n_spikes, dtype=np.double)
if on_all_ranks:
comm.Allgatherv(local_timestamps, [all_timestamps, sizes, offsets, MPI.DOUBLE])
else:
comm.Gatherv(local_timestamps, [all_timestamps, sizes, offsets, MPI.DOUBLE], root=0)
if MPI_rank != 0:
all_timestamps = None
return all_node_ids, all_timestamps
[docs] def to_dataframe(self, populations=None, sort_order=SortOrder.none, with_population_col=True, on_rank='all',
**kwargs):
comm_barrier()
if on_rank == 'local':
return super(STMPIBuffer, self).to_dataframe(populations=populations, sort_order=sort_order,
with_population_col=with_population_col, **kwargs)
if on_rank not in ['local', 'all', 'root']:
raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
if populations is None:
selelectd_pops = list(self.get_populations(on_rank='all'))
elif np.isscalar(populations):
selelectd_pops = [populations]
else:
selelectd_pops = populations
# Make sure the list of populations is exactly the same (including order) across all ranks so that
# _gatherv is called in the same sequence across all ranks.
selelectd_pops.sort()
ret_df = None
for pop_name in selelectd_pops:
if on_rank == 'all':
node_ids, timestamps = self._gatherv(population=pop_name, on_all_ranks=True)
pop_df = pd.DataFrame({
col_node_ids: node_ids,
col_timestamps: timestamps
})
if with_population_col:
pop_df[col_population] = pop_name
if sort_order == SortOrder.by_id:
pop_df = pop_df.sort_values(col_node_ids)
elif sort_order == SortOrder.by_time:
pop_df = pop_df.sort_values(col_timestamps)
if ret_df is None:
ret_df = pop_df
else:
ret_df = pd.concat((ret_df, pop_df))
elif on_rank == 'root':
node_ids, timestamps = self._gatherv(population=pop_name, on_all_ranks=False)
if MPI_rank != 0:
continue
pop_df = pd.DataFrame({
col_node_ids: node_ids,
col_timestamps: timestamps
})
if with_population_col:
pop_df[col_population] = pop_name
if sort_order == SortOrder.by_id:
pop_df = pop_df.sort_values(col_node_ids)
elif sort_order == SortOrder.by_time:
pop_df = pop_df.sort_values(col_timestamps)
if ret_df is None:
ret_df = pop_df
else:
ret_df = pd.concat((ret_df, pop_df))
comm_barrier()
if on_rank == 'all' or MPI_rank == 0:
# If using 'all' or on rank 0 a dataframe is expected even if there are no spikes
ret_df = _create_empty_df(with_population_col) if ret_df is None else ret_df
return ret_df
@property
def populations(self):
return self.get_populations(on_rank='all')
[docs] def get_populations(self, on_rank='all'):
local_pops = list(super(STMPIBuffer, self).populations)
if on_rank == 'local':
return local_pops
if on_rank == 'all':
gathered_pops = comm.allgather(local_pops)
elif on_rank == 'root':
gathered_pops = comm.gather(local_pops, 0)
else:
raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
if gathered_pops is None:
return None
else:
all_populations = set()
for pops in gathered_pops:
all_populations |= set(pops)
# WARNING: For a number of parallel applications it's important that the list of populations returned
# is the same across all ranks (eg ranks don't iterate through each population in different sequences)
all_populations = list(all_populations)
all_populations.sort()
return all_populations
[docs] def node_ids(self, population=None, on_rank='all'):
local_node_ids = super(STMPIBuffer, self).node_ids(population)
if on_rank == 'local':
return local_node_ids
if on_rank == 'all':
gathered_nodes = comm.allgather(local_node_ids)
elif on_rank == 'root':
gathered_nodes = comm.gather(local_node_ids, 0)
else:
raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
if gathered_nodes is None:
return None
else:
return np.unique(np.concatenate(gathered_nodes)).astype(np.uint)
[docs] def n_spikes(self, population=None, on_rank='all'):
from mpi4py import MPI
local_n = super(STMPIBuffer, self).n_spikes(population)
if on_rank == 'local':
return local_n
elif on_rank == 'all':
return comm.allreduce(local_n, MPI.SUM)
elif on_rank == 'root':
return comm.reduce(local_n, MPI.SUM)
else:
raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
[docs] def get_times(self, node_id, population=None, time_window=None, on_rank='all', **kwargs):
local_times = super(STMPIBuffer, self).get_times(node_id=node_id, population=population,
time_window=time_window, **kwargs)
if on_rank == 'local':
return local_times
elif on_rank == 'all':
all_times = comm.allgather(local_times)
elif on_rank == 'root':
all_times = comm.gather(local_times, 0)
else:
raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
if all_times is not None:
return np.sort(np.concatenate(all_times))
else:
return None
[docs] def spikes(self, populations=None, time_window=None, sort_order=SortOrder.none, on_rank='all', **kwargs):
if on_rank == 'local':
for i in super(STMPIBuffer, self).spikes(populations=populations, time_window=time_window,
sort_order=sort_order, **kwargs):
yield i
return
if populations is None:
populations = self.populations
elif np.isscalar(populations):
populations = [populations]
populations.sort()
for pop_name in populations:
node_ids, timestamps = self._gatherv(pop_name, on_all_ranks=(on_rank == 'all'))
if node_ids is None:
continue
if sort_order == SortOrder.by_id:
sort_indx = np.argsort(node_ids)
elif sort_order == SortOrder.by_time:
sort_indx = np.argsort(timestamps)
else:
sort_indx = range(len(timestamps))
filter = _create_filter(populations, time_window)
for i in sort_indx:
t = timestamps[i]
p = pop_name
if filter(p=p, t=t):
yield t, p, node_ids[i]
[docs]class STCSVBuffer(SpikeTrainsAPI):
""" A Class for creating, storing and reading multi-population spike-trains - especially for saving the spikes of a
large scale network simulation. Keeps a running tally of the (timestamp, population-name, node_id) for each
individual spike.
Uses a caching mechanism to periodically save spikes to the disk. Will encure a runtime performance penality but
will always have an upper bound on the maximum memory used.
If running parallel simulations should use the STMPIBuffer adaptor instead.
"""
def __init__(self, cache_dir=None, default_population=None, cache_name='spikes', **kwargs):
self._default_population = default_population or pop_na
# Keep a file handle open for writing spike information
self._cache_dir = cache_dir or '.'
self._cache_name = cache_name
self._buffer_filename = self._cache_fname(self._cache_dir)
self._buffer_handle = open(self._buffer_filename, 'w')
self._units = kwargs.get('units', 'ms')
self._pop_metadata = {}
self._spike_counts = 0 # all spikes added on rank, for each individual pop spike count stored in _pop_metadata
def _cache_fname(self, cache_dir):
# TODO: Potential problem if multiple SpikeTrains are opened at the same time, add salt to prevent collisions
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
return os.path.join(cache_dir, '.bmtk.{}.cache.csv'.format(self._cache_name))
[docs] def add_spike(self, node_id, timestamp, population=None, **kwargs):
population = population or self._default_population
# NOTE: I looked into using a in-memory buffer to save data and caching only when they reached a threshold,
# however on my computer it was actually slower than just calling file.write() each time. Likely the python
# file writer is more efficent than what I could write. However still would like to benchmark on a NSF.
self._buffer_handle.write('{} {} {}\n'.format(timestamp, population, node_id))
if population not in self._pop_metadata:
self._pop_metadata[population] = {'node_ids': set(), 'n_spikes': 0}
self._pop_metadata[population]['node_ids'].add(node_id)
self._pop_metadata[population]['n_spikes'] += 1
self._spike_counts += 1
[docs] def add_spikes(self, node_ids, timestamps, population=None, **kwargs):
if np.isscalar(node_ids):
for ts in timestamps:
self.add_spike(node_ids, ts, population)
else:
for node_id, ts in zip(node_ids, timestamps):
self.add_spike(node_id, ts, population)
@property
def populations(self):
return list(self._pop_metadata.keys())
[docs] def units(self, population=None):
return self._units
[docs] def set_units(self, u, population=None):
self._units = u
[docs] def node_ids(self, population=None):
population = population if population is not None else self._default_population
if population not in self._pop_metadata:
return []
return list(self._pop_metadata[population]['node_ids'])
[docs] def n_spikes(self, population=None):
population = population if population is not None else self._default_population
if population not in self._pop_metadata:
return 0
return self._pop_metadata[population]['n_spikes']
[docs] def time_range(self, populations=None):
return None # TODO: keep track of largest and smallest values
# return np.min(self._timestamps), np.max(self._timestamps)
[docs] def get_times(self, node_id, population=None, time_window=None, **kwargs):
self.flush()
population = population if population is not None else self._default_population
return np.array([t[0] for t in self.spikes(populations=population, time_window=time_window) if t[2] == node_id])
[docs] def to_dataframe(self, populations=None, sort_order=SortOrder.none, with_population_col=True, **kwargs):
self.flush()
sorting_cols = [col_population]
if sort_order == SortOrder.by_time:
sorting_cols = [col_population, col_timestamps]
elif sort_order == SortOrder.by_id:
sorting_cols = [col_population, col_node_ids]
ret_df = pd.read_csv(
self._buffer_filename, sep=' ', names=[col_timestamps, col_population, col_node_ids]
).sort_values(sorting_cols)
# filter by population
if np.isscalar(populations):
ret_df = ret_df[ret_df[col_population] == populations]
elif populations is not None:
ret_df = ret_df[ret_df[col_population].isin(populations)]
if not with_population_col:
ret_df = ret_df.drop(col_population, axis=1)
ret_df = ret_df.astype({col_timestamps: float, col_node_ids: np.int64})
return ret_df
[docs] def flush(self):
self._buffer_handle.flush()
# Found an issue with even after flushing the csv there can be a lag before data is actually cached to the disk.
# this can have problems with other processes on a different rank tries open the file that hasn't been
# completely saved. This hack should hopefully ensure that each rank has fully cached their spikes to disk.
for i in range(10):
with open(self._buffer_filename) as fh:
fcount = len(fh.readlines())
if fcount == self._spike_counts:
break
time.sleep(0.5)
else:
print('Warning: spike counts on rank {} cache does not match total added.'.format(MPI_rank))
[docs] def close(self):
self._buffer_handle.close()
if os.path.exists(self._buffer_filename):
os.remove(self._buffer_filename)
[docs] def spikes(self, populations=None, time_window=None, sort_order=SortOrder.none, **kwargs):
self.flush()
self._sort_buffer_file(self._buffer_filename, sort_order)
filter_fnc = _create_filter(populations, time_window)
with open(self._buffer_filename, 'r') as csvfile:
csv_reader = csv.reader(csvfile, delimiter=' ')
for row in csv_reader:
t = float(row[0])
p = row[1]
if filter_fnc(p=p, t=t):
yield t, p, int(row[2])
return
def _sort_buffer_file(self, file_name, sort_order):
# sort a spikes cache file
# Currently we just read "file_name" into a dataframe, sort it, and resave it to the file
if sort_order == SortOrder.by_time:
sort_col = 'time'
elif sort_order == SortOrder.by_id:
sort_col = 'node'
else:
return
tmp_spikes_ds = pd.read_csv(file_name, sep=' ', names=['time', 'population', 'node'])
tmp_spikes_ds = tmp_spikes_ds.sort_values(by=sort_col)
tmp_spikes_ds.to_csv(file_name, sep=' ', index=False, header=False)
[docs]class STCSVMPIBuffer(STCSVBuffer):
def __init__(self, cache_dir=None, default_population=None, cache_name='spikes', **kwargs):
self.mpi_rank = kwargs.get('MPI_rank', MPI_rank)
self.mpi_size = kwargs.get('MPI_size', MPI_size)
self._cache_name = cache_name
self._all_ranks_data = {}
super(STCSVMPIBuffer, self).__init__(cache_dir, default_population=default_population, **kwargs)
def _cache_fname(self, cache_dir):
if self.mpi_rank == 0:
if not os.path.exists(self._cache_dir):
os.makedirs(self._cache_dir)
comm_barrier()
return os.path.join(self._cache_dir, '.bmtk.{}.cache.node{}.csv'.format(self._cache_name, self.mpi_rank))
def _all_cached_files(self):
return [os.path.join(self._cache_dir, '.bmtk.{}.cache.node{}.csv'.format(self._cache_name, r))
for r in range(MPI_size)]
def _gather(self):
self._all_ranks_data = {}
for fn in self._all_cached_files():
if not os.path.exists(fn):
continue
with open(fn, 'r') as csvfile:
csv_reader = csv.reader(csvfile, delimiter=' ')
for row in csv_reader:
pop = row[1]
if pop not in self._all_ranks_data:
self._all_ranks_data[pop] = {'n_spikes': 0, 'node_ids': set()}
self._all_ranks_data[pop]['n_spikes'] += 1 # self._all_ranks_data.get(pop, 0) + 1
self._all_ranks_data[pop]['node_ids'].add(int(row[2]))
def _gather_times(self, node_id, population):
timestamps = []
for fn in self._all_cached_files():
if not os.path.exists(fn):
continue
with open(fn, 'r') as csvfile:
csv_reader = csv.reader(csvfile, delimiter=' ')
for row in csv_reader:
pop = row[1]
nid = int(row[2])
if nid == node_id and pop == population:
timestamps.append(float(row[0]))
return timestamps
@property
def populations(self):
return self.get_populations(on_rank='all')
[docs] def get_populations(self, on_rank='all'):
if on_rank == 'local':
pops = super(STCSVMPIBuffer, self).populations
pops.sort() # import populations are in the same order on all ranks
return pops
self.flush()
comm_barrier()
pops = None
if on_rank == 'all':
self._gather()
pops = list(self._all_ranks_data.keys())
elif on_rank == 'root':
if MPI_rank == 0:
self._gather()
pops = list(self._all_ranks_data.keys())
if pops is not None:
pops.sort()
return pops
[docs] def n_spikes(self, population=None, on_rank='all'):
if on_rank == 'local':
return super(STCSVMPIBuffer, self).n_spikes(population=population)
population = population if population is not None else self._default_population
self.flush()
comm_barrier()
if on_rank == 'all':
self._gather()
return self._all_ranks_data[population]['n_spikes'] if population in self._all_ranks_data else 0
elif on_rank == 'root':
if MPI_rank == 0:
self._gather()
return self._all_ranks_data[population]['n_spikes'] if population in self._all_ranks_data else 0
else:
return None
else:
raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
[docs] def node_ids(self, population=None, on_rank='all'):
if on_rank == 'local':
return super(STCSVMPIBuffer, self).node_ids(population=population)
population = population if population is not None else self._default_population
self.flush()
comm_barrier()
if on_rank == 'all':
self._gather()
return list(self._all_ranks_data[population]['node_ids']) if population in self._all_ranks_data else []
elif on_rank == 'root':
if MPI_rank == 0:
self._gather()
return list(self._all_ranks_data[population]['node_ids']) if population in self._all_ranks_data else []
else:
return None
else:
raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
[docs] def get_times(self, node_id, population=None, time_window=None, on_rank='all', **kwargs):
# population = population if population is not None else self._default_population
if on_rank == 'local':
# calling super.get_times() will fail since it relies on spikes()
return np.array([t[0] for t in super(STCSVMPIBuffer, self).spikes(
populations=population, time_window=time_window) if t[2] == node_id])
population = population if population is not None else self._default_population
self.flush()
comm_barrier()
if on_rank == 'all':
timestamps = self._gather_times(node_id=node_id, population=population)
elif on_rank == 'root':
timestamps = self._gather_times(node_id, population) if MPI_rank == 0 else None
else:
raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
if time_window is not None:
timestamps = [t for t in timestamps if time_window[0] <= t <= time_window[1]]
return timestamps
[docs] def spikes(self, populations=None, time_window=None, sort_order=SortOrder.none, on_rank='all', **kwargs):
if on_rank == 'local':
return super(STCSVMPIBuffer, self).spikes(populations=populations, time_window=time_window,
sort_order=sort_order, **kwargs)
self.flush()
comm_barrier()
if on_rank == 'all':
return self._sort_helper(populations, time_window, sort_order)
elif on_rank == 'root':
if MPI_rank == 0:
return self._sort_helper(populations, time_window, sort_order)
else:
return []
def _sort_helper(self, populations, time_window, sort_order):
filter_fnc = _create_filter(populations, time_window)
if sort_order == SortOrder.by_time or sort_order == SortOrder.by_id:
for file_name in self._all_cached_files():
if not os.path.exists(file_name):
continue
self._sort_buffer_file(file_name, sort_order)
return self._sorted_itr(filter_fnc, 0 if sort_order == SortOrder.by_time else 2)
else:
return self._unsorted_itr(filter_fnc)
def _unsorted_itr(self, filter_fnc):
for fn in self._all_cached_files():
if not os.path.exists(fn):
continue
with open(fn, 'r') as csvfile:
csv_reader = csv.reader(csvfile, delimiter=' ')
for row in csv_reader:
t = float(row[0])
p = row[1]
if filter_fnc(p=p, t=t):
yield t, p, int(row[2])
return
def _sorted_itr(self, filter_fnc, sort_col):
"""Iterates through all the spikes on each rank, returning them in the specified order"""
import heapq
def next_row(csv_reader):
try:
rn = next(csv_reader)
row = [float(rn[0]), rn[1], int(rn[2])]
return row[sort_col], row, csv_reader
except StopIteration:
return None
except ValueError as ie:
print(ie)
exit()
# Assumes all the ranked cached files have already been sorted. Pop the top row off of each rank onto the
# heap, pull next spike off the heap and replace. Repeat until all spikes on all ranks have been poped.
h = []
readers = [next_row(csv.reader(open(fn, 'r'), delimiter=' ')) for fn in self._all_cached_files()]
for r in readers:
if r is not None: # Some ranks may not have produced any spikes
heapq.heappush(h, r)
while h:
v, row, csv_reader = heapq.heappop(h)
n = next_row(csv_reader)
if n:
heapq.heappush(h, n)
if filter_fnc(row[1], row[2]):
yield row
[docs]class STCSVMPIBufferV2(STCSVMPIBuffer):
[docs] def to_dataframe(self, populations=None, sort_order=SortOrder.none, with_population_col=True, on_rank='all',
**kwargs):
if on_rank == 'local':
return super(STCSVMPIBufferV2, self).to_dataframe(populations=populations, sort_order=populations,
with_population_col=with_population_col, **kwargs)
if np.isscalar(populations):
populations = [populations] # so we can use dataframe.isin() later
ret_df = None
self.flush()
comm_barrier()
if on_rank == 'all':
cached_files = self._all_cached_files()
elif on_rank == 'root':
cached_files = self._all_cached_files() if MPI_rank == 0 else []
else:
raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
for file_name in cached_files:
if not os.path.exists(file_name):
continue
df = pd.read_csv(file_name, sep=' ', names=[col_timestamps, col_population, col_node_ids])
if populations is not None:
df = df[df[col_population].isin(populations)]
if not with_population_col:
df.drop(col_population, axis=1)
ret_df = df if ret_df is None else ret_df.append(df)
if ret_df is not None:
# pandas doesn't always do a good job of reading in the correct dtype for each column
ret_df = ret_df.astype({col_timestamps: float, col_node_ids: np.int64})
if sort_order == SortOrder.by_time:
ret_df = ret_df.sort_values(col_timestamps)
elif sort_order == SortOrder.by_id:
ret_df = ret_df.sort_values(col_node_ids)
comm_barrier()
return ret_df