Source code for bmtk.utils.reports.spike_trains.spike_trains

# Copyright 2020. Allen Institute. All rights reserved
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
# following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
# disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import numpy as np
from six import string_types

from .core import find_file_type, MPI_size
from .spike_train_readers import load_sonata_file, CSVSTReader, NWBSTReader
from .spike_train_buffer import STMemoryBuffer, STCSVBuffer, STMPIBuffer, STCSVMPIBufferV2
from bmtk.utils.sonata.utils import get_node_ids
from scipy.stats import gamma
import warnings

[docs]class SpikeTrains(object): """A class for creating and reading spike files. """ def __init__(self, spikes_adaptor=None, **kwargs): # There are a number of strategies for reading and writing spike trains, depending on memory limitations, if # MPI is being used, or if there if read-only from disk. I'm using a decorator/adaptor pattern and moving # the actual functionality to the Buffered/ReadOnly classes that implement SpikeTrainsAPI. if spikes_adaptor is not None: self.adaptor = spikes_adaptor else: # TODO: Check that comm has gather, reduce, etc methods; if not can't use STMPIBuffer use_mpi = MPI_size > 1 cache_to_disk = 'cache_dir' in kwargs and kwargs.get('cache_to_disk', True) if use_mpi and cache_to_disk: self.adaptor = STCSVMPIBufferV2(**kwargs) elif cache_to_disk: self.adaptor = STCSVBuffer(**kwargs) elif use_mpi: self.adaptor = STMPIBuffer(**kwargs) else: self.adaptor = STMemoryBuffer(**kwargs)
[docs] @classmethod def from_csv(cls, path, **kwargs): return cls(spikes_adaptor=CSVSTReader(path, **kwargs))
[docs] @classmethod def from_sonata(cls, path, **kwargs): sonata_adaptor = load_sonata_file(path, **kwargs) return cls(spikes_adaptor=sonata_adaptor)
[docs] @classmethod def from_nwb(cls, path, **kwargs): return cls(spikes_adaptor=NWBSTReader(path, **kwargs))
[docs] @classmethod def load(cls, path, file_type=None, **kwargs): file_type = file_type.lower() if file_type else find_file_type(path) if file_type == 'h5' or file_type == 'sonata': return cls.from_sonata(path, **kwargs) elif file_type == 'nwb': return cls.from_nwb(path, **kwargs) elif file_type == 'csv': return cls.from_csv(path, **kwargs)
def __getattr__(self, item): return getattr(self.adaptor, item) def __setattr__(self, key, value): if key == 'adaptor': self.__dict__[key] = value else: self.adaptor.__dict__[key] = value def __len__(self): return self.adaptor.__len__() def __eq__(self, other): return self.adaptor.__eq__(other) def __lt__(self, other): return self.adaptor.__lt__(other) def __le__(self, other): return self.adaptor.__le__(other) def __gt__(self, other): return self.adaptor.__gt__(other) def __ge__(self, other): return self.adaptor.__ge__(other) def __ne__(self, other): return self.adaptor.__ne__(other)
[docs]class SpikeGenerator(SpikeTrains): def __init__(self, population=None, seed=None, output_units='ms', **kwargs): max_spikes_per_node = 10000000 if population is not None and 'default_population' not in kwargs: kwargs['default_population']= population if seed: np.random.seed(seed) super(SpikeGenerator, self).__init__(units=output_units, **kwargs) # self.units = units if output_units.lower() in ['ms', 'millisecond', 'milliseconds']: self._units = 'ms' self.output_conversion = 1000.0 elif output_units.lower() in ['s', 'second', 'seconds']: self._units = 's' self.output_conversion = 1.0 else: raise AttributeError('Unknown output_units value {}'.format(output_units))
[docs]class PoissonSpikeGenerator(SpikeGenerator): """ A Class for generating spike-trains with a homogeneous and inhomogeneous Poisson distribution. Uses the methods describe in Dayan and Abbott, 2001. """ def __init__(self, population=None, seed=None, output_units='ms', **kwargs): super(PoissonSpikeGenerator, self).__init__(population, seed, output_units, **kwargs)
[docs] def add(self, node_ids, firing_rate, population=None, times=(0.0, 1.0), abs_ref=0, tau_ref=0): """ :param firing_rate: Scalar stationary firing rate or array of values for inhomogeneous (Hz) :param times: Start and end time for spike train (s) :param abs_ref: Absolute refractory period (s) :param tau_ref: Relative refractory period time constant for exponential recovery (s) """ if tau_ref < 0: raise ValueError('Refractory period time constant (sec) cannot be negative.') if abs_ref < 0: raise ValueError('Absolute refractory period (sec) cannot be negative.') if isinstance(node_ids, string_types): # if user passes in path to nodes.h5 file count number of nodes node_ids = get_node_ids(node_ids, population) if np.isscalar(node_ids): # In case user passes in single node_id node_ids = [node_ids] if np.isscalar(firing_rate): self._build_fixed_fr(node_ids, population, firing_rate, times, abs_ref, tau_ref) elif isinstance(firing_rate, (list, np.ndarray)): self._build_inhomogeneous_fr(node_ids, population, firing_rate, times, abs_ref, tau_ref)
[docs] def time_range(self, population=None): df = self.to_dataframe(populations=population, with_population_col=False) timestamps = df['timestamps'] return np.min(timestamps), np.max(timestamps)
def _build_fixed_fr(self, node_ids, population, fr, times, abs_ref, tau_ref): if np.isscalar(times) and times > 0.0: tstart = 0.0 tstop = times else: tstart = times[0] tstop = times[-1] if tstart >= tstop: raise ValueError('Invalid start and stop times.') if fr < 0: raise ValueError('Firing rates must not be negative.') # If there are refractory properties, correct starting firing rate and check firing limit if abs_ref!=0 or tau_ref!=0: max_fr_lim, p = fr_corr(abs_ref, tau_ref) if fr > max_fr_lim: raise ValueError(f'Cannot achieve firing rate above {max_fr_lim} with these absolute' f' and relative refractory properties. Also consider using' f' the GammaSpikeGenerator instead.') fr = p(fr) #rs2 = np.random.RandomState(0) count = 0 for node_id in node_ids: c_time = tstart while True: interval = -np.log(1.0 - np.random.uniform()) / fr preceding_time = c_time c_time += interval if c_time > tstop: break if tau_ref != 0: w = 1 - np.exp(-(interval-abs_ref)/tau_ref) else: w = 1 # To avoid divide by zero warning if abs_ref != 0: w = w*(interval>abs_ref) #if (w == 1) or (rs2.uniform() < w): if (w == 1) or (np.random.uniform() < w): self.add_spike(node_id=node_id, population=population, timestamp=c_time*self.output_conversion) count = count+1 def _build_inhomogeneous_fr(self, node_ids, population, fr, times, abs_ref, tau_ref): if np.min(fr) < 0: raise ValueError('Firing rates must not be negative') if len(fr) != len(times): raise ValueError('If using a time series for firing rate, times must be an array of equal length') max_fr = np.max(fr) max_fr_lim, p = fr_corr(abs_ref, tau_ref) if max_fr > max_fr_lim: raise ValueError(f'Cannot achieve firing rate above {max_fr_lim} with these absolute and relative refractory properties') fr = p(fr) times = times tstart = times[0] tstop = times[-1] for node_id in node_ids: c_time = tstart time_indx = 0 while True: # Using the pruning method, see Dayan and Abbott Ch 2 interval = -np.log(1.0 - np.random.uniform()) / max_fr preceding_time = c_time c_time += interval if c_time > tstop: break if tau_ref != 0: w = 1 - np.exp(-(interval-abs_ref)/tau_ref) else: w = 1 # To avoid divide by zero warning if abs_ref != 0: w = w * (interval > abs_ref) # A spike occurs at t_i, find index j st times[j-1] < t_i < times[j], and interpolate the firing rates # using fr[j-1] and fr[j] while times[time_indx] <= c_time: time_indx += 1 fr_i = _interpolate_fr(c_time, times[time_indx-1], times[time_indx], fr[time_indx-1], fr[time_indx]) if not fr_i/max_fr*w < np.random.uniform(): self.add_spike(node_id=node_id, population=population, timestamp=c_time*self.output_conversion)
[docs]class GammaSpikeGenerator(SpikeGenerator): """ A Class for generating spike-trains based on a gamma-distributed renewal process. """ def __init__(self, population=None, seed=None, output_units='ms', **kwargs): super(GammaSpikeGenerator, self).__init__(population, seed, output_units, **kwargs)
[docs] def add(self, node_ids, firing_rate, a, population=None, times=(0.0, 1.0)): """ :param firing_rate: Stationary firing rate (Hz) :param a: Shape parameter (a>0). For a=1, this becomes a Poisson distribution. :param times: Start and end time for spike train (s) """ if isinstance(node_ids, string_types): # if user passes in path to nodes.h5 file count number of nodes node_ids = get_node_ids(node_ids, population) if np.isscalar(node_ids): # In case user passes in single node_id node_ids = [node_ids] if np.isscalar(firing_rate): self._build_fixed_fr(node_ids, population, firing_rate, a, times) elif isinstance(firing_rate, (list, np.ndarray)): raise Exception('Firing rate must be stationary for GammaSpikeGenerator')
[docs] def time_range(self, population=None): df = self.to_dataframe(populations=population, with_population_col=False) timestamps = df['timestamps'] return np.min(timestamps), np.max(timestamps)
def _build_fixed_fr(self, node_ids, population, fr, a, times): if np.isscalar(times) and times > 0.0: tstart = 0.0 tstop = times else: tstart = times[0] tstop = times[-1] if tstart >= tstop: raise ValueError('Invalid start and stop times.') if fr < 0: raise Exception('Firing rates must not be negative.') if a < 0: raise ValueError('Shape parameter `a` cannot be negative.') for node_id in node_ids: c_time = tstart n_spikes_avg = (tstop-tstart)*fr intervals = gamma.rvs(a, loc=0, scale=1 / (a * fr), size=round(n_spikes_avg * 1.5)) for i in range(len(intervals)): preceding_time = c_time c_time += intervals[i] if c_time > tstop: break self.add_spike(node_id=node_id, population=population, timestamp=c_time*self.output_conversion)
def _interpolate_fr(t, t0, t1, fr0, fr1): # Used to interpolate the firing rate at time t from a discrete list of firing rates return fr0 + (fr1 - fr0)*(t - t0)/(t1 - t0)
[docs]def fr_corr(abs_ref, tau_ref): # Firing rate correction for lost spikes if tau_ref != 0: def f(fr_before, abs_ref=abs_ref, tau_ref=tau_ref): return (np.exp(-fr_before * abs_ref) - \ fr_before * np.exp(abs_ref / tau_ref) * tau_ref / (fr_before * tau_ref + 1) * \ (np.exp(-(fr_before + 1 / tau_ref) * abs_ref))) * fr_before else: def f(fr_before, abs_ref=abs_ref): return np.exp(-fr_before * abs_ref) * fr_before fr_befores = np.logspace(0, 3, 40) fr_afters = f(fr_befores) max_fr_lim = np.max(fr_afters) max_ind = np.argmax(fr_afters) z = np.polyfit(fr_afters[:max_ind], fr_befores[:max_ind], 10) p = np.poly1d(z) return max_fr_lim, p