Source code for bmtk.utils.reports.spike_trains.spike_train_buffer

# Copyright 2020. Allen Institute. All rights reserved
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
# following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
# disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
import os
import numpy as np
import pandas as pd
import csv
import time
from array import array

from .core import SortOrder, pop_na, comm, MPI_size, MPI_rank, comm_barrier
from .core import col_node_ids, col_population, col_timestamps
from .spike_trains_api import SpikeTrainsAPI

def _spikes_filter1(p, t, time_window, populations):
    return p in populations and time_window[0] <= t <= time_window[1]

def _spikes_filter2(p, t, populations):
    return p in populations

def _spikes_filter3(p, t, time_window):
    return time_window[0] <= t <= time_window[1]

def _create_filter(populations, time_window):
    from functools import partial

    if populations is None and time_window is None:
        return lambda p, t: True
    if populations is None:
        return partial(_spikes_filter3, time_window=time_window)

    populations = [populations] if np.isscalar(populations) else populations
    if time_window is None:
        return partial(_spikes_filter2, populations=populations)
        return partial(_spikes_filter1, populations=populations, time_window=time_window)

def _create_empty_df(with_population_col=True):
    columns = [col_timestamps, col_population, col_node_ids] if with_population_col else [col_timestamps, col_node_ids]
    return pd.DataFrame(columns=columns)

[docs]class STMemoryBuffer(SpikeTrainsAPI): """ A Class for creating, storing and reading multi-population spike-trains - especially for saving the spikes of a large scale network simulation. Keeps a running tally of the (timestamp, population-name, node_id) for each individual spike. The spikes are stored in memory and very large and/or epiletic simulations may run into memory issues. Not designed to work with parallel simulations. """ def __init__(self, default_population=None, store_type='array', **kwargs): self._default_population = default_population or kwargs.get('population', None) or pop_na self._store_type = store_type # self._pop_counts = {self._default_population: 0} # A count of spikes per population self._units = kwargs.get('units', 'ms') # for backwards compatability default to milliseconds self._pops = {}
[docs] def add_spike(self, node_id, timestamp, population=None, **kwargs): population = population or self._default_population if population not in self._pops: self._create_store(population) self._pops[population][col_node_ids].append(node_id) self._pops[population][col_timestamps].append(timestamp)
[docs] def add_spikes(self, node_ids, timestamps, population=None, **kwargs): population = population or self._default_population if np.isscalar(node_ids): node_ids = [node_ids]*len(timestamps) if len(node_ids) != len(timestamps): raise ValueError('node_ids and timestamps must by of the same length') if population not in self._pops: self._create_store(population) pop_data = self._pops[population] pop_data[col_node_ids].extend(node_ids) pop_data[col_timestamps].extend(timestamps)
def _create_store(self, population): """Helper for creating storage data struct of a population, so add_spike/add_spikes is consistent.""" # Benchmark Notes: # Tested with numpy, lists and arrays. np.concate/append is too slow to consider. regular list is ~2-3x # faster than array, but require 2-4x the amount of memory. For larger and parallelized applications # (> 100 million spikes) use array since the amount of memory can required can exceed amount available. But # if memory is not an issue use list. if self._store_type == 'list': self._pops[population] = {col_node_ids: [], col_timestamps: []} elif self._store_type == 'array': self._pops[population] = {col_node_ids: array('I'), col_timestamps: array('d')} else: raise AttributeError('Uknown store type {} for SpikeTrains'.format(self._store_type))
[docs] def import_spikes(self, obj, **kwargs): pass
[docs] def flush(self): pass # not necessary since everything is stored in memory
[docs] def close(self): pass # don't need to do anything
@property def populations(self): return list(self._pops.keys())
[docs] def node_ids(self, population=None): population = population if population is not None else self._default_population if population not in self._pops: return [] return np.unique(self._pops[population][col_node_ids]).astype(np.uint)
[docs] def units(self, population=None): return self._units
[docs] def set_units(self, v, population=None): self._units = v
[docs] def n_spikes(self, population=None): population = population if population is not None else self._default_population if population not in self._pops: return 0 return len(self._pops[population][col_timestamps])
[docs] def get_times(self, node_id, population=None, time_window=None, **kwargs): population = population if population is not None else self._default_population pop = self._pops[population] # filter by node_id and (if specified) by time. mask = np.array(pop[col_node_ids]) == node_id ts = np.array(pop[col_timestamps]) if time_window: mask &= (time_window[0] <= ts) & (ts <= time_window[1]) return ts[mask]
[docs] def to_dataframe(self, populations=None, sort_order=SortOrder.none, with_population_col=True, **kwargs): if populations is None: selelectd_pops = list(self.populations) elif np.isscalar(populations): selelectd_pops = [populations] else: selelectd_pops = populations ret_df = None for pop_name in selelectd_pops: pop_data = self._pops.get(pop_name, {col_node_ids: [], col_timestamps: []}) pop_df = pd.DataFrame({ col_node_ids: pop_data[col_node_ids], col_timestamps: pop_data[col_timestamps] }) if with_population_col: pop_df[col_population] = pop_name if sort_order == SortOrder.by_id: pop_df = pop_df.sort_values(col_node_ids) elif sort_order == SortOrder.by_time: pop_df = pop_df.sort_values(col_timestamps) if ret_df is None: ret_df = pop_df else: ret_df = pd.concat((ret_df, pop_df)) # Make sure ret_df is not None ret_df = _create_empty_df(with_population_col) if ret_df is None else ret_df return ret_df
[docs] def spikes(self, populations=None, time_window=None, sort_order=SortOrder.none, **kwargs): if populations is None: populations = self.populations elif np.isscalar(populations): populations = [populations] for pop_name in populations: pop_data = self._pops.get(pop_name, {col_node_ids: [], col_timestamps: []}) timestamps = pop_data[col_timestamps] node_ids = pop_data[col_node_ids] if sort_order == SortOrder.by_id: sort_indx = np.argsort(node_ids) elif sort_order == SortOrder.by_time: sort_indx = np.argsort(timestamps) else: sort_indx = range(len(timestamps)) filter = _create_filter(populations, time_window) for i in sort_indx: t = timestamps[i] p = pop_name if filter(p=p, t=t): yield t, p, node_ids[i] return
def __len__(self): return len(self.to_dataframe())
[docs]class STMPIBuffer(STMemoryBuffer): def __init__(self, default_population=None, store_type='array', **kwargs): self.mpi_rank = kwargs.get('MPI_rank', MPI_rank) self.mpi_size = kwargs.get('MPI_size', MPI_size) super(STMPIBuffer, self).__init__(default_population=default_population, store_type=store_type, **kwargs) def _gatherv(self, population, on_all_ranks=True): from mpi4py import MPI local_n_spikes = super(STMPIBuffer, self).n_spikes(population) sizes = comm.allgather(local_n_spikes) offsets = np.zeros(MPI_size, dtype=np.int64) offsets[1:] = np.cumsum(sizes)[:-1] all_n_spikes = np.sum(sizes) local_population = self._pops.get(population, {col_node_ids: [], col_timestamps: []}) # if pop not on rank local_node_ids = np.array(local_population[col_node_ids], dtype=np.uint64) all_node_ids = np.zeros(all_n_spikes, dtype=np.uint64) if on_all_ranks: comm.Allgatherv(local_node_ids, [all_node_ids, sizes, offsets, MPI.UINT64_T]) else: comm.Gatherv(local_node_ids, [all_node_ids, sizes, offsets, MPI.UINT64_T], root=0) if MPI_rank != 0: all_node_ids = None local_timestamps = np.array(local_population[col_timestamps], dtype=np.double) all_timestamps = np.zeros(all_n_spikes, dtype=np.double) if on_all_ranks: comm.Allgatherv(local_timestamps, [all_timestamps, sizes, offsets, MPI.DOUBLE]) else: comm.Gatherv(local_timestamps, [all_timestamps, sizes, offsets, MPI.DOUBLE], root=0) if MPI_rank != 0: all_timestamps = None return all_node_ids, all_timestamps
[docs] def to_dataframe(self, populations=None, sort_order=SortOrder.none, with_population_col=True, on_rank='all', **kwargs): comm_barrier() if on_rank == 'local': return super(STMPIBuffer, self).to_dataframe(populations=populations, sort_order=sort_order, with_population_col=with_population_col, **kwargs) if on_rank not in ['local', 'all', 'root']: raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank)) if populations is None: selelectd_pops = list(self.get_populations(on_rank='all')) elif np.isscalar(populations): selelectd_pops = [populations] else: selelectd_pops = populations # Make sure the list of populations is exactly the same (including order) across all ranks so that # _gatherv is called in the same sequence across all ranks. selelectd_pops.sort() ret_df = None for pop_name in selelectd_pops: if on_rank == 'all': node_ids, timestamps = self._gatherv(population=pop_name, on_all_ranks=True) pop_df = pd.DataFrame({ col_node_ids: node_ids, col_timestamps: timestamps }) if with_population_col: pop_df[col_population] = pop_name if sort_order == SortOrder.by_id: pop_df = pop_df.sort_values(col_node_ids) elif sort_order == SortOrder.by_time: pop_df = pop_df.sort_values(col_timestamps) if ret_df is None: ret_df = pop_df else: ret_df = pd.concat((ret_df, pop_df)) elif on_rank == 'root': node_ids, timestamps = self._gatherv(population=pop_name, on_all_ranks=False) if MPI_rank != 0: continue pop_df = pd.DataFrame({ col_node_ids: node_ids, col_timestamps: timestamps }) if with_population_col: pop_df[col_population] = pop_name if sort_order == SortOrder.by_id: pop_df = pop_df.sort_values(col_node_ids) elif sort_order == SortOrder.by_time: pop_df = pop_df.sort_values(col_timestamps) if ret_df is None: ret_df = pop_df else: ret_df = pd.concat((ret_df, pop_df)) comm_barrier() if on_rank == 'all' or MPI_rank == 0: # If using 'all' or on rank 0 a dataframe is expected even if there are no spikes ret_df = _create_empty_df(with_population_col) if ret_df is None else ret_df return ret_df
@property def populations(self): return self.get_populations(on_rank='all')
[docs] def get_populations(self, on_rank='all'): local_pops = list(super(STMPIBuffer, self).populations) if on_rank == 'local': return local_pops if on_rank == 'all': gathered_pops = comm.allgather(local_pops) elif on_rank == 'root': gathered_pops = comm.gather(local_pops, 0) else: raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank)) if gathered_pops is None: return None else: all_populations = set() for pops in gathered_pops: all_populations |= set(pops) # WARNING: For a number of parallel applications it's important that the list of populations returned # is the same across all ranks (eg ranks don't iterate through each population in different sequences) all_populations = list(all_populations) all_populations.sort() return all_populations
[docs] def node_ids(self, population=None, on_rank='all'): local_node_ids = super(STMPIBuffer, self).node_ids(population) if on_rank == 'local': return local_node_ids if on_rank == 'all': gathered_nodes = comm.allgather(local_node_ids) elif on_rank == 'root': gathered_nodes = comm.gather(local_node_ids, 0) else: raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank)) if gathered_nodes is None: return None else: return np.unique(np.concatenate(gathered_nodes)).astype(np.uint)
[docs] def n_spikes(self, population=None, on_rank='all'): from mpi4py import MPI local_n = super(STMPIBuffer, self).n_spikes(population) if on_rank == 'local': return local_n elif on_rank == 'all': return comm.allreduce(local_n, MPI.SUM) elif on_rank == 'root': return comm.reduce(local_n, MPI.SUM) else: raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
[docs] def get_times(self, node_id, population=None, time_window=None, on_rank='all', **kwargs): local_times = super(STMPIBuffer, self).get_times(node_id=node_id, population=population, time_window=time_window, **kwargs) if on_rank == 'local': return local_times elif on_rank == 'all': all_times = comm.allgather(local_times) elif on_rank == 'root': all_times = comm.gather(local_times, 0) else: raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank)) if all_times is not None: return np.sort(np.concatenate(all_times)) else: return None
[docs] def spikes(self, populations=None, time_window=None, sort_order=SortOrder.none, on_rank='all', **kwargs): if on_rank == 'local': for i in super(STMPIBuffer, self).spikes(populations=populations, time_window=time_window, sort_order=sort_order, **kwargs): yield i return if populations is None: populations = self.populations elif np.isscalar(populations): populations = [populations] populations.sort() for pop_name in populations: node_ids, timestamps = self._gatherv(pop_name, on_all_ranks=(on_rank == 'all')) if node_ids is None: continue if sort_order == SortOrder.by_id: sort_indx = np.argsort(node_ids) elif sort_order == SortOrder.by_time: sort_indx = np.argsort(timestamps) else: sort_indx = range(len(timestamps)) filter = _create_filter(populations, time_window) for i in sort_indx: t = timestamps[i] p = pop_name if filter(p=p, t=t): yield t, p, node_ids[i]
[docs]class STCSVBuffer(SpikeTrainsAPI): """ A Class for creating, storing and reading multi-population spike-trains - especially for saving the spikes of a large scale network simulation. Keeps a running tally of the (timestamp, population-name, node_id) for each individual spike. Uses a caching mechanism to periodically save spikes to the disk. Will encure a runtime performance penality but will always have an upper bound on the maximum memory used. If running parallel simulations should use the STMPIBuffer adaptor instead. """ def __init__(self, cache_dir=None, default_population=None, cache_name='spikes', **kwargs): self._default_population = default_population or pop_na # Keep a file handle open for writing spike information self._cache_dir = cache_dir or '.' self._cache_name = cache_name self._buffer_filename = self._cache_fname(self._cache_dir) self._buffer_handle = open(self._buffer_filename, 'w') self._units = kwargs.get('units', 'ms') self._pop_metadata = {} self._spike_counts = 0 # all spikes added on rank, for each individual pop spike count stored in _pop_metadata def _cache_fname(self, cache_dir): # TODO: Potential problem if multiple SpikeTrains are opened at the same time, add salt to prevent collisions if not os.path.exists(cache_dir): os.makedirs(cache_dir) return os.path.join(cache_dir, '.bmtk.{}.cache.csv'.format(self._cache_name))
[docs] def add_spike(self, node_id, timestamp, population=None, **kwargs): population = population or self._default_population # NOTE: I looked into using a in-memory buffer to save data and caching only when they reached a threshold, # however on my computer it was actually slower than just calling file.write() each time. Likely the python # file writer is more efficent than what I could write. However still would like to benchmark on a NSF. self._buffer_handle.write('{} {} {}\n'.format(timestamp, population, node_id)) if population not in self._pop_metadata: self._pop_metadata[population] = {'node_ids': set(), 'n_spikes': 0} self._pop_metadata[population]['node_ids'].add(node_id) self._pop_metadata[population]['n_spikes'] += 1 self._spike_counts += 1
[docs] def add_spikes(self, node_ids, timestamps, population=None, **kwargs): if np.isscalar(node_ids): for ts in timestamps: self.add_spike(node_ids, ts, population) else: for node_id, ts in zip(node_ids, timestamps): self.add_spike(node_id, ts, population)
@property def populations(self): return list(self._pop_metadata.keys())
[docs] def units(self, population=None): return self._units
[docs] def set_units(self, u, population=None): self._units = u
[docs] def node_ids(self, population=None): population = population if population is not None else self._default_population if population not in self._pop_metadata: return [] return list(self._pop_metadata[population]['node_ids'])
[docs] def n_spikes(self, population=None): population = population if population is not None else self._default_population if population not in self._pop_metadata: return 0 return self._pop_metadata[population]['n_spikes']
[docs] def time_range(self, populations=None): return None # TODO: keep track of largest and smallest values
# return np.min(self._timestamps), np.max(self._timestamps)
[docs] def get_times(self, node_id, population=None, time_window=None, **kwargs): self.flush() population = population if population is not None else self._default_population return np.array([t[0] for t in self.spikes(populations=population, time_window=time_window) if t[2] == node_id])
[docs] def to_dataframe(self, populations=None, sort_order=SortOrder.none, with_population_col=True, **kwargs): self.flush() sorting_cols = [col_population] if sort_order == SortOrder.by_time: sorting_cols = [col_population, col_timestamps] elif sort_order == SortOrder.by_id: sorting_cols = [col_population, col_node_ids] ret_df = pd.read_csv( self._buffer_filename, sep=' ', names=[col_timestamps, col_population, col_node_ids] ).sort_values(sorting_cols) # filter by population if np.isscalar(populations): ret_df = ret_df[ret_df[col_population] == populations] elif populations is not None: ret_df = ret_df[ret_df[col_population].isin(populations)] if not with_population_col: ret_df = ret_df.drop(col_population, axis=1) ret_df = ret_df.astype({col_timestamps: float, col_node_ids: np.int64}) return ret_df
[docs] def flush(self): self._buffer_handle.flush() # Found an issue with even after flushing the csv there can be a lag before data is actually cached to the disk. # this can have problems with other processes on a different rank tries open the file that hasn't been # completely saved. This hack should hopefully ensure that each rank has fully cached their spikes to disk. for i in range(10): with open(self._buffer_filename) as fh: fcount = len(fh.readlines()) if fcount == self._spike_counts: break time.sleep(0.5) else: print('Warning: spike counts on rank {} cache does not match total added.'.format(MPI_rank))
[docs] def close(self): self._buffer_handle.close() if os.path.exists(self._buffer_filename): os.remove(self._buffer_filename)
[docs] def spikes(self, populations=None, time_window=None, sort_order=SortOrder.none, **kwargs): self.flush() self._sort_buffer_file(self._buffer_filename, sort_order) filter_fnc = _create_filter(populations, time_window) with open(self._buffer_filename, 'r') as csvfile: csv_reader = csv.reader(csvfile, delimiter=' ') for row in csv_reader: t = float(row[0]) p = row[1] if filter_fnc(p=p, t=t): yield t, p, int(row[2]) return
def _sort_buffer_file(self, file_name, sort_order): # sort a spikes cache file # Currently we just read "file_name" into a dataframe, sort it, and resave it to the file if sort_order == SortOrder.by_time: sort_col = 'time' elif sort_order == SortOrder.by_id: sort_col = 'node' else: return tmp_spikes_ds = pd.read_csv(file_name, sep=' ', names=['time', 'population', 'node']) tmp_spikes_ds = tmp_spikes_ds.sort_values(by=sort_col) tmp_spikes_ds.to_csv(file_name, sep=' ', index=False, header=False)
[docs]class STCSVMPIBuffer(STCSVBuffer): def __init__(self, cache_dir=None, default_population=None, cache_name='spikes', **kwargs): self.mpi_rank = kwargs.get('MPI_rank', MPI_rank) self.mpi_size = kwargs.get('MPI_size', MPI_size) self._cache_name = cache_name self._all_ranks_data = {} super(STCSVMPIBuffer, self).__init__(cache_dir, default_population=default_population, **kwargs) def _cache_fname(self, cache_dir): if self.mpi_rank == 0: if not os.path.exists(self._cache_dir): os.makedirs(self._cache_dir) comm_barrier() return os.path.join(self._cache_dir, '.bmtk.{}.cache.node{}.csv'.format(self._cache_name, self.mpi_rank)) def _all_cached_files(self): return [os.path.join(self._cache_dir, '.bmtk.{}.cache.node{}.csv'.format(self._cache_name, r)) for r in range(MPI_size)] def _gather(self): self._all_ranks_data = {} for fn in self._all_cached_files(): if not os.path.exists(fn): continue with open(fn, 'r') as csvfile: csv_reader = csv.reader(csvfile, delimiter=' ') for row in csv_reader: pop = row[1] if pop not in self._all_ranks_data: self._all_ranks_data[pop] = {'n_spikes': 0, 'node_ids': set()} self._all_ranks_data[pop]['n_spikes'] += 1 # self._all_ranks_data.get(pop, 0) + 1 self._all_ranks_data[pop]['node_ids'].add(int(row[2])) def _gather_times(self, node_id, population): timestamps = [] for fn in self._all_cached_files(): if not os.path.exists(fn): continue with open(fn, 'r') as csvfile: csv_reader = csv.reader(csvfile, delimiter=' ') for row in csv_reader: pop = row[1] nid = int(row[2]) if nid == node_id and pop == population: timestamps.append(float(row[0])) return timestamps @property def populations(self): return self.get_populations(on_rank='all')
[docs] def get_populations(self, on_rank='all'): if on_rank == 'local': pops = super(STCSVMPIBuffer, self).populations pops.sort() # import populations are in the same order on all ranks return pops self.flush() comm_barrier() pops = None if on_rank == 'all': self._gather() pops = list(self._all_ranks_data.keys()) elif on_rank == 'root': if MPI_rank == 0: self._gather() pops = list(self._all_ranks_data.keys()) if pops is not None: pops.sort() return pops
[docs] def n_spikes(self, population=None, on_rank='all'): if on_rank == 'local': return super(STCSVMPIBuffer, self).n_spikes(population=population) population = population if population is not None else self._default_population self.flush() comm_barrier() if on_rank == 'all': self._gather() return self._all_ranks_data[population]['n_spikes'] if population in self._all_ranks_data else 0 elif on_rank == 'root': if MPI_rank == 0: self._gather() return self._all_ranks_data[population]['n_spikes'] if population in self._all_ranks_data else 0 else: return None else: raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
[docs] def node_ids(self, population=None, on_rank='all'): if on_rank == 'local': return super(STCSVMPIBuffer, self).node_ids(population=population) population = population if population is not None else self._default_population self.flush() comm_barrier() if on_rank == 'all': self._gather() return list(self._all_ranks_data[population]['node_ids']) if population in self._all_ranks_data else [] elif on_rank == 'root': if MPI_rank == 0: self._gather() return list(self._all_ranks_data[population]['node_ids']) if population in self._all_ranks_data else [] else: return None else: raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank))
[docs] def get_times(self, node_id, population=None, time_window=None, on_rank='all', **kwargs): # population = population if population is not None else self._default_population if on_rank == 'local': # calling super.get_times() will fail since it relies on spikes() return np.array([t[0] for t in super(STCSVMPIBuffer, self).spikes( populations=population, time_window=time_window) if t[2] == node_id]) population = population if population is not None else self._default_population self.flush() comm_barrier() if on_rank == 'all': timestamps = self._gather_times(node_id=node_id, population=population) elif on_rank == 'root': timestamps = self._gather_times(node_id, population) if MPI_rank == 0 else None else: raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank)) if time_window is not None: timestamps = [t for t in timestamps if time_window[0] <= t <= time_window[1]] return timestamps
[docs] def spikes(self, populations=None, time_window=None, sort_order=SortOrder.none, on_rank='all', **kwargs): if on_rank == 'local': return super(STCSVMPIBuffer, self).spikes(populations=populations, time_window=time_window, sort_order=sort_order, **kwargs) self.flush() comm_barrier() if on_rank == 'all': return self._sort_helper(populations, time_window, sort_order) elif on_rank == 'root': if MPI_rank == 0: return self._sort_helper(populations, time_window, sort_order) else: return []
def _sort_helper(self, populations, time_window, sort_order): filter_fnc = _create_filter(populations, time_window) if sort_order == SortOrder.by_time or sort_order == SortOrder.by_id: for file_name in self._all_cached_files(): if not os.path.exists(file_name): continue self._sort_buffer_file(file_name, sort_order) return self._sorted_itr(filter_fnc, 0 if sort_order == SortOrder.by_time else 2) else: return self._unsorted_itr(filter_fnc) def _unsorted_itr(self, filter_fnc): for fn in self._all_cached_files(): if not os.path.exists(fn): continue with open(fn, 'r') as csvfile: csv_reader = csv.reader(csvfile, delimiter=' ') for row in csv_reader: t = float(row[0]) p = row[1] if filter_fnc(p=p, t=t): yield t, p, int(row[2]) return def _sorted_itr(self, filter_fnc, sort_col): """Iterates through all the spikes on each rank, returning them in the specified order""" import heapq def next_row(csv_reader): try: rn = next(csv_reader) row = [float(rn[0]), rn[1], int(rn[2])] return row[sort_col], row, csv_reader except StopIteration: return None except ValueError as ie: print(ie) exit() # Assumes all the ranked cached files have already been sorted. Pop the top row off of each rank onto the # heap, pull next spike off the heap and replace. Repeat until all spikes on all ranks have been poped. h = [] readers = [next_row(csv.reader(open(fn, 'r'), delimiter=' ')) for fn in self._all_cached_files()] for r in readers: if r is not None: # Some ranks may not have produced any spikes heapq.heappush(h, r) while h: v, row, csv_reader = heapq.heappop(h) n = next_row(csv_reader) if n: heapq.heappush(h, n) if filter_fnc(row[1], row[2]): yield row
[docs]class STCSVMPIBufferV2(STCSVMPIBuffer):
[docs] def to_dataframe(self, populations=None, sort_order=SortOrder.none, with_population_col=True, on_rank='all', **kwargs): if on_rank == 'local': return super(STCSVMPIBufferV2, self).to_dataframe(populations=populations, sort_order=populations, with_population_col=with_population_col, **kwargs) if np.isscalar(populations): populations = [populations] # so we can use dataframe.isin() later ret_df = None self.flush() comm_barrier() if on_rank == 'all': cached_files = self._all_cached_files() elif on_rank == 'root': cached_files = self._all_cached_files() if MPI_rank == 0 else [] else: raise ValueError('Invalid option "{}" for mpi on_rank parameter'.format(on_rank)) for file_name in cached_files: if not os.path.exists(file_name): continue df = pd.read_csv(file_name, sep=' ', names=[col_timestamps, col_population, col_node_ids]) if populations is not None: df = df[df[col_population].isin(populations)] if not with_population_col: df.drop(col_population, axis=1) ret_df = df if ret_df is None else ret_df.append(df) if ret_df is not None: # pandas doesn't always do a good job of reading in the correct dtype for each column ret_df = ret_df.astype({col_timestamps: float, col_node_ids: np.int64}) if sort_order == SortOrder.by_time: ret_df = ret_df.sort_values(col_timestamps) elif sort_order == SortOrder.by_id: ret_df = ret_df.sort_values(col_node_ids) comm_barrier() return ret_df