Source code for allensdk.ephys.ephys_features

# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2017. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact terms@alleninstitute.org for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
import warnings
import logging
import numpy as np
import scipy.signal as signal
from scipy.optimize import curve_fit
from functools import partial

[docs]def detect_putative_spikes(v, t, start=None, end=None, filter=10., dv_cutoff=20.): """Perform initial detection of spikes and return their indexes. Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds start : start of time window for spike detection (optional) end : end of time window for spike detection (optional) filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10) dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20) dvdt : pre-calculated time-derivative of voltage (optional) Returns ------- putative_spikes : numpy array of preliminary spike indexes """ if not isinstance(v, np.ndarray): raise TypeError("v is not an np.ndarray") if not isinstance(t, np.ndarray): raise TypeError("t is not an np.ndarray") if v.shape != t.shape: raise FeatureError("Voltage and time series do not have the same dimensions") if start is None: start = t[0] if end is None: end = t[-1] start_index = find_time_index(t, start) end_index = find_time_index(t, end) v_window = v[start_index:end_index + 1] t_window = t[start_index:end_index + 1] dvdt = calculate_dvdt(v_window, t_window, filter) # Find positive-going crossings of dV/dt cutoff level putative_spikes = np.flatnonzero(np.diff(np.greater_equal(dvdt, dv_cutoff).astype(int)) == 1) if len(putative_spikes) <= 1: # Set back to original index space (not just window) return np.array(putative_spikes) + start_index # Only keep spike times if dV/dt has dropped all the way to zero between putative spikes putative_spikes = [putative_spikes[0]] + [s for i, s in enumerate(putative_spikes[1:]) if np.any(dvdt[putative_spikes[i]:s] < 0)] # Set back to original index space (not just window) return np.array(putative_spikes) + start_index
[docs]def find_peak_indexes(v, t, spike_indexes, end=None): """Find indexes of spike peaks. Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds spike_indexes : numpy array of preliminary spike indexes end : end of time window for spike detection (optional) """ if not end: end = t[-1] end_index = find_time_index(t, end) spks_and_end = np.append(spike_indexes, end_index) peak_indexes = [np.argmax(v[spk:next]) + spk for spk, next in zip(spks_and_end[:-1], spks_and_end[1:])] return np.array(peak_indexes)
[docs]def filter_putative_spikes(v, t, spike_indexes, peak_indexes, min_height=2., min_peak=-30., filter=10., dvdt=None): """Filter out events that are unlikely to be spikes based on: * Voltage failing to go down between peak and the next spike's threshold * Height (threshold to peak) * Absolute peak level Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds spike_indexes : numpy array of preliminary spike indexes peak_indexes : numpy array of indexes of spike peaks min_height : minimum acceptable height from threshold to peak in mV (optional, default 2) min_peak : minimum acceptable absolute peak level in mV (optional, default -30) filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10) dvdt : pre-calculated time-derivative of voltage (optional) Returns ------- spike_indexes : numpy array of threshold indexes peak_indexes : numpy array of peak indexes """ if not spike_indexes.size or not peak_indexes.size: return np.array([]), np.array([]) if dvdt is None: dvdt = calculate_dvdt(v, t, filter) diff_mask = [np.any(dvdt[peak_ind:spike_ind] < 0) for peak_ind, spike_ind in zip(peak_indexes[:-1], spike_indexes[1:])] peak_indexes = peak_indexes[np.array(diff_mask + [True])] spike_indexes = spike_indexes[np.array([True] + diff_mask)] peak_level_mask = v[peak_indexes] >= min_peak spike_indexes = spike_indexes[peak_level_mask] peak_indexes = peak_indexes[peak_level_mask] height_mask = (v[peak_indexes] - v[spike_indexes]) >= min_height spike_indexes = spike_indexes[height_mask] peak_indexes = peak_indexes[height_mask] return spike_indexes, peak_indexes
[docs]def find_upstroke_indexes(v, t, spike_indexes, peak_indexes, filter=10., dvdt=None): """Find indexes of maximum upstroke of spike. Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds spike_indexes : numpy array of preliminary spike indexes peak_indexes : numpy array of indexes of spike peaks filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10) dvdt : pre-calculated time-derivative of voltage (optional) Returns ------- upstroke_indexes : numpy array of upstroke indexes """ if dvdt is None: dvdt = calculate_dvdt(v, t, filter) upstroke_indexes = [np.argmax(dvdt[spike:peak]) + spike for spike, peak in zip(spike_indexes, peak_indexes)] return np.array(upstroke_indexes)
[docs]def refine_threshold_indexes(v, t, upstroke_indexes, thresh_frac=0.05, filter=10., dvdt=None): """Refine threshold detection of previously-found spikes. Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds upstroke_indexes : numpy array of indexes of spike upstrokes (for threshold target calculation) thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05) filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10) dvdt : pre-calculated time-derivative of voltage (optional) Returns ------- threshold_indexes : numpy array of threshold indexes """ if not upstroke_indexes.size: return np.array([]) if dvdt is None: dvdt = calculate_dvdt(v, t, filter) avg_upstroke = dvdt[upstroke_indexes].mean() target = avg_upstroke * thresh_frac upstrokes_and_start = np.append(np.array([0]), upstroke_indexes) threshold_indexes = [] for upstk, upstk_prev in zip(upstrokes_and_start[1:], upstrokes_and_start[:-1]): potential_indexes = np.flatnonzero(dvdt[upstk:upstk_prev:-1] <= target) if not potential_indexes.size: # couldn't find a matching value for threshold, # so just going to the start of the search interval threshold_indexes.append(upstk_prev) else: threshold_indexes.append(upstk - potential_indexes[0]) return np.array(threshold_indexes)
[docs]def check_thresholds_and_peaks(v, t, spike_indexes, peak_indexes, upstroke_indexes, end=None, max_interval=0.005, thresh_frac=0.05, filter=10., dvdt=None, tol=1.0): """Validate thresholds and peaks for set of spikes Check that peaks and thresholds for consecutive spikes do not overlap Spikes with overlapping thresholds and peaks will be merged. Check that peaks and thresholds for a given spike are not too far apart. Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds spike_indexes : numpy array of spike indexes peak_indexes : numpy array of indexes of spike peaks upstroke_indexes : numpy array of indexes of spike upstrokes max_interval : maximum allowed time between start of spike and time of peak in sec (default 0.005) thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05) filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10) dvdt : pre-calculated time-derivative of voltage (optional) tol : tolerance for returning to threshold in mV (optional, default 1) Returns ------- spike_indexes : numpy array of modified spike indexes peak_indexes : numpy array of modified spike peak indexes upstroke_indexes : numpy array of modified spike upstroke indexes clipped : numpy array of clipped status of spikes """ if not end: end = t[-1] overlaps = np.flatnonzero(spike_indexes[1:] <= peak_indexes[:-1] + 1) if overlaps.size: spike_mask = np.ones_like(spike_indexes, dtype=bool) spike_mask[overlaps + 1] = False spike_indexes = spike_indexes[spike_mask] peak_mask = np.ones_like(peak_indexes, dtype=bool) peak_mask[overlaps] = False peak_indexes = peak_indexes[peak_mask] upstroke_mask = np.ones_like(upstroke_indexes, dtype=bool) upstroke_mask[overlaps] = False upstroke_indexes = upstroke_indexes[upstroke_mask] # Validate that peaks don't occur too long after the threshold # If they do, try to re-find threshold from the peak too_long_spikes = [] for i, (spk, peak) in enumerate(zip(spike_indexes, peak_indexes)): if t[peak] - t[spk] >= max_interval: logging.info("Need to recalculate threshold-peak pair that exceeds maximum allowed interval ({:f} s)".format(max_interval)) too_long_spikes.append(i) if too_long_spikes: if dvdt is None: dvdt = calculate_dvdt(v, t, filter) avg_upstroke = dvdt[upstroke_indexes].mean() target = avg_upstroke * thresh_frac drop_spikes = [] for i in too_long_spikes: # First guessing that threshold is wrong and peak is right peak = peak_indexes[i] t_0 = find_time_index(t, t[peak] - max_interval) below_target = np.flatnonzero(dvdt[upstroke_indexes[i]:t_0:-1] <= target) if not below_target.size: # Now try to see if threshold was right but peak was wrong # Find the peak in a window twice the size of our allowed window spike = spike_indexes[i] t_0 = find_time_index(t, t[spike] + 2 * max_interval) new_peak = np.argmax(v[spike:t_0]) + spike # If that peak is okay (not outside the allowed window, not past the next spike) # then keep it if t[new_peak] - t[spike] < max_interval and \ (i == len(spike_indexes) - 1 or t[new_peak] < t[spike_indexes[i + 1]]): peak_indexes[i] = new_peak else: # Otherwise, log and get rid of the spike logging.info("Could not redetermine threshold-peak pair - dropping that pair") drop_spikes.append(i) # raise FeatureError("Could not redetermine threshold") else: spike_indexes[i] = upstroke_indexes[i] - below_target[0] if drop_spikes: spike_indexes = np.delete(spike_indexes, drop_spikes) peak_indexes = np.delete(peak_indexes, drop_spikes) upstroke_indexes = np.delete(upstroke_indexes, drop_spikes) # Check that last spike was not cut off too early by end of stimulus # by checking that the membrane potential returned to at least the threshold # voltage - otherwise, drop it clipped = np.zeros_like(spike_indexes, dtype=bool) end_index = find_time_index(t, end) if len(spike_indexes) > 0 and not np.any(v[peak_indexes[-1]:end_index + 1] <= v[spike_indexes[-1]] + tol): logging.debug("Failed to return to threshold voltage + tolerance (%.2f) after last spike (min %.2f) - marking last spike as clipped", v[spike_indexes[-1]] + tol, v[peak_indexes[-1]:end_index + 1].min()) clipped[-1] = True return spike_indexes, peak_indexes, upstroke_indexes, clipped
[docs]def find_trough_indexes(v, t, spike_indexes, peak_indexes, clipped=None, end=None): """ Find indexes of minimum voltage (trough) between spikes. Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds spike_indexes : numpy array of spike indexes peak_indexes : numpy array of spike peak indexes end : end of time window (optional) Returns ------- trough_indexes : numpy array of threshold indexes """ if not spike_indexes.size or not peak_indexes.size: return np.array([]) if clipped is None: clipped = np.zeros_like(spike_indexes, dtype=bool) if end is None: end = t[-1] end_index = find_time_index(t, end) trough_indexes = np.zeros_like(spike_indexes, dtype=float) trough_indexes[:-1] = [v[peak:spk].argmin() + peak for peak, spk in zip(peak_indexes[:-1], spike_indexes[1:])] if clipped[-1]: # If last spike is cut off by the end of the window, trough is undefined trough_indexes[-1] = np.nan else: trough_indexes[-1] = v[peak_indexes[-1]:end_index].argmin() + peak_indexes[-1] # nwg - trying to remove this next part for now - can't figure out if this will be needed with new "clipped" method # If peak is the same point as the trough, drop that point # trough_indexes = trough_indexes[np.where(peak_indexes[:len(trough_indexes)] != trough_indexes)] return trough_indexes
[docs]def find_downstroke_indexes(v, t, peak_indexes, trough_indexes, clipped=None, filter=10., dvdt=None): """Find indexes of minimum voltage (troughs) between spikes. Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds peak_indexes : numpy array of spike peak indexes trough_indexes : numpy array of threshold indexes clipped: boolean array - False if spike not clipped by edge of window filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10) dvdt : pre-calculated time-derivative of voltage (optional) Returns ------- downstroke_indexes : numpy array of downstroke indexes """ if not trough_indexes.size: return np.array([]) if dvdt is None: dvdt = calculate_dvdt(v, t, filter) if clipped is None: clipped = np.zeros_like(peak_indexes, dtype=bool) if len(peak_indexes) < len(trough_indexes): raise FeatureError("Cannot have more troughs than peaks") # Taking this out...with clipped info, should always have the same number of points # peak_indexes = peak_indexes[:len(trough_indexes)] valid_peak_indexes = peak_indexes[~clipped].astype(int) valid_trough_indexes = trough_indexes[~clipped].astype(int) downstroke_indexes = np.zeros_like(peak_indexes) * np.nan downstroke_index_values = [np.argmin(dvdt[peak:trough]) + peak for peak, trough in zip(valid_peak_indexes, valid_trough_indexes)] downstroke_indexes[~clipped] = downstroke_index_values return downstroke_indexes
[docs]def find_widths(v, t, spike_indexes, peak_indexes, trough_indexes, clipped=None): """Find widths at half-height for spikes. Widths are only returned when heights are defined Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds spike_indexes : numpy array of spike indexes peak_indexes : numpy array of spike peak indexes trough_indexes : numpy array of trough indexes Returns ------- widths : numpy array of spike widths in sec """ if not spike_indexes.size or not peak_indexes.size: return np.array([]) if len(spike_indexes) < len(trough_indexes): raise FeatureError("Cannot have more troughs than spikes") if clipped is None: clipped = np.zeros_like(spike_indexes, dtype=bool) use_indexes = ~np.isnan(trough_indexes) use_indexes[clipped] = False heights = np.zeros_like(trough_indexes) * np.nan heights[use_indexes] = v[peak_indexes[use_indexes]] - v[trough_indexes[use_indexes].astype(int)] width_levels = np.zeros_like(trough_indexes) * np.nan width_levels[use_indexes] = heights[use_indexes] / 2. + v[trough_indexes[use_indexes].astype(int)] thresh_to_peak_levels = np.zeros_like(trough_indexes) * np.nan thresh_to_peak_levels[use_indexes] = (v[peak_indexes[use_indexes]] - v[spike_indexes[use_indexes]]) / 2. + v[spike_indexes[use_indexes]] # Some spikes in burst may have deep trough but short height, so can't use same # definition for width width_levels[width_levels < v[spike_indexes]] = \ thresh_to_peak_levels[width_levels < v[spike_indexes]] width_starts = np.zeros_like(trough_indexes) * np.nan width_starts[use_indexes] = np.array([pk - np.flatnonzero(v[pk:spk:-1] <= wl)[0] if np.flatnonzero(v[pk:spk:-1] <= wl).size > 0 else np.nan for pk, spk, wl in zip(peak_indexes[use_indexes], spike_indexes[use_indexes], width_levels[use_indexes])]) width_ends = np.zeros_like(trough_indexes) * np.nan width_ends[use_indexes] = np.array([pk + np.flatnonzero(v[pk:tr] <= wl)[0] if np.flatnonzero(v[pk:tr] <= wl).size > 0 else np.nan for pk, tr, wl in zip(peak_indexes[use_indexes], trough_indexes[use_indexes].astype(int), width_levels[use_indexes])]) missing_widths = np.isnan(width_starts) | np.isnan(width_ends) widths = np.zeros_like(width_starts, dtype=np.float64) widths[~missing_widths] = t[width_ends[~missing_widths].astype(int)] - \ t[width_starts[~missing_widths].astype(int)] if any(missing_widths): widths[missing_widths] = np.nan return widths
[docs]def analyze_trough_details(v, t, spike_indexes, peak_indexes, clipped=None, end=None, filter=10., heavy_filter=1., term_frac=0.01, adp_thresh=0.5, tol=0.5, flat_interval=0.002, adp_max_delta_t=0.005, adp_max_delta_v=10., dvdt=None): """Analyze trough to determine if an ADP exists and whether the reset is a 'detour' or 'direct' Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds spike_indexes : numpy array of spike indexes peak_indexes : numpy array of spike peak indexes end : end of time window (optional) filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (default 1) heavy_filter : lower cutoff frequency for 4-pole low-pass Bessel filter in kHz (default 1) thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05) adp_thresh: minimum dV/dt in V/s to exceed to be considered to have an ADP (optional, default 1.5) tol : tolerance for evaluating whether Vm drops appreciably further after end of spike (default 1.0 mV) flat_interval: if the trace is flat for this duration, stop looking for an ADP (default 0.002 s) adp_max_delta_t: max possible ADP delta t (default 0.005 s) adp_max_delta_v: max possible ADP delta v (default 10 mV) dvdt : pre-calculated time-derivative of voltage (optional) Returns ------- isi_types : numpy array of isi reset types (direct or detour) fast_trough_indexes : numpy array of indexes at the start of the trough (i.e. end of the spike) adp_indexes : numpy array of adp indexes (np.nan if there was no ADP in that ISI slow_trough_indexes : numpy array of indexes at the minimum of the slow phase of the trough (if there wasn't just a fast phase) """ if end is None: end = t[-1] end_index = find_time_index(t, end) if clipped is None: clipped = np.zeros_like(peak_indexes) # Can't evaluate for spikes that are clipped by the window orig_len = len(peak_indexes) valid_spike_indexes = spike_indexes[~clipped] valid_peak_indexes = peak_indexes[~clipped] if dvdt is None: dvdt = calculate_dvdt(v, t, filter) dvdt_hvy = calculate_dvdt(v, t, heavy_filter) # Writing as for loop - see if I can vectorize any later fast_trough_indexes = [] adp_indexes = [] slow_trough_indexes = [] isi_types = [] update_clipped = [] for peak, next_spk in zip(valid_peak_indexes, np.append(valid_spike_indexes[1:], end_index)): downstroke = dvdt[peak:next_spk].argmin() + peak target = term_frac * dvdt[downstroke] terminated_points = np.flatnonzero(dvdt[downstroke:next_spk] >= target) if terminated_points.size: terminated = terminated_points[0] + downstroke update_clipped.append(False) else: logging.debug("Could not identify fast trough - marking spike as clipped") isi_types.append(np.nan) fast_trough_indexes.append(np.nan) adp_indexes.append(np.nan) slow_trough_indexes.append(np.nan) update_clipped.append(True) continue # Could there be an ADP? adp_index = np.nan dv_over_thresh = np.flatnonzero(dvdt_hvy[terminated:next_spk] >= adp_thresh) if dv_over_thresh.size: cross = dv_over_thresh[0] + terminated # only want to look for ADP before things get pretty flat # otherwise, could just pick up random transients long after the spike if t[cross] - t[terminated] < flat_interval: # Going back up fast, but could just be going into another spike # so need to check for a reversal (zero-crossing) in dV/dt zero_return_vals = np.flatnonzero(dvdt_hvy[cross:next_spk] <= 0) if zero_return_vals.size: putative_adp_index = zero_return_vals[0] + cross min_index = v[putative_adp_index:next_spk].argmin() + putative_adp_index if (v[putative_adp_index] - v[min_index] >= tol and v[putative_adp_index] - v[terminated] <= adp_max_delta_v and t[putative_adp_index] - t[terminated] <= adp_max_delta_t): adp_index = putative_adp_index slow_phase_min_index = min_index isi_type = "detour" if np.isnan(adp_index): v_term = v[terminated] min_index = v[terminated:next_spk].argmin() + terminated if v_term - v[min_index] >= tol: # dropped further after end of spike -> detour reset isi_type = "detour" slow_phase_min_index = min_index else: isi_type = "direct" isi_types.append(isi_type) fast_trough_indexes.append(terminated) adp_indexes.append(adp_index) if isi_type == "detour": slow_trough_indexes.append(slow_phase_min_index) else: slow_trough_indexes.append(np.nan) # If we had to kick some spikes out before, need to add nans at the end output = [] output.append(np.array(isi_types)) for d in (fast_trough_indexes, adp_indexes, slow_trough_indexes): output.append(np.array(d, dtype=float)) if orig_len > len(isi_types): extra = np.zeros(orig_len - len(isi_types)) * np.nan output = tuple((np.append(o, extra) for o in output)) # The ADP and slow trough for the last spike in a train are not reliably # calculated, and usually extreme when wrong, so we will NaN them out. # # Note that this will result in a 0 value when delta V or delta T is # calculated, which may not be strictly accurate to the trace, but the # magnitude of the difference will be less than in many of the erroneous # cases seen otherwise output[2][-1] = np.nan # ADP output[3][-1] = np.nan # slow trough clipped[~clipped] = update_clipped return output, clipped
[docs]def find_time_index(t, t_0): """Find the index value of a given time (t_0) in a time series (t).""" t_gte = np.flatnonzero(t >= t_0) if not t_gte.size: raise FeatureError("Could not find given time in time vector") return t_gte[0]
[docs]def calculate_dvdt(v, t, filter=None): """Low-pass filters (if requested) and differentiates voltage by time. Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default None) Returns ------- dvdt : numpy array of time-derivative of voltage (V/s = mV/ms) """ if has_fixed_dt(t) and filter: delta_t = t[1] - t[0] sample_freq = 1. / delta_t filt_coeff = (filter * 1e3) / (sample_freq / 2.) # filter kHz -> Hz, then get fraction of Nyquist frequency if filt_coeff < 0 or filt_coeff >= 1: raise ValueError("bessel coeff ({:f}) is outside of valid range [0,1); cannot filter sampling frequency {:.1f} kHz with cutoff frequency {:.1f} kHz.".format(filt_coeff, sample_freq / 1e3, filter)) b, a = signal.bessel(4, filt_coeff, "low") v_filt = signal.filtfilt(b, a, v, axis=0) dv = np.diff(v_filt) else: dv = np.diff(v) dt = np.diff(t) dvdt = 1e-3 * dv / dt # in V/s = mV/ms # Remove nan values (in case any dt values == 0) dvdt = dvdt[~np.isnan(dvdt)] return dvdt
[docs]def get_isis(t, spikes): """Find interspike intervals in sec between spikes (as indexes).""" if len(spikes) <= 1: return np.array([]) return t[spikes[1:]] - t[spikes[:-1]]
[docs]def average_voltage(v, t, start=None, end=None): """Calculate average voltage between start and end. Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds start : start of time window for spike detection (optional, default None) end : end of time window for spike detection (optional, default None) Returns ------- v_avg : average voltage """ if start is None: start = t[0] if end is None: end = t[-1] start_index = find_time_index(t, start) end_index = find_time_index(t, end) return v[start_index:end_index].mean()
[docs]def adaptation_index(isis): """Calculate adaptation index of `isis`.""" if len(isis) == 0: return np.nan return norm_diff(isis)
[docs]def latency(t, spikes, start): """Calculate time to the first spike.""" if len(spikes) == 0: return np.nan if start is None: start = t[0] return t[spikes[0]] - start
[docs]def average_rate(t, spikes, start, end): """Calculate average firing rate during interval between `start` and `end`. Parameters ---------- t : numpy array of times in seconds spikes : numpy array of spike indexes start : start of time window for spike detection end : end of time window for spike detection Returns ------- avg_rate : average firing rate in spikes/sec """ if start is None: start = t[0] if end is None: end = t[-1] spikes_in_interval = [spk for spk in spikes if t[spk] >= start and t[spk] <= end] avg_rate = len(spikes_in_interval) / (end - start) return avg_rate
[docs]def norm_diff(a): """Calculate average of (a[i] - a[i+1]) / (a[i] + a[i+1]).""" if len(a) <= 1: return np.nan a = a.astype(float) if np.allclose((a[1:] + a[:-1]), 0.): return 0. norm_diffs = (a[1:] - a[:-1]) / (a[1:] + a[:-1]) norm_diffs[(a[1:] == 0) & (a[:-1] == 0)] = 0. with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy") avg = np.nanmean(norm_diffs) return avg
[docs]def norm_sq_diff(a): """Calculate average of (a[i] - a[i+1])^2 / (a[i] + a[i+1])^2.""" if len(a) <= 1: return np.nan a = a.astype(float) norm_sq_diffs = np.square((a[1:] - a[:-1])) / np.square((a[1:] + a[:-1])) return norm_sq_diffs.mean()
[docs]def has_fixed_dt(t): """Check that all time intervals are identical.""" dt = np.diff(t) return np.allclose(dt, np.ones_like(dt) * dt[0])
[docs]def fit_membrane_time_constant(v, t, start, end, min_rsme=1e-4): """Fit an exponential to estimate membrane time constant between start and end Parameters ---------- v : numpy array of voltages in mV t : numpy array of times in seconds start : start of time window for exponential fit end : end of time window for exponential fit min_rsme: minimal acceptable root mean square error (default 1e-4) Returns ------- a, inv_tau, y0 : Coeffients of equation y0 + a * exp(-inv_tau * x) returns np.nan for values if fit fails """ start_index = find_time_index(t, start) end_index = find_time_index(t, end) guess = (v[start_index] - v[end_index], 50., v[end_index]) t_window = (t[start_index:end_index] - t[start_index]).astype(np.float64) v_window = v[start_index:end_index].astype(np.float64) try: popt, pcov = curve_fit(_exp_curve, t_window, v_window, p0=guess) except RuntimeError: logging.info("Curve fit for membrane time constant failed") return np.nan, np.nan, np.nan pred = _exp_curve(t_window, *popt) rsme = np.sqrt(np.mean(pred - v_window)) if rsme > min_rsme: logging.debug("Curve fit for membrane time constant did not meet RSME standard") return np.nan, np.nan, np.nan return popt
[docs]def detect_pauses(isis, isi_types, cost_weight=1.0): """Determine which ISIs are "pauses" in ongoing firing. Pauses are unusually long ISIs with a "detour reset" among "direct resets". Parameters ---------- isis : numpy array of interspike intervals isi_types : numpy array of interspike interval types ('direct' or 'detour') cost_weight : weight for cost function for calling an ISI a pause Higher cost weights lead to fewer ISIs identified as pauses. The cost function also depends on the difference between the duration of the "pause" ISIs and the average duration and standard deviation of "non-pause" ISIs. Returns ------- pauses : numpy array of indices corresponding to pauses in `isis` """ if len(isis) != len(isi_types): raise FeatureError("Wrong number of ISIs") if not np.any(isi_types == "direct"): # Need some direct-type firing to have pauses return np.array([]) detour_candidates = [i for i, isi_type in enumerate(isi_types) if isi_type == "detour"] median_direct = np.median(isis[isi_types == "direct"]) direct_candidates = [i for i, isi_type in enumerate(isi_types) if isi_type == "direct" and isis[i] > 3 * median_direct] candidates = detour_candidates + direct_candidates if not candidates: return np.array([]) pause_list = np.array([], dtype=int) all_cv = isis.std() / isis.mean() best_net = 0 for i in candidates: temp_pause_list = np.append(pause_list, i) non_pause_isis = np.delete(isis, temp_pause_list) pause_isis = isis[temp_pause_list] if len(non_pause_isis) < 2: break cv = non_pause_isis.std() / non_pause_isis.mean() benefit = all_cv - cv cost = np.sum(non_pause_isis.std() / np.abs(non_pause_isis.mean() - pause_isis)) cost *= cost_weight net = benefit - cost if net > 0 and net < best_net: break if net > best_net: best_net = net pause_list = np.append(pause_list, i) if best_net <= 0: pause_list = np.array([]) return np.sort(pause_list)
[docs]def detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, slow_tr_v, slow_tr_t, thr_v, tol=0.5, pause_cost=1.0): """Detect bursts in spike train. Parameters ---------- isis : numpy array of n interspike intervals isi_types : numpy array of n interspike interval types fast_tr_v : numpy array of fast trough voltages for the n + 1 spikes of the train fast_tr_t : numpy array of fast trough times for the n + 1 spikes of the train slow_tr_v : numpy array of slow trough voltages for the n + 1 spikes of the train slow_tr_t : numpy array of slow trough times for the n + 1 spikes of the train thr_v : numpy array of threshold voltages for the n + 1 spikes of the train tol : tolerance for the difference in slow trough voltages and thresholds (default 0.5 mV) Used to identify "delay" interspike intervals that occur within a burst Returns ------- bursts : list of bursts Each item in list is a tuple of the form (burst_index, start, end) where `burst_index` is a comparison index between the highest instantaneous rate within the burst vs the highest instantaneous rate outside the burst. `start` is the index of the first ISI of the burst, and `end` is the ISI index immediately following the burst. """ if len(isis) != len(isi_types): raise FeatureError("Wrong number of ISIs") if len(isis) < 2: # can't determine burstiness for a single ISI return np.array([]) fast_tr_v = fast_tr_v[:-1] fast_tr_t = fast_tr_t[:-1] slow_tr_v = slow_tr_v[:-1] slow_tr_t = slow_tr_t[:-1] isi_types = np.array(isi_types) # don't want to change the actual isi types data # Burst transitions can't be at "pause"-like ISIs pauses = detect_pauses(isis, isi_types, cost_weight=pause_cost).astype(int) isi_types[pauses] = "pauselike" if not (np.any(isi_types == "direct") and np.any(isi_types == "detour")): # no candidates that could be bursts return np.array([]) # Want to catch special case of detour in the middle of a large burst where # the slow trough value is higher than the previous spike's threshold isi_types[(thr_v[:-1] < (slow_tr_v + tol)) & (isi_types == "detour")] = "midburst" # Find transitions from direct -> detour and vice versa for burst boundaries into_burst = np.array([i + 1 for i, (prev, cur) in enumerate(zip(isi_types[:-1], isi_types[1:])) if cur == "direct" and prev == "detour"], dtype=int) if isi_types[0] == "direct": into_burst = np.append(np.array([0]), into_burst) drop_into = [] out_of_burst = [] for j, (into, next) in enumerate(zip(into_burst, np.append(into_burst[1:], len(isis)))): for i, isi in enumerate(isi_types[into + 1:next]): if isi == "detour": out_of_burst.append(i + into + 1) break elif isi == "pauselike": drop_into.append(j) break mask = np.ones_like(into_burst, dtype=bool) mask[drop_into] = False into_burst = into_burst[mask] out_of_burst = np.array(out_of_burst) if len(out_of_burst) == len(into_burst) - 1: out_of_burst = np.append(out_of_burst, len(isi_types)) if not (into_burst.size or out_of_burst.size): return np.array([]) if len(into_burst) != len(out_of_burst): raise FeatureError("Inconsistent burst boundary identification") inout_pairs = zip(into_burst, out_of_burst) delta_t = slow_tr_t - fast_tr_t scores = _score_burst_set(inout_pairs, isis, delta_t) best_score = np.mean(scores) worst = np.argmin(scores) test_bursts = list(inout_pairs) del test_bursts[worst] while len(test_bursts) > 0: scores = _score_burst_set(test_bursts, isis, delta_t) if np.mean(scores) > best_score: best_score = np.mean(scores) inout_pairs = list(test_bursts) worst = np.argmin(scores) del test_bursts[worst] else: break if best_score < 0: return np.array([]) bursts = [] for i, (into, outof) in enumerate(inout_pairs): if i == len(inout_pairs) - 1: # last burst to evaluate if outof <= len(isis) - 1: # are there spikes left after the burst? metric = _burstiness_index(isis[into:outof], isis[outof:]) elif i == 0: # was this the first one (and there weren't spikes after)? metric = _burstiness_index(isis[into:outof], isis[:into]) else: prev_burst = inout_pairs[i - 1] metric = _burstiness_index(isis[into:outof], isis[prev_burst[1]:into]) else: next_burst = inout_pairs[i + 1] metric = _burstiness_index(isis[into:outof], isis[outof:next_burst[0]]) bursts.append((metric, into, outof)) return bursts
[docs]def fit_prespike_time_constant(v, t, start, spike_time, dv_limit=-0.001, tau_limit=0.3): """Finds the dominant time constant of the pre-spike rise in voltage Parameters ---------- v : numpy array of voltage time series in mV t : numpy array of times in seconds start : start of voltage rise (seconds) spike_time : time of first spike (seconds) dv_limit : dV/dt cutoff (default -0.001) Shortens fit window if rate of voltage drop exceeds this limit tau_limit : upper bound for slow time constant (seconds, default 0.3) If the slower time constant of a double-exponential fit is twice that of the faster and exceeds this limit, the faster one will be considered the dominant one Returns ------- tau : dominant time constant (seconds) """ start_index = find_time_index(t, start) end_index = find_time_index(t, spike_time) if end_index <= start_index: raise FeatureError("Start for pre-spike time constant fit cannot be after the spike time.") v_slice = v[start_index:end_index] t_slice = t[start_index:end_index] # Solve linear version with single exponential first to guess at the time constant y0 = v_slice.max() + 5e-6 # set y0 slightly above v_slice maximum y = -v_slice + y0 y = np.log(y) dy = calculate_dvdt(y, t_slice, filter=1.0) # End the fit interval if the voltage starts dropping new_end_indexes = np.flatnonzero(dy <= dv_limit) cross_limit = 0.0005 # sec if not new_end_indexes.size or t_slice[new_end_indexes[0]] - t_slice[0] < cross_limit: # either never crosses or crosses too early new_end_index = len(v_slice) else: new_end_index = new_end_indexes[0] K, A_log = np.polyfit(t_slice[:new_end_index] - t_slice[0], y[:new_end_index], 1) A = np.exp(A_log) dbl_exp_y0 = partial(_dbl_exp_fit, y0) try: popt, pcov = curve_fit(dbl_exp_y0, t_slice - t_slice[0], v_slice, p0=(-A / 2.0, -1.0 / K, -A / 2.0, -1.0 / K)) except RuntimeError: # Fall back to single fit tau = -1.0 / K return tau # Find dominant time constant if popt[1] < popt[3]: faster_weight, faster_tau, slower_weight, slower_tau = popt else: slower_weight, slower_tau, faster_weight, faster_tau = popt # These are all empirical values if np.abs(faster_weight) > np.abs(slower_weight): tau = faster_tau elif (slower_tau - faster_tau) / slower_tau <= 0.1: # close enough; just use slower tau = slower_tau elif slower_tau > tau_limit and slower_weight / faster_weight < 2.0: tau = faster_tau else: tau = slower_tau return tau
[docs]def estimate_adjusted_detection_parameters(v_set, t_set, interval_start, interval_end, filter=10): """ Estimate adjusted values for spike detection by analyzing a period when the voltage changes quickly but passively (due to strong current stimulation), which can result in spurious spike detection results. Parameters ---------- v_set : list of numpy arrays of voltage time series in mV t_set : list of numpy arrays of times in seconds interval_start : start of analysis interval (sec) interval_end : end of analysis interval (sec) Returns ------- new_dv_cutoff : adjusted dv/dt cutoff (V/s) new_thresh_frac : adjusted fraction of avg upstroke to find threshold """ if type(v_set) is not list: v_set = list(v_set) if type(t_set) is not list: t_set = list(t_set) if len(v_set) != len(t_set): raise FeatureError("t_set and v_set must be lists of equal size") if len(v_set) == 0: raise FeatureError("t_set and v_set are empty") start_index = find_time_index(t_set[0], interval_start) end_index = find_time_index(t_set[0], interval_end) maxes = [] ends = [] dv_set = [] for v, t in zip(v_set, t_set): dv = calculate_dvdt(v, t, filter) dv_set.append(dv) maxes.append(dv[start_index:end_index].max()) ends.append(dv[end_index]) maxes = np.array(maxes) ends = np.array(ends) cutoff_adj_factor = 1.1 thresh_frac_adj_factor = 1.2 new_dv_cutoff = np.median(maxes) * cutoff_adj_factor min_thresh = np.median(ends) * thresh_frac_adj_factor all_upstrokes = np.array([]) for v, t, dv in zip(v_set, t_set, dv_set): putative_spikes = detect_putative_spikes(v, t, dv_cutoff=new_dv_cutoff, filter=filter) peaks = find_peak_indexes(v, t, putative_spikes) putative_spikes, peaks = filter_putative_spikes(v, t, putative_spikes, peaks, dvdt=dv, filter=filter) upstrokes = find_upstroke_indexes(v, t, putative_spikes, peaks, dvdt=dv) if upstrokes.size: all_upstrokes = np.append(all_upstrokes, dv[upstrokes]) new_thresh_frac = min_thresh / all_upstrokes.mean() return new_dv_cutoff, new_thresh_frac
def _score_burst_set(bursts, isis, delta_t, c_n=0.1, c_tx=0.01): in_burst = np.zeros_like(isis, dtype=bool) for b in bursts: in_burst[b[0]:b[1]] = True # If all ISIs are part of a burst, give it a bad score if len(isis[~in_burst]) == 0: return [-1e12] * len(bursts) delta_frac = delta_t / isis scores = [] for b in bursts: score = _burstiness_index(isis[b[0]:b[1]], isis[~in_burst]) # base score if b[1] < len(delta_t): score -= c_tx * (1. / (delta_frac[b[1]])) # cost for starting a burst if b[0] > 0: score -= c_tx * (1. / delta_frac[b[0] - 1]) # cost for ending a burst score -= c_n * (b[1] - b[0] - 1) # cost for extending a burst scores.append(score) return scores def _burstiness_index(in_burst_isis, out_burst_isis): burst_rate = 1. / in_burst_isis.min() out_rate = 1. / out_burst_isis.min() return (burst_rate - out_rate) / (burst_rate + out_rate) def _exp_curve(x, a, inv_tau, y0): return y0 + a * np.exp(-inv_tau * x) def _dbl_exp_fit(y0, x, A1, tau1, A2, tau2): penalty = 0 if tau1 < 0 or tau2 < 0: penalty = 1e6 return y0 + A1 * np.exp(-x / tau1) + A2 * np.exp(-x / tau2) + penalty
[docs]class FeatureError(Exception): """Generic Python-exception-derived object raised by feature detection functions.""" pass