Source code for allensdk.ephys.ephys_extractor
# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2015-2016. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact terms@alleninstitute.org for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
import numpy as np
from pandas import DataFrame
import warnings
import logging
from collections import Counter
from . import ephys_features as ft
import six
# Constants for stimulus-specific analysis
RAMPS_START = 1.02
LONG_SQUARES_START = 1.02
LONG_SQUARES_END = 2.02
SHORT_SQUARES_WINDOW_START = 1.02
SHORT_SQUARES_WINDOW_END = 1.021
SHORT_SQUARE_TRIPLE_WINDOW_START = 2.02
SHORT_SQUARE_TRIPLE_WINDOW_END = 2.021
[docs]class EphysSweepFeatureExtractor:
"""Feature calculation for a sweep (voltage and/or current time series)."""
def __init__(self, t=None, v=None, i=None, start=None, end=None, filter=10.,
dv_cutoff=20., max_interval=0.005, min_height=2., min_peak=-30.,
thresh_frac=0.05, baseline_interval=0.1, baseline_detect_thresh=0.3,
id=None):
"""Initialize SweepFeatures object.
Parameters
----------
t : ndarray of times (seconds)
v : ndarray of voltages (mV)
i : ndarray of currents (pA)
start : start of time window for feature analysis (optional)
end : end of time window for feature analysis (optional)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20)
max_interval : maximum acceptable time between start of spike and time of peak in sec (optional, default 0.005)
min_height : minimum acceptable height from threshold to peak in mV (optional, default 2)
min_peak : minimum acceptable absolute peak level in mV (optional, default -30)
thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
baseline_interval: interval length for baseline voltage calculation (before start if start is defined, default 0.1)
baseline_detect_thresh : dV/dt threshold for evaluating flatness of baseline region (optional, default 0.3)
"""
self.id = id
self.t = t
self.v = v
self.i = i
self.start = start
self.end = end
self.filter = filter
self.dv_cutoff = dv_cutoff
self.max_interval = max_interval
self.min_height = min_height
self.min_peak = min_peak
self.thresh_frac = thresh_frac
self.baseline_interval = baseline_interval
self.baseline_detect_thresh = baseline_detect_thresh
self.stimulus_amplitude_calculator = None
self._sweep_features = {}
self._affected_by_clipping = []
[docs] def process_spikes(self):
"""Perform spike-related feature analysis"""
self._process_individual_spikes()
self._process_spike_related_features()
def _process_individual_spikes(self):
v = self.v
t = self.t
dvdt = ft.calculate_dvdt(v, t, self.filter)
# Basic features of spikes
putative_spikes = ft.detect_putative_spikes(v, t, self.start, self.end,
self.filter, self.dv_cutoff)
peaks = ft.find_peak_indexes(v, t, putative_spikes, self.end)
putative_spikes, peaks = ft.filter_putative_spikes(v, t, putative_spikes, peaks,
self.min_height, self.min_peak,
dvdt=dvdt, filter=self.filter)
if not putative_spikes.size:
# Save time if no spikes detected
self._spikes_df = DataFrame()
return
upstrokes = ft.find_upstroke_indexes(v, t, putative_spikes, peaks, self.filter, dvdt)
thresholds = ft.refine_threshold_indexes(v, t, upstrokes, self.thresh_frac,
self.filter, dvdt)
thresholds, peaks, upstrokes, clipped = ft.check_thresholds_and_peaks(v, t, thresholds, peaks,
upstrokes, self.end, self.max_interval)
if not thresholds.size:
# Save time if no spikes detected
self._spikes_df = DataFrame()
return
# Spike list and thresholds have been refined - now find other features
upstrokes = ft.find_upstroke_indexes(v, t, thresholds, peaks, self.filter, dvdt)
troughs = ft.find_trough_indexes(v, t, thresholds, peaks, clipped, self.end)
downstrokes = ft.find_downstroke_indexes(v, t, peaks, troughs, clipped, self.filter, dvdt)
trough_details, clipped = ft.analyze_trough_details(v, t, thresholds, peaks, clipped, self.end,
self.filter, dvdt=dvdt)
widths = ft.find_widths(v, t, thresholds, peaks, trough_details[1], clipped)
base_clipped_list = []
# Points where we care about t, v, and i if available
vit_data_indexes = {
"threshold": thresholds,
"peak": peaks,
"trough": troughs,
}
base_clipped_list += ["trough"]
# Points where we care about t and dv/dt
dvdt_data_indexes = {
"upstroke": upstrokes,
"downstroke": downstrokes
}
base_clipped_list += ["downstroke"]
# Trough details
isi_types = trough_details[0]
trough_detail_indexes = dict(zip(["fast_trough", "adp", "slow_trough"], trough_details[1:]))
base_clipped_list += ["fast_trough", "adp", "slow_trough"]
# Redundant, but ensures that DataFrame has right number of rows
# Any better way to do it?
spikes_df = DataFrame(data=thresholds, columns=["threshold_index"])
spikes_df["clipped"] = clipped
for k, all_vals in six.iteritems(vit_data_indexes):
valid_ind = ~np.isnan(all_vals)
vals = all_vals[valid_ind].astype(int)
spikes_df[k + "_index"] = np.nan
spikes_df[k + "_t"] = np.nan
spikes_df[k + "_v"] = np.nan
if len(vals) > 0:
spikes_df.ix[valid_ind, k + "_index"] = vals
spikes_df.ix[valid_ind, k + "_t"] = t[vals]
spikes_df.ix[valid_ind, k + "_v"] = v[vals]
if self.i is not None:
spikes_df[k + "_i"] = np.nan
if len(vals) > 0:
spikes_df.ix[valid_ind, k + "_i"] = self.i[vals]
if k in base_clipped_list:
self._affected_by_clipping += [
k + "_index",
k + "_t",
k + "_v",
k + "_i",
]
for k, all_vals in six.iteritems(dvdt_data_indexes):
valid_ind = ~np.isnan(all_vals)
vals = all_vals[valid_ind].astype(int)
spikes_df[k + "_index"] = np.nan
spikes_df[k] = np.nan
if len(vals) > 0:
spikes_df.ix[valid_ind, k + "_index"] = vals
spikes_df.ix[valid_ind, k + "_t"] = t[vals]
spikes_df.ix[valid_ind, k + "_v"] = v[vals]
spikes_df.ix[valid_ind, k] = dvdt[vals]
if k in base_clipped_list:
self._affected_by_clipping += [
k + "_index",
k + "_t",
k + "_v",
k,
]
spikes_df["isi_type"] = isi_types
self._affected_by_clipping += ["isi_type"]
for k, all_vals in six.iteritems(trough_detail_indexes):
valid_ind = ~np.isnan(all_vals)
vals = all_vals[valid_ind].astype(int)
spikes_df[k + "_index"] = np.nan
spikes_df[k + "_t"] = np.nan
spikes_df[k + "_v"] = np.nan
if len(vals) > 0:
spikes_df.ix[valid_ind, k + "_index"] = vals
spikes_df.ix[valid_ind, k + "_t"] = t[vals]
spikes_df.ix[valid_ind, k + "_v"] = v[vals]
if self.i is not None:
spikes_df[k + "_i"] = np.nan
if len(vals) > 0:
spikes_df.ix[valid_ind, k + "_i"] = self.i[vals]
if k in base_clipped_list:
self._affected_by_clipping += [
k + "_index",
k + "_t",
k + "_v",
k + "_i",
]
spikes_df["width"] = widths
self._affected_by_clipping += ["width"]
spikes_df["upstroke_downstroke_ratio"] = spikes_df["upstroke"] / -spikes_df["downstroke"]
self._affected_by_clipping += ["upstroke_downstroke_ratio"]
self._spikes_df = spikes_df
def _process_spike_related_features(self):
t = self.t
if len(self._spikes_df) == 0:
self._sweep_features["avg_rate"] = 0
return
thresholds = self._spikes_df["threshold_index"].values.astype(int)
isis = ft.get_isis(t, thresholds)
with warnings.catch_warnings():
# ignore mean of empty slice warnings here
warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy")
sweep_level_features = {
"adapt": ft.adaptation_index(isis),
"latency": ft.latency(t, thresholds, self.start),
"isi_cv": (isis.std() / isis.mean()) if len(isis) >= 1 else np.nan,
"mean_isi": isis.mean() if len(isis) > 0 else np.nan,
"median_isi": np.median(isis),
"first_isi": isis[0] if len(isis) >= 1 else np.nan,
"avg_rate": ft.average_rate(t, thresholds, self.start, self.end),
}
for k, v in six.iteritems(sweep_level_features):
self._sweep_features[k] = v
def _process_pauses(self, cost_weight=1.0):
# Pauses are unusually long ISIs with a "detour reset" among delay resets
thresholds = self._spikes_df["threshold_index"].values.astype(int)
isis = ft.get_isis(self.t, thresholds)
isi_types = self._spikes_df["isi_type"][:-1].values
return ft.detect_pauses(isis, isi_types, cost_weight)
[docs] def pause_metrics(self):
"""Estimate average number of pauses and average fraction of time spent in a pause
Attempts to detect pauses with a variety of conditions and averages results together.
Pauses that are consistently detected contribute more to estimates.
Returns
-------
avg_n_pauses : average number of pauses detected across conditions
avg_pause_frac : average fraction of interval (between start and end) spent in a pause
max_reliability : max fraction of times most reliable pause was detected given weights tested
n_max_rel_pauses : number of pauses detected with `max_reliability`
"""
thresholds = self._spikes_df["threshold_index"].values.astype(int)
isis = ft.get_isis(self.t, thresholds)
weight = 1.0
pause_list = self._process_pauses(weight)
if len(pause_list) == 0:
return 0, 0.
n_pauses = len(pause_list)
pause_frac = isis[pause_list].sum()
pause_frac /= self.end - self.start
return n_pauses, pause_frac
def _process_bursts(self, tol=0.5, pause_cost=1.0):
thresholds = self._spikes_df["threshold_index"].values.astype(int)
isis = ft.get_isis(self.t, thresholds)
isi_types = self._spikes_df["isi_type"][:-1].values
fast_tr_v = self._spikes_df["fast_trough_v"].values
fast_tr_t = self._spikes_df["fast_trough_t"].values
slow_tr_v = self._spikes_df["slow_trough_v"].values
slow_tr_t = self._spikes_df["slow_trough_t"].values
thr_v = self._spikes_df["threshold_v"].values
bursts = ft.detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, slow_tr_v, slow_tr_t,
thr_v, tol, pause_cost)
return np.array(bursts)
[docs] def burst_metrics(self):
"""Find bursts and return max "burstiness" index (normalized max rate in burst vs out).
Returns
-------
max_burstiness_index : max "burstiness" index across detected bursts
num_bursts : number of bursts detected
"""
burst_info = self._process_bursts()
if burst_info.shape[0] > 0:
return burst_info[:, 0].max(), burst_info.shape[0]
else:
return 0., 0
[docs] def delay_metrics(self):
"""Calculates ratio of latency to dominant time constant of rise before spike
Returns
-------
delay_ratio : ratio of latency to tau (higher means more delay)
tau : dominant time constant of rise before spike
"""
if len(self._spikes_df) == 0:
logging.info("No spikes available for delay calculation")
return 0., 0.
start = self.start
spike_time = self._spikes_df["threshold_t"].values[0]
tau = ft.fit_prespike_time_constant(self.v, self.t, start, spike_time)
latency = spike_time - start
delay_ratio = latency / tau
return delay_ratio, tau
def _get_baseline_voltage(self):
v = self.v
t = self.t
filter_frequency = 1. # in kHz
# Look at baseline interval before start if start is defined
if self.start is not None:
return ft.average_voltage(v, t, self.start - self.baseline_interval, self.start)
# Otherwise try to find an interval where things are pretty flat
dv = ft.calculate_dvdt(v, t, filter_frequency)
non_flat_points = np.flatnonzero(np.abs(dv >= self.baseline_detect_thresh))
flat_intervals = t[non_flat_points[1:]] - t[non_flat_points[:-1]]
long_flat_intervals = np.flatnonzero(flat_intervals >= self.baseline_interval)
if long_flat_intervals.size > 0:
interval_index = long_flat_intervals[0] + 1
baseline_end_time = t[non_flat_points[interval_index]]
return ft.average_voltage(v, t, baseline_end_time - self.baseline_interval,
baseline_end_time)
else:
logging.info("Could not find sufficiently flat interval for automatic baseline voltage", RuntimeWarning)
return np.nan
[docs] def voltage_deflection(self, deflect_type=None):
"""Measure deflection (min or max, between start and end if specified).
Parameters
----------
deflect_type : measure minimal ('min') or maximal ('max') voltage deflection
If not specified, it will check to see if the current (i) is positive or negative
between start and end, then choose 'max' or 'min', respectively
If the current is not defined, it will default to 'min'.
Returns
-------
deflect_v : peak
deflect_index : index of peak deflection
"""
deflect_dispatch = {
"min": np.argmin,
"max": np.argmax,
}
start = self.start
if not start:
start = 0
start_index = ft.find_time_index(self.t, start)
end = self.end
if not end:
end = self.t[-1]
end_index = ft.find_time_index(self.t, end)
if deflect_type is None:
if self.i is not None:
halfway_index = ft.find_time_index(self.t, (end - start) / 2. + start)
if self.i[halfway_index] >= 0:
deflect_type = "max"
else:
deflect_type = "min"
else:
deflect_type = "min"
deflect_func = deflect_dispatch[deflect_type]
v_window = self.v[start_index:end_index]
deflect_index = deflect_func(v_window) + start_index
return self.v[deflect_index], deflect_index
[docs] def stimulus_amplitude(self):
""" """
if self.stimulus_amplitude_calculator is not None:
return self.stimulus_amplitude_calculator(self)
else:
return np.nan
[docs] def estimate_time_constant(self):
"""Calculate the membrane time constant by fitting the voltage response with a
single exponential.
Returns
-------
tau : membrane time constant in seconds
"""
# Assumes this is being done on a hyperpolarizing step
v_peak, peak_index = self.voltage_deflection("min")
v_baseline = self.sweep_feature("v_baseline")
if self.start:
start_index = ft.find_time_index(self.t, self.start)
else:
start_index = 0
frac = 0.1
search_result = np.flatnonzero(self.v[start_index:] <= frac * (v_peak - v_baseline) + v_baseline)
if not search_result.size:
raise ft.FeatureError("could not find interval for time constant estimate")
fit_start = self.t[search_result[0] + start_index]
fit_end = self.t[peak_index]
a, inv_tau, y0 = ft.fit_membrane_time_constant(self.v, self.t, fit_start, fit_end)
return 1. / inv_tau
[docs] def estimate_sag(self, peak_width=0.005):
"""Calculate the sag in a hyperpolarizing voltage response.
Parameters
----------
peak_width : window width to get more robust peak estimate in sec (default 0.005)
Returns
-------
sag : fraction that membrane potential relaxes back to baseline
"""
t = self.t
v = self.v
start = self.start
if not start:
start = 0
end = self.end
if not end:
end = self.t[-1]
v_peak, peak_index = self.voltage_deflection("min")
v_peak_avg = ft.average_voltage(v, t, start=t[peak_index] - peak_width / 2.,
end=t[peak_index] + peak_width / 2.)
v_baseline = self.sweep_feature("v_baseline")
v_steady = ft.average_voltage(v, t, start=end - self.baseline_interval, end=end)
sag = (v_peak_avg - v_steady) / (v_peak_avg - v_baseline)
return sag
[docs] def spikes(self):
"""Get all features for each spike as a list of records."""
return self._spikes_df.to_dict('records')
[docs] def spike_feature(self, key, include_clipped=False, force_exclude_clipped=False):
"""Get specified feature for every spike.
Parameters
----------
key : feature name
include_clipped: return values for every identified spike, even when clipping means they will be incorrect/undefined
Returns
-------
spike_feature_values : ndarray of features for each spike
"""
if not hasattr(self, "_spikes_df"):
raise AttributeError("EphysSweepFeatureExtractor instance attribute with spike information does not exist yet - have spikes been processed?")
if len(self._spikes_df) == 0:
return np.array([])
if key not in self._spikes_df.columns:
raise KeyError("requested feature '{:s}' not available".format(key))
values = self._spikes_df[key].values
if include_clipped and force_exclude_clipped:
raise ValueError("include_clipped and force_exclude_clipped cannot both be true")
if not include_clipped and self.is_spike_feature_affected_by_clipping(key):
values = values[~self._spikes_df["clipped"].values]
elif force_exclude_clipped:
values = values[~self._spikes_df["clipped"].values]
return values
[docs] def is_spike_feature_affected_by_clipping(self, key):
return key in self._affected_by_clipping
[docs] def spike_feature_keys(self):
"""Get list of every available spike feature."""
return self._spikes_df.columns.values.tolist()
[docs] def sweep_feature(self, key, allow_missing=False):
"""Get sweep-level feature (`key`).
Parameters
----------
key : name of sweep-level feature
allow_missing : return np.nan if key is missing for sweep (default False)
Returns
-------
sweep_feature : sweep-level feature value
"""
on_request_dispatch = {
"v_baseline": self._get_baseline_voltage,
"tau": self.estimate_time_constant,
"sag": self.estimate_sag,
"peak_deflect": self.voltage_deflection,
"stim_amp": self.stimulus_amplitude,
}
if allow_missing and key not in self._sweep_features and key not in on_request_dispatch:
return np.nan
elif key not in self._sweep_features and key not in on_request_dispatch:
raise KeyError("requested feature '{:s}' not available".format(key))
if key not in self._sweep_features and key in on_request_dispatch:
fn = on_request_dispatch[key]
if fn is not None:
self._sweep_features[key] = fn()
else:
raise KeyError("requested feature '{:s}' not defined".format(key))
return self._sweep_features[key]
[docs] def process_new_spike_feature(self, feature_name, feature_func, affected_by_clipping=False):
"""Add new spike-level feature calculation function
The function should take this sweep extractor as its argument. Its results
can be accessed by calling the method spike_feature(<feature_name>).
"""
if feature_name in self._spikes_df.columns:
raise KeyError("Feature {:s} already exists for sweep".format(feature_name))
self._spikes_df[feature_name] = feature_func(self)
if affected_by_clipping:
self._affected_by_clipping.append(feature_name)
[docs] def process_new_sweep_feature(self, feature_name, feature_func):
"""Add new sweep-level feature calculation function
The function should take this sweep extractor as its argument. Its results
can be accessed by calling the method sweep_feature(<feature_name>).
"""
if feature_name in self._sweep_features:
raise KeyError("Feature {:s} already exists for sweep".format(feature_name))
self._sweep_features[feature_name] = feature_func(self)
[docs] def set_stimulus_amplitude_calculator(self, function):
self.stimulus_amplitude_calculator = function
[docs] def sweep_feature_keys(self):
"""Get list of every available sweep-level feature."""
return self._sweep_features.keys()
[docs] def as_dict(self):
"""Create dict of features and spikes."""
output_dict = self._sweep_features.copy()
output_dict["spikes"] = self.spikes()
if self.id is not None:
output_dict["id"] = self.id
return output_dict
[docs]class EphysSweepSetFeatureExtractor:
def __init__(self, t_set=None, v_set=None, i_set=None, start=None, end=None,
filter=10., dv_cutoff=20., max_interval=0.005, min_height=2.,
min_peak=-30., thresh_frac=0.05, baseline_interval=0.1,
baseline_detect_thresh=0.3, id_set=None):
"""Initialize EphysSweepSetFeatureExtractor object.
Parameters
----------
t_set : list of ndarray of times in seconds
v_set : list of ndarray of voltages in mV
i_set : list of ndarray of currents in pA
start : start of time window for feature analysis (optional, can be list)
end : end of time window for feature analysis (optional, can be list)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20)
max_interval : maximum acceptable time between start of spike and time of peak in sec (optional, default 0.005)
min_height : minimum acceptable height from threshold to peak in mV (optional, default 2)
min_peak : minimum acceptable absolute peak level in mV (optional, default -30)
thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
baseline_interval: interval length for baseline voltage calculation (before start if start is defined, default 0.1)
baseline_detect_thresh : dV/dt threshold for evaluating flatness of baseline region (optional, default 0.3)
"""
if t_set is not None and v_set is not None:
self._set_sweeps(t_set, v_set, i_set, start, end, filter, dv_cutoff, max_interval,
min_height, min_peak, thresh_frac, baseline_interval,
baseline_detect_thresh, id_set)
else:
self._sweeps = None
[docs] @classmethod
def from_sweeps(cls, sweep_list):
"""Initialize EphysSweepSetFeatureExtractor object with a list of pre-existing
sweep feature extractor objects.
"""
obj = cls()
obj._sweeps = sweep_list
return obj
def _set_sweeps(self, t_set, v_set, i_set, start, end, filter, dv_cutoff, max_interval,
min_height, min_peak, thresh_frac, baseline_interval,
baseline_detect_thresh, id_set):
if type(t_set) != list:
raise ValueError("t_set must be a list")
if type(v_set) != list:
raise ValueError("v_set must be a list")
if i_set is not None and type(i_set) != list:
raise ValueError("i_set must be a list")
if len(t_set) != len(v_set):
raise ValueError("t_set and v_set must have the same number of items")
if i_set and len(t_set) != len(i_set):
raise ValueError("t_set and i_set must have the same number of items")
if id_set is None:
id_set = range(len(t_set))
if len(id_set) != len(t_set):
raise ValueError("t_set and id_set must have the same number of items")
sweeps = []
if i_set is None:
i_set = [None] * len(t_set)
if type(start) is not list:
start = [start] * len(t_set)
end = [end] * len(t_set)
sweeps = [ EphysSweepFeatureExtractor(t, v, i, start, end,
filter=filter, dv_cutoff=dv_cutoff,
max_interval=max_interval,
min_height=min_height, min_peak=min_peak,
thresh_frac=thresh_frac,
baseline_interval=baseline_interval,
baseline_detect_thresh=baseline_detect_thresh,
id=sid) \
for t, v, i, start, end, sid in zip(t_set, v_set, i_set, start, end, id_set) ]
self._sweeps = sweeps
[docs] def process_spikes(self):
"""Analyze spike features for all sweeps."""
for sweep in self._sweeps:
sweep.process_spikes()
[docs] def sweep_features(self, key, allow_missing=False):
"""Get nparray of sweep-level feature (`key`) for all sweeps
Parameters
----------
key : name of sweep-level feature
allow_missing : return np.nan if key is missing for sweep (default False)
Returns
-------
sweep_feature : nparray of sweep-level feature values
"""
return np.array([swp.sweep_feature(key, allow_missing) for swp in self._sweeps])
[docs] def spike_feature_averages(self, key):
"""Get nparray of average spike-level feature (`key`) for all sweeps"""
return np.array([swp.spike_feature(key).mean() for swp in self._sweeps])
[docs]class EphysCellFeatureExtractor:
# Class constants for specific processing
SUBTHRESH_MAX_AMP = 0
SAG_TARGET = -100.
def __init__(self, ramps_ext, short_squares_ext, long_squares_ext, subthresh_min_amp=-100):
"""Initialize EphysCellFeatureExtractor object from EphysSweepSetExtractors for
ramp, short square, and long square sweeps.
Parameters
----------
dataset : NwbDataSet
ramps_ext : EphysSweepSetFeatureExtractor prepared with ramp sweeps
short_squares_ext : EphysSweepSetFeatureExtractor prepared with short square sweeps
long_squares_ext : EphysSweepSetFeatureExtractor prepared with long square sweeps
"""
self._ramps_ext = ramps_ext
self._short_squares_ext = short_squares_ext
self._long_squares_ext = long_squares_ext
self._subthresh_min_amp = subthresh_min_amp
self._features = {
"ramps": {},
"short_squares": {},
"long_squares": {},
}
self._spiking_long_squares_ext = None
self._subthreshold_long_squares_ext = None
self._subthreshold_membrane_property_ext = None
[docs] def process(self, keys=None):
"""Processes features. Can take a specific key (or set of keys) to do a subset of processing."""
dispatch = {
"ramps": self._analyze_ramps,
"short_squares": self._analyze_short_squares,
"long_squares": self._analyze_long_squares,
"long_squares_spiking": self._analyze_long_squares_spiking,
}
if keys is None:
keys = list(dispatch.keys())
if type(keys) is not list:
keys = list(keys)
for k in [j for j in keys if j in dispatch]:
dispatch[k]()
def _analyze_ramps(self):
ext = self._ramps_ext
ext.process_spikes()
self._all_ramps_ext = ext
# pull out the spiking sweeps
spiking_sweeps = [ sweep for sweep in self._ramps_ext.sweeps() if sweep.sweep_feature("avg_rate") > 0 ]
ext = EphysSweepSetFeatureExtractor.from_sweeps(spiking_sweeps)
self._ramps_ext = ext
self._features["ramps"]["spiking_sweeps"] = ext.sweeps()
[docs] def ramps_features(self, all=False):
if all:
return self._all_ramps_ext
else:
return self._ramps_ext
def _analyze_short_squares(self):
ext = self._short_squares_ext
ext.process_spikes()
# Need to count how many had spikes at each amplitude; find most; ties go to lower amplitude
spiking_sweeps = [sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") > 0]
if len(spiking_sweeps) == 0:
raise ft.FeatureError("No spiking short square sweeps, cannot compute cell features.")
most_common = Counter(map(_short_step_stim_amp, spiking_sweeps)).most_common()
common_amp, common_count = most_common[0]
for c in most_common[1:]:
if c[1] < common_count:
break
if c[0] < common_amp:
common_amp = c[0]
self._features["short_squares"]["stimulus_amplitude"] = common_amp
ext = EphysSweepSetFeatureExtractor.from_sweeps([sweep for sweep in spiking_sweeps if _short_step_stim_amp(sweep) == common_amp])
self._short_squares_ext = ext
self._features["short_squares"]["common_amp_sweeps"] = ext.sweeps()
for s in self._features["short_squares"]["common_amp_sweeps"]:
s.set_stimulus_amplitude_calculator(_short_step_stim_amp)
def _analyze_long_squares(self):
self._analyze_long_squares_spiking()
self._analyze_long_squares_subthreshold()
def _analyze_long_squares_spiking(self, force_reprocess=False):
if not force_reprocess and self._spiking_long_squares_ext:
return
ext = self._long_squares_ext
ext.process_spikes()
self._features["long_squares"]["sweeps"] = ext.sweeps()
for s in self._features["long_squares"]["sweeps"]:
s.set_stimulus_amplitude_calculator(_step_stim_amp)
spiking_indexes = np.flatnonzero(ext.sweep_features("avg_rate"))
if len(spiking_indexes) == 0:
raise ft.FeatureError("No spiking long square sweeps, cannot compute cell features.")
amps = ext.sweep_features("stim_amp")#self.long_squares_stim_amps()
min_index = np.argmin(amps[spiking_indexes])
rheobase_index = spiking_indexes[min_index]
rheobase_i = _step_stim_amp(ext.sweeps()[rheobase_index])
self._features["long_squares"]["rheobase_extractor_index"] = rheobase_index
self._features["long_squares"]["rheobase_i"] = rheobase_i
self._features["long_squares"]["rheobase_sweep"] = ext.sweeps()[rheobase_index]
spiking_sweeps = [sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") > 0]
self._spiking_long_squares_ext = EphysSweepSetFeatureExtractor.from_sweeps(spiking_sweeps)
self._features["long_squares"]["spiking_sweeps"] = self._spiking_long_squares_ext.sweeps()
self._features["long_squares"]["fi_fit_slope"] = fit_fi_slope(self._spiking_long_squares_ext)
def _analyze_long_squares_subthreshold(self):
ext = self._long_squares_ext
subthresh_sweeps = [sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") == 0]
subthresh_ext = EphysSweepSetFeatureExtractor.from_sweeps(subthresh_sweeps)
self._subthreshold_long_squares_ext = subthresh_ext
if len(subthresh_ext.sweeps()) == 0:
raise ft.FeatureError("No subthreshold long square sweeps, cannot evaluate cell features.")
peaks = subthresh_ext.sweep_features("peak_deflect")
sags = subthresh_ext.sweep_features("sag")
sag_eval_levels = np.array([sweep.voltage_deflection()[0] for sweep in subthresh_ext.sweeps()])
target_level = self.SAG_TARGET
closest_index = np.argmin(np.abs(sag_eval_levels - target_level))
self._features["long_squares"]["sag"] = sags[closest_index]
self._features["long_squares"]["vm_for_sag"] = sag_eval_levels[closest_index]
self._features["long_squares"]["subthreshold_sweeps"] = subthresh_ext.sweeps()
for s in self._features["long_squares"]["subthreshold_sweeps"]:
s.set_stimulus_amplitude_calculator(_step_stim_amp)
logging.debug("subthresh_sweeps: %d", len(subthresh_sweeps))
calc_subthresh_sweeps = [sweep for sweep in subthresh_sweeps if
sweep.sweep_feature("stim_amp") < self.SUBTHRESH_MAX_AMP and
sweep.sweep_feature("stim_amp") > self._subthresh_min_amp]
logging.debug("calc_subthresh_sweeps: %d", len(calc_subthresh_sweeps))
calc_subthresh_ext = EphysSweepSetFeatureExtractor.from_sweeps(calc_subthresh_sweeps)
self._subthreshold_membrane_property_ext = calc_subthresh_ext
self._features["long_squares"]["subthreshold_membrane_property_sweeps"] = calc_subthresh_ext.sweeps()
self._features["long_squares"]["input_resistance"] = input_resistance(calc_subthresh_ext)
self._features["long_squares"]["tau"] = membrane_time_constant(calc_subthresh_ext)
self._features["long_squares"]["v_baseline"] = np.nanmean(ext.sweep_features("v_baseline"))
[docs] def long_squares_features(self, option=None):
option_table = {
"spiking": self._spiking_long_squares_ext,
"subthreshold": self._subthreshold_long_squares_ext,
"subthreshold_membrane_property": self._subthreshold_membrane_property_ext,
}
if option:
return option_table[option]
return self._long_squares_ext
[docs] def long_squares_stim_amps(self, option=None):
option_table = {
"spiking": self._spiking_long_squares_ext,
"subthreshold": self._subthreshold_long_squares_ext,
"subthreshold_membrane_property": self._subthreshold_membrane_property_ext,
}
if option:
ext = option_table[option]
else:
ext = self._long_squares_ext
return np.array(map(_step_stim_amp, ext.sweeps()))
[docs] def as_dict(self):
"""Create dict of cell features."""
# get shallow copies of the sub-type dictionaries
out = {
"long_squares": self._features["long_squares"].copy(),
"short_squares": self._features["short_squares"].copy(),
"ramps": self._features["ramps"].copy(),
}
# convert feature extractor lists to sweep dictionarsweep extract lists
ls_sweeps = [ s.as_dict() for s in out["long_squares"]["sweeps"] ]
ls_spike_sweeps = [ s.as_dict() for s in out["long_squares"]["spiking_sweeps"] ]
rheo_sweep = out["long_squares"]["rheobase_sweep"].as_dict()
ls_sub_sweeps = [ s.as_dict() for s in out["long_squares"]["subthreshold_sweeps"] ]
ls_sub_mem_sweeps = [ s.as_dict() for s in out["long_squares"]["subthreshold_membrane_property_sweeps"] ]
ss_sweeps = [ s.as_dict() for s in out["short_squares"]["common_amp_sweeps"] ]
ramp_sweeps = [ s.as_dict() for s in out["ramps"]["spiking_sweeps"] ]
out["long_squares"]["sweeps"] = ls_sweeps
out["long_squares"]["spiking_sweeps"] = ls_spike_sweeps
out["long_squares"]["subthreshold_sweeps"] = ls_sub_sweeps
out["long_squares"]["subthreshold_membrane_property_sweeps"] = ls_sub_mem_sweeps
out["long_squares"]["rheobase_sweep"] = rheo_sweep
out["short_squares"]["common_amp_sweeps"] = ss_sweeps
out["ramps"]["spiking_sweeps"] = ramp_sweeps
return out
[docs]def input_resistance(ext):
"""Estimate input resistance in MOhms, assuming all sweeps in passed extractor
are hyperpolarizing responses."""
sweeps = ext.sweeps()
if not sweeps:
raise ft.FeatureError("no sweeps available for input resistance calculation")
v_vals = []
i_vals = []
for sweep in sweeps:
if sweep.i is None:
raise ft.FeatureError("cannot calculate input resistance: i not defined for a sweep")
v_peak, min_index = sweep.voltage_deflection('min')
v_vals.append(v_peak)
i_vals.append(sweep.i[min_index])
v = np.array(v_vals)
i = np.array(i_vals)
if len(v) == 1:
# If there's just one sweep, we'll have to use its own baseline to estimate
# the input resistance
v = np.append(v, sweeps[0].sweep_feature("v_baseline"))
i = np.append(i, 0.)
A = np.vstack([i, np.ones_like(i)]).T
m, c = np.linalg.lstsq(A, v)[0]
return m * 1e3
[docs]def membrane_time_constant(ext):
"""Average the membrane time constant values estimated from each sweep in passed extractor."""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy")
avg_tau = np.nanmean(ext.sweep_features("tau"))
return avg_tau
[docs]def fit_fi_slope(ext):
"""Fit the rate and stimulus amplitude to a line and return the slope of the fit."""
if len(ext.sweeps()) < 2:
raise ft.FeatureError("Cannot fit f-I curve slope with less than two suprathreshold sweeps")
x = np.array(list(map(_step_stim_amp, ext.sweeps())))
y = ext.sweep_features("avg_rate")
A = np.vstack([x, np.ones_like(x)]).T
m, c = np.linalg.lstsq(A, y)[0]
return m
[docs]def reset_long_squares_start(when):
global LONG_SQUARES_START, LONG_SQUARES_END
delta = LONG_SQUARES_END - LONG_SQUARES_START
LONG_SQUARES_START = when
LONG_SQUARES_END = when + delta
[docs]def cell_extractor_for_nwb(dataset, ramps, short_squares, long_squares, subthresh_min_amp=-100):
"""Initialize EphysCellFeatureExtractor object from NWB data set
Parameters
----------
dataset : NwbDataSet
ramps : list of sweep numbers of ramp sweeps
short_squares : list of sweep numbers of short square sweeps
long_squares : list of sweep numbers of long square sweeps
"""
if len(short_squares) == 0:
raise ft.FeatureError("no short square sweep numbers provided")
if len(ramps) == 0:
raise ft.FeatureError("no ramp sweep numbers provided")
if len(long_squares) == 0:
raise ft.FeatureError("no long_square sweep numbers provided")
ramps_ext = extractor_for_nwb_sweeps(dataset, ramps, fixed_start=RAMPS_START)
temp_short_sq_ext = extractor_for_nwb_sweeps(dataset, short_squares)
t_set = [s.t for s in temp_short_sq_ext.sweeps()]
v_set = [s.v for s in temp_short_sq_ext.sweeps()]
cutoff, thresh_frac = ft.estimate_adjusted_detection_parameters(v_set, t_set,
SHORT_SQUARES_WINDOW_START,
SHORT_SQUARES_WINDOW_END)
thresh_frac = max(thresh_frac, 0.1)
short_squares_ext = extractor_for_nwb_sweeps(dataset, short_squares,
dv_cutoff=cutoff, thresh_frac=thresh_frac)
long_squares_ext = extractor_for_nwb_sweeps(dataset, long_squares,
fixed_start=LONG_SQUARES_START,
fixed_end=LONG_SQUARES_END)
return EphysCellFeatureExtractor(ramps_ext, short_squares_ext, long_squares_ext, subthresh_min_amp)
[docs]def extractor_for_nwb_sweeps(dataset, sweep_numbers,
fixed_start=None, fixed_end=None,
dv_cutoff=20., thresh_frac=0.05):
v_set = []
t_set = []
i_set = []
start = []
end = []
for sweep_number in sweep_numbers:
data = dataset.get_sweep(sweep_number)
v = data['response'] * 1e3 # mV
i = data['stimulus'] * 1e12 # pA
hz = data['sampling_rate']
dt = 1. / hz
t = np.arange(0, len(v)) * dt # sec
s, e = dt * np.array(data['index_range'])
v_set.append(v)
i_set.append(i)
t_set.append(t)
start.append(s)
end.append(e)
if fixed_start and not fixed_end:
start = [fixed_start] * len(end)
elif fixed_start and fixed_end:
start = fixed_start
end = fixed_end
return EphysSweepSetFeatureExtractor(t_set, v_set, i_set, start=start, end=end,
dv_cutoff=dv_cutoff, thresh_frac=thresh_frac,
id_set=sweep_numbers)
def _step_stim_amp(sweep):
t_index = ft.find_time_index(sweep.t, sweep.start)
return sweep.i[t_index + 1]
def _short_step_stim_amp(sweep):
t_index = ft.find_time_index(sweep.t, sweep.start)
return sweep.i[t_index + 1:].max()