# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2017. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact terms@alleninstitute.org for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
import logging
import numpy as np
import simplejson as json
import allensdk.core.json_utilities as ju
import copy
try:
from glif_neuron_methods import GlifNeuronMethod, METHOD_LIBRARY
except:
from .glif_neuron_methods import GlifNeuronMethod, METHOD_LIBRARY
[docs]class GlifBadResetException( Exception ):
""" Exception raised when voltage is still above threshold after a reset rule is applied. """
def __init__(self, message, dv):
super(Exception, self).__init__(message)
self.dv = dv
[docs]class GlifNeuron( object ):
""" Implements the current-based Mihalas Neiber GLIF neuron. Simulations model the voltage,
threshold, and afterspike currents of a neuron given an input stimulus. A set of modular dynamics
rules are applied until voltage crosses threshold, at which point a set of modular reset rules are
applied. See glif_neuron_methods.py for a list of what options there are for voltage, threshold, and
afterspike current dynamics and reset rules.
Parameters
----------
El : float
resting potential
dt : float
duration between time steps
asc_tau_array: np.ndarray
TODO
R_input : float
input resistance
C : float
capacitance
asc_amp_arrap : np.ndarray
afterspike current vector. one element per element of asc_tau_array.
spike_cut_length : int
how many time steps to replace with NaNs when a spike occurs.
th_inf : float
instantaneous threshold
coeffs : dict
dictionary coefficients premultiplied to neuron properties during simulation. used for optimization.
AScurrent_dynamics_method : dict
dictionary containing the 'name' of the afterspike current dynamics method to use and a 'params' dictionary parameters to pass to that function.
voltage_dynamics_method : dict
dictionary containing the 'name' of the voltage dynamics method to use and a 'params' dictionary parameters to pass to that function.
threshold_dynamics_method : dict
dictionary containing the 'name' of the threshold dynamics method to use and a 'params' dictionary parameters to pass to that function.
AScurrent_reset_method : dict
dictionary containing the 'name' of the afterspike current dynamics method to use and a 'params' dictionary parameters to pass to that function.
voltage_reset_method : dict
dictionary containing the 'name' of the voltage dynamics method to use and a 'params' dictionary parameters to pass to that function.
threshold_reset_method : dict
dictionary containing the 'name' of the threshold dynamics method to use and a 'params' dictionary parameters to pass to that function.
init_voltage : float
initial voltage value
init_threshold : float
initial spike threshold value
init_AScurrents : np.ndarray
initial afterspike current vector. one element per element of asc_tau_array.
"""
TYPE = "GLIF"
def __init__(self, El, dt, asc_tau_array, R_input, C, asc_amp_array, spike_cut_length, th_inf, th_adapt, coeffs,
AScurrent_dynamics_method, voltage_dynamics_method, threshold_dynamics_method,
AScurrent_reset_method, voltage_reset_method, threshold_reset_method,
init_voltage, init_threshold, init_AScurrents, **kwargs):
""" Initialize the neuron."""
self.type = GlifNeuron.TYPE
self.El = El
self.dt = dt
self.asc_tau_array = np.array(asc_tau_array)
self.R_input = R_input
self.C = C
self.asc_amp_array = np.array(asc_amp_array)
self.spike_cut_length = int(spike_cut_length)
self.th_inf = th_inf
self.th_adapt = th_adapt
self.threshold_components = None
self.init_voltage = init_voltage
self.init_threshold = init_threshold
self.init_AScurrents = init_AScurrents
assert len(asc_tau_array) == len(asc_amp_array), Exception("After-spike current vector must have same length as asc_tau_array (%d vs %d)" % (asc_amp_array, asc_tau_array))
assert len(self.init_AScurrents) == len(self.asc_tau_array), Exception("init_AScurrents length (%d) must have same length as asc_tau_array (%d)" % (len(self.init_AScurrents), len(self.asc_tau_array)))
# values computed based on inputs
self.k = 1.0 / self.asc_tau_array
self.G = 1.0 / self.R_input
# Values that can be fit: They scale the input values.
# These are allowed to have default values because they are going to get optimized.
self.coeffs = {
'th_inf': 1,
'C': 1,
'G': 1,
'b': 1,
'a': 1,
'asc_amp_array': np.ones(len(self.asc_tau_array))
}
self.coeffs.update(coeffs)
logging.debug('spike cut length: %d' % self.spike_cut_length)
# initialize dynamics methods
self.AScurrent_dynamics_method = self.configure_library_method('AScurrent_dynamics_method', AScurrent_dynamics_method)
self.voltage_dynamics_method = self.configure_library_method('voltage_dynamics_method', voltage_dynamics_method)
self.threshold_dynamics_method = self.configure_library_method('threshold_dynamics_method', threshold_dynamics_method)
# initialize reset methods
self.AScurrent_reset_method = self.configure_library_method('AScurrent_reset_method', AScurrent_reset_method)
self.voltage_reset_method = self.configure_library_method('voltage_reset_method', voltage_reset_method)
self.threshold_reset_method = self.configure_library_method('threshold_reset_method', threshold_reset_method)
def __str__(self):
return json.dumps(self.to_dict(), default=ju.json_handler, indent=2)
@property
def tau_m(self):
return self.R_input*self.C
[docs] @classmethod
def from_dict(cls, d):
return cls(El = d['El'],
dt = d['dt'],
asc_tau_array = d['asc_tau_array'],
R_input = d['R_input'],
C = d['C'],
asc_amp_array = d['asc_amp_array'],
spike_cut_length = d['spike_cut_length'],
th_inf = d['th_inf'],
th_adapt = d['th_adapt'],
coeffs = d.get('coeffs', {}),
AScurrent_dynamics_method = d['AScurrent_dynamics_method'],
voltage_dynamics_method = d['voltage_dynamics_method'],
threshold_dynamics_method = d['threshold_dynamics_method'],
voltage_reset_method = d['voltage_reset_method'],
AScurrent_reset_method = d['AScurrent_reset_method'],
threshold_reset_method = d['threshold_reset_method'],
init_voltage = d['init_voltage'],
init_threshold = d['init_threshold'],
init_AScurrents = d['init_AScurrents'])
[docs] def to_dict(self):
""" Convert the neuron to a serializable dictionary. """
return {
'type': self.type,
'El': self.El,
'dt': self.dt,
'asc_tau_array': copy.deepcopy(self.asc_tau_array),
'R_input': self.R_input,
'C': self.C,
'asc_amp_array': copy.deepcopy(self.asc_amp_array),
'spike_cut_length': self.spike_cut_length,
'th_inf': self.th_inf,
'th_adapt': self.th_adapt,
'coeffs': copy.deepcopy(self.coeffs),
'AScurrent_dynamics_method': copy.deepcopy(self.AScurrent_dynamics_method),
'voltage_dynamics_method': copy.deepcopy(self.voltage_dynamics_method),
'threshold_dynamics_method': copy.deepcopy(self.threshold_dynamics_method),
'AScurrent_reset_method': copy.deepcopy(self.AScurrent_reset_method),
'voltage_reset_method': copy.deepcopy(self.voltage_reset_method),
'threshold_reset_method': copy.deepcopy(self.threshold_reset_method),
'init_voltage': self.init_voltage,
'init_threshold': self.init_threshold,
'init_AScurrents': copy.deepcopy(self.init_AScurrents),
'El_reference': self.El
}
[docs] def dynamics(self, voltage_t0, threshold_t0, AScurrents_t0, inj, time_step, spike_time_steps):
""" Update the voltage, threshold, and afterspike currents of the neuron for a single time step.
Parameters
----------
voltage_t0 : float
the current voltage of the neuron
threshold_t0 : float
the current spike threshold level of the neuron
AScurrents_t0 : np.ndarray
the current state of the afterspike currents in the neuron
inj : float
the current value of the current injection into the neuron
time_step : int
the current time step of the neuron simulation
spike_time_steps : list
a list of all of the time steps of spikes in the neuron
Returns
-------
tuple
voltage_t1 (voltage at next time step), threshold_t1 (threshold at next time step), AScurrents_t1 (afterspike currents at next time step)
"""
AScurrents_t1 = self.AScurrent_dynamics_method(self, AScurrents_t0, time_step, spike_time_steps)
voltage_t1 = self.voltage_dynamics_method(self, voltage_t0, AScurrents_t0, inj)
threshold_t1 = self.threshold_dynamics_method(self, threshold_t0, voltage_t0, AScurrents_t0, inj)
return voltage_t1, threshold_t1, AScurrents_t1
[docs] def reset(self, voltage_t0, threshold_t0, AScurrents_t0):
""" Apply reset rules to the neuron's voltage, threshold, and afterspike currents assuming a spike has occurred (voltage is above threshold).
Parameters
----------
voltage_t0 : float
the current voltage of the neuron
threshold_t0 : float
the current spike threshold level of the neuron
AScurrents_t0 : np.ndarray
the current state of the afterspike currents in the neuron
Returns
-------
tuple
voltage_t1 (voltage at next time step), threshold_t1 (threshold at next time step), AScurrents_t1 (afterspike currents at next time step)
"""
AScurrents_t1 = self.AScurrent_reset_method(self, AScurrents_t0)
voltage_t1 = self.voltage_reset_method(self, voltage_t0)
threshold_t1 = self.threshold_reset_method(self, threshold_t0, voltage_t1)
bad_reset_flag=False
if voltage_t1 > threshold_t1:
bad_reset_flag=True
#TODO put this back in eventually but would rather debug right now
# raise GlifBadResetException("Voltage reset above threshold: voltage_t1 (%f) threshold_t1 (%f), voltage_t0 (%f) threshold_t0 (%f) AScurrents_t0 (%s)" % ( voltage_t1, threshold_t1, voltage_t0, threshold_t0, repr(AScurrents_t0)), voltage_t1 - threshold_t1)
return voltage_t1, threshold_t1, AScurrents_t1, bad_reset_flag
[docs] def run(self, stim):
""" Run neuron simulation over a given stimulus. This steps through the stimulus applying dynamics equations.
After each step it checks if voltage is above threshold. If so, self.spike_cut_length NaNs are inserted
into the output voltages, reset rules are applied to the voltage, threshold, and afterspike currents, and the
simulation resumes.
Parameters
----------
stim : np.ndarray
vector of scalar current values
Returns
-------
dict
a dictionary containing:
'voltage': simulated voltage values,
'threshold': threshold values during the simulation,
'AScurrents': afterspike current values during the simulation,
'grid_spike_times': spike times (in uits of self.dt) aligned to simulation time steps,
'interpolated_spike_times': spike times (in units of self.dt) linearly interpolated between time steps,
'spike_time_steps': the indices of grid spike times,
'interpolated_spike_voltage': voltage of the simulation at interpolated spike times,
'interpolated_spike_threshold': threshold of the simulation at interpolated spike times
"""
bad_reset_flag=False
# initialize the voltage, threshold, and afterspike current values
voltage_t0 = self.init_voltage
threshold_t0 = self.init_threshold
AScurrents_t0 = self.init_AScurrents
self.threshold_components = None #get rid of lingering method data
num_time_steps = len(stim)
num_AScurrents = len(AScurrents_t0)
# pre-allocate the output voltages, thresholds, and after-spike currents
voltage_out=np.empty(num_time_steps)
voltage_out[:]=np.nan
threshold_out=np.empty(num_time_steps)
threshold_out[:]=np.nan
AScurrents_out=np.empty(shape=(num_time_steps, num_AScurrents))
AScurrents_out[:]=np.nan
# array that will hold spike indices
spike_time_steps = []
grid_spike_times = []
interpolated_spike_times = []
interpolated_spike_voltage = []
interpolated_spike_threshold = []
time_step = 0
while time_step < num_time_steps:
if time_step % 10000 == 0:
logging.info("time step %d / %d" % (time_step, num_time_steps))
# compute voltage, threshold, and ascurrents at current time step
(voltage_t1, threshold_t1, AScurrents_t1) = self.dynamics(voltage_t0, threshold_t0, AScurrents_t0, stim[time_step], time_step, spike_time_steps)
#if the voltage is bigger than the threshold record the spike and reset the values
if voltage_t1 > threshold_t1:
# spike_time_steps are stimulus indices when voltage surpassed threshold
spike_time_steps.append(time_step)
grid_spike_times.append(time_step * self.dt)
# compute higher fidelity spike time/voltage/threshold by linearly interpolating
interpolated_spike_times.append(interpolate_spike_time(self.dt, time_step, threshold_t0, threshold_t1, voltage_t0, voltage_t1))
interpolated_spike_time_offset = interpolated_spike_times[-1] - (time_step - 1) * self.dt
interpolated_spike_voltage.append(interpolate_spike_value(self.dt, interpolated_spike_time_offset, voltage_t0, voltage_t1))
interpolated_spike_threshold.append(interpolate_spike_value(self.dt, interpolated_spike_time_offset, threshold_t0, threshold_t1))
# reset voltage, threshold, and afterspike currents
# Note that these values are not ever recorded unless the spike cut length doesnt happen (this doesnt seem quite right)
(voltage_t0, threshold_t0, AScurrents_t0, bad_reset_flag) = self.reset(voltage_t1, threshold_t1, AScurrents_t1)
# if we are not integrating during the spike (which includes right now), insert nans then jump ahead
# TODO MAYBE ONE LAST NAN SHOULD BE INSERTED AND THIS VALUE SHOULD BE RECORDED FOR CONSISTANCY
if self.spike_cut_length > 0:
n = self.spike_cut_length
cut_past_end = (time_step + n) >= len(voltage_out)
if cut_past_end:
n = len(voltage_out) - time_step
voltage_out[time_step:time_step+n] = np.nan
threshold_out[time_step:time_step+n] = np.nan
AScurrents_out[time_step:time_step+n,:] = np.nan
if not cut_past_end:
voltage_out[time_step+n] = voltage_t0
threshold_out[time_step+n] = threshold_t0
AScurrents_out[time_step+n,:] = AScurrents_t0
time_step += self.spike_cut_length+1
else:
voltage_out[time_step] = voltage_t0
threshold_out[time_step] = threshold_t0
AScurrents_out[time_step,:] = AScurrents_t0
time_step += 1
if bad_reset_flag:
voltage_out[time_step:time_step+5] = voltage_t0
threshold_out[time_step:time_step+5] = threshold_t0
AScurrents_out[time_step:time_step+5] = AScurrents_t0
break
else:
# there was no spike, store the next voltages
voltage_out[time_step] = voltage_t1
threshold_out[time_step] = threshold_t1
AScurrents_out[time_step,:] = AScurrents_t1
voltage_t0 = voltage_t1
threshold_t0 = threshold_t1
AScurrents_t0 = AScurrents_t1
time_step += 1
return {
'voltage': voltage_out,
'threshold': threshold_out,
'AScurrents': AScurrents_out,
'grid_spike_times': np.array(grid_spike_times),
'interpolated_spike_times': np.array(interpolated_spike_times),
'spike_time_steps': np.array(spike_time_steps),
'interpolated_spike_voltage': np.array(interpolated_spike_voltage),
'interpolated_spike_threshold': np.array(interpolated_spike_threshold)
}
# TODO: DEPRICATE
# def get_threshold_components(self):
# if self.threshold_components is None:
# self.threshold_components = { 'spike': [0], 'voltage': [0] }
#
# return self.threshold_components
[docs] def append_threshold_components(self, spike, voltage):
self.threshold_components['spike'].append(spike)
self.threshold_components['voltage'].append(voltage)
# TODO: DEPRICATE
# def reset_threshold_components(self):
# self.threshold_components = None
[docs]def interpolate_spike_time(dt, time_step, threshold_t0, threshold_t1, voltage_t0, voltage_t1):
""" Given two voltage and threshold values, the dt between them and the initial time step, interpolate
a spike time within the dt interval by intersecting the two lines. """
return time_step*dt + line_crossing_x(dt, voltage_t0, voltage_t1, threshold_t0, threshold_t1)
[docs]def interpolate_spike_value(dt, interpolated_spike_time_offset, v0, v1):
""" Take a value at two adjacent time steps and linearly interpolate what the value would be
at an offset between the two time steps. """
return v0 + (v1 - v0) * interpolated_spike_time_offset / dt
[docs]def line_crossing_x(dx, a0, a1, b0, b1):
""" Find the x value of the intersection of two lines. """
assert type(a0) != int and type(a1) != int and type(b0) != int and type(b1) != int, Exception("Do not pass integers into this function!")
return dx * (b0 - a0) / ( (a1 - a0) - (b1 - b0) )
[docs]def line_crossing_y(dx, a0, a1, b0, b1):
""" Find the y value of the intersection of two lines. """
return b0 + (b1 - b0) * (b0 - a0) / ((a1 - a0) - (b1 - b0))