Source code for bmtk.simulator.filternet.default_setters.cell_loaders

import numpy as np
from sympy.abc import x as symbolic_x
from sympy.abc import y as symbolic_y
from six import string_types

from bmtk.simulator.filternet.filters import TemporalFilterCosineBump, GaussianSpatialFilter, SpatioTemporalFilter, \
    WaveletFilter, SpectroTemporalFilter, GaborFilter
from bmtk.simulator.filternet.cell_models import TwoSubfieldLinearCell, OnUnit, OffUnit, LGNOnOffCell
from bmtk.simulator.filternet.transfer_functions import ScalarTransferFunction, MultiTransferFunction
from bmtk.simulator.filternet.utils import get_data_metrics_for_each_subclass, get_tcross_from_temporal_kernel
from bmtk.simulator.filternet.pyfunction_cache import py_modules


[docs] def create_two_sub_cell(dom_lf, non_dom_lf, dom_spont, non_dom_spont, onoff_axis_angle, subfield_separation, dom_location): dsp = str(dom_spont) ndsp = str(non_dom_spont) two_sub_transfer_fn = MultiTransferFunction((symbolic_x, symbolic_y), 'Heaviside(x+'+dsp+')*(x+'+dsp+')+Heaviside(y+'+ndsp+')*(y+'+ndsp+')') two_sub_cell = TwoSubfieldLinearCell(dom_lf, non_dom_lf, subfield_separation=subfield_separation, onoff_axis_angle=onoff_axis_angle, dominant_subfield_location=dom_location, transfer_function=two_sub_transfer_fn) return two_sub_cell
[docs] def createOneUnitOfTwoSubunitFilter(weights, kpeaks, delays, ttp_exp): delays = np.array(delays) filt = TemporalFilterCosineBump(weights, kpeaks, delays) tcross_ind = get_tcross_from_temporal_kernel(filt.get_kernel(threshold=-1.0).kernel) filt_sum = filt.get_kernel(threshold=-1.0).kernel[:tcross_ind].sum() # Calculate delay offset needed to match response latency with data and rebuild temporal filter del_offset = ttp_exp - tcross_ind if del_offset >= 0: delays_updated = delays + del_offset filt_new = TemporalFilterCosineBump(weights, kpeaks, delays_updated) else: raise Exception('del_offset < 0') return filt_new, filt_sum
[docs] def get_tf_params(node, dynamics_params, non_dom_props=False): if not non_dom_props: weights = node.weights if node.weights is not None else dynamics_params['opt_wts'] kpeaks = node.kpeaks if node.kpeaks is not None else dynamics_params['opt_kpeaks'] delays = node.delays if node.delays is not None else dynamics_params['opt_delays'] else: dp = dynamics_params or {} weights = node.weights_non_dom if node.weights_non_dom is not None else dp.get('opt_wts', None) kpeaks = node.kpeaks_non_dom if node.kpeaks_non_dom is not None else dp.get('opt_kpeaks', None) delays = node.delays_non_dom if node.delays_non_dom is not None else dp.get('opt_delays', None) if node.predefined_jitter: jitter_fnc = lambda a: np.array([np.random.uniform(x*node.jitter[0], x*node.jitter[1]) for x in a]) weights = jitter_fnc(weights) if weights is not None else weights kpeaks = jitter_fnc(kpeaks) if kpeaks is not None else kpeaks delays = jitter_fnc(delays) if delays is not None else delays return weights, kpeaks, delays
[docs] def get_sigma(node, dynamics_params): if 'spatial_size' in node: sigma = node['spatial_size'] elif 'sigma' in node: sigma = node['sigma'] elif 'spatial_size' in dynamics_params: sigma = dynamics_params['spatial_size'] elif 'sigma' in dynamics_params: sigma = dynamics_params['spatial_size'] else: # TODO: Raise warning sigma = (1.0, 1.0) if np.isscalar(sigma): sigma = (sigma, sigma) # convert from degree to SD return sigma[0]/3.0, sigma[1]/3.0
[docs] def get_wavelet_params(node, dynamics_params): t_mod_freq = node.t_mod_freq if node.t_mod_freq is not None else dynamics_params['t_mod_freq'] sp_mod_freq = node.sp_mod_freq if node.sp_mod_freq is not None else dynamics_params['sp_mod_freq'] if (t_mod_freq < 0) or (sp_mod_freq < 0): raise Exception("Temporal modulation frequency (t_mod_freq) and spectral modulation frequency " "(sp_mod_freq) must be non-negative.") Lambda = 1/np.linalg.norm([t_mod_freq, sp_mod_freq]) # Wavelength of oscillatory component sigma_f = node.sigma_f if node.sigma_f is not None else dynamics_params['sigma_f'] b_t = node.b_t if node.b_t is not None else dynamics_params['b_t'] order_t = node.order_t if node.order_t is not None else dynamics_params['order_t'] amplitude = node.amplitude if node.amplitude is not None else dynamics_params['amplitude'] if t_mod_freq != 0: theta = np.arctan(sp_mod_freq / t_mod_freq) else: theta = np.pi / 2 psi = node.psi if node.psi is not None else dynamics_params['psi'] if isinstance(psi, string_types): psi = eval(psi.replace('pi', 'np.pi')) delay = node.delay if node.delay is not None else dynamics_params['delay'] direction = node.direction if node.direction is not None else dynamics_params['direction'] if direction == 'up': direction = 1 elif direction == 'down': #dynamics_params['direction'] = -1 direction = -1 elif direction in [-1,0,1]: pass else: raise Exception("'Direction' filter parameter must be 'up' (or 1) for upward frequency modulation, " " or 'down' (or -1) for downward modulation, or 0 if not applicable.") return Lambda, sigma_f, b_t, order_t, theta, psi, delay, amplitude, direction
[docs] def default_cell_loader(node, template_name, dynamics_params): """ :param node: :param template_name: :param dynamics_params: :return: """ if template_name is None or template_name[0] == 'lgnmodel': # Create the spatial filter origin = (0.0, 0.0) translate = (node['x'], node['y']) sigma = get_sigma(node, dynamics_params) if 'spatial_rotation' in node: rotation = node['spatial_rotation'] else: rotation = 0.0 t_weights, t_kpeaks, t_delays = get_tf_params(node, dynamics_params) if template_name: model_name = template_name[1] else: model_name = node['pop_name'] if model_name in ['sONsOFF_001', 'sONsOFF']: # sON temporal filter t_weights_nd, t_kpeaks_nd, t_delays_nd = get_tf_params(node, node.non_dom_params, non_dom_props=True) sON_filt_new, sON_sum = createOneUnitOfTwoSubunitFilter(t_weights_nd, t_kpeaks_nd, t_delays_nd, 121.0) sOFF_filt_new, sOFF_sum = createOneUnitOfTwoSubunitFilter(t_weights, t_kpeaks, t_delays, 115.0) amp_on = 1.0 # set the non-dominant subunit amplitude to unity spont = 4.0 max_roff = 35.0 max_ron = 21.0 amp_off = -(max_roff/max_ron)*(sON_sum/sOFF_sum)*amp_on - (spont*(max_roff - max_ron))/(max_ron*sOFF_sum) # Create sON subunit: spatial_filter_on = GaussianSpatialFilter(translate=translate, sigma=sigma, origin=origin, rotation=rotation) linear_filter_son = SpatioTemporalFilter(spatial_filter_on, sON_filt_new, amplitude=amp_on) # Create sOFF subunit: spatial_filter_off = GaussianSpatialFilter(translate=translate, sigma=sigma, origin=origin, rotation=rotation) linear_filter_soff = SpatioTemporalFilter(spatial_filter_off, sOFF_filt_new, amplitude=amp_off) sf_sep = node.sf_sep if node.predefined_jitter: sf_sep = np.random.uniform(node.jitter[0]*sf_sep, node.jitter[1]*sf_sep) sep_ss_onoff_cell = create_two_sub_cell(linear_filter_soff, linear_filter_son, 0.5 * spont, 0.5 * spont, node.tuning_angle, sf_sep, translate) cell = sep_ss_onoff_cell elif model_name in ['sONtOFF_001', 'sONtOFF']: t_weights_nd, t_kpeaks_nd, t_delays_nd = get_tf_params(node, node.non_dom_params, non_dom_props=True) sON_filt_new, sON_sum = createOneUnitOfTwoSubunitFilter(t_weights_nd, t_kpeaks_nd, t_delays_nd, 93.5) tOFF_filt_new, tOFF_sum = createOneUnitOfTwoSubunitFilter(t_weights, t_kpeaks, t_delays, 64.8) # 64.8 amp_on = 1.0 # set the non-dominant subunit amplitude to unity spont = 5.5 max_roff = 46.0 max_ron = 31.0 amp_off = -0.7*(max_roff/max_ron)*(sON_sum/tOFF_sum)*amp_on - (spont*(max_roff - max_ron))/(max_ron*tOFF_sum) # Create sON subunit: spatial_filter_on = GaussianSpatialFilter(translate=translate, sigma=sigma, origin=origin, rotation=rotation) linear_filter_son = SpatioTemporalFilter(spatial_filter_on, sON_filt_new, amplitude=amp_on) # Create tOFF subunit: spatial_filter_off = GaussianSpatialFilter(translate=translate, sigma=sigma, origin=origin, rotation=rotation) linear_filter_toff = SpatioTemporalFilter(spatial_filter_off, tOFF_filt_new, amplitude=amp_off) sf_sep = node.sf_sep if node.predefined_jitter: sf_sep = np.random.uniform(node.jitter[0]*sf_sep, node.jitter[1]*sf_sep) sep_ts_onoff_cell = create_two_sub_cell(linear_filter_toff, linear_filter_son, 0.5 * spont, 0.5 * spont, node.tuning_angle, sf_sep, translate) cell = sep_ts_onoff_cell elif model_name == 'LGNOnOFFCell': wts = [node['weight_dom_0'], node['weight_dom_1']] kpeaks = [node['kpeaks_dom_0'], node['kpeaks_dom_1']] delays = [node['delay_dom_0'], node['delay_dom_1']] # transfer_function = ScalarTransferFunction('s') temporal_filter = TemporalFilterCosineBump(wts, kpeaks, delays) spatial_filter_on = GaussianSpatialFilter(sigma=node['sigma_on'], origin=origin, translate=translate) on_linear_filter = SpatioTemporalFilter(spatial_filter_on, temporal_filter, amplitude=20) spatial_filter_off = GaussianSpatialFilter(sigma=node['sigma_off'], origin=origin, translate=translate) off_linear_filter = SpatioTemporalFilter(spatial_filter_off, temporal_filter, amplitude=-20) cell = LGNOnOffCell(on_linear_filter, off_linear_filter) else: type_split = model_name.split('_') if len(type_split) == 1: cell_type = model_name tf_str = 'TF8' else: cell_type, tf_str = type_split[0], type_split[1] # Get spontaneous firing rate, either from the cell property of calculate from experimental data if 'spont_fr' in node: spont_fr = node['spont_fr'] else: exp_prs_dict = get_data_metrics_for_each_subclass(cell_type) subclass_prs_dict = exp_prs_dict[tf_str] spont_fr = subclass_prs_dict['spont_exp'][0] # Get filters transfer_function = ScalarTransferFunction('Heaviside(s+{})*(s+{})'.format(spont_fr, spont_fr)) temporal_filter = TemporalFilterCosineBump(t_weights, t_kpeaks, t_delays) spatial_filter = GaussianSpatialFilter(translate=translate, sigma=sigma, origin=origin, rotation=rotation) if cell_type.find('ON') >= 0: amplitude = 1.0 linear_filter = SpatioTemporalFilter(spatial_filter, temporal_filter, amplitude=amplitude) cell = OnUnit(linear_filter, transfer_function) elif cell_type.find('OFF') >= 0: amplitude = -1.0 linear_filter = SpatioTemporalFilter(spatial_filter, temporal_filter, amplitude=amplitude) cell = OffUnit(linear_filter, transfer_function) elif template_name[0] == 'audmodel': # Currently tying y to center freq translate = (node['y']) Lambda, sigma_f, b_t, order_t, theta, psi, delay, amplitude, direction = get_wavelet_params(node, dynamics_params) spectrotemporal_filter = WaveletFilter(translate, sigma_f, b_t, order_t, theta, Lambda, psi, delay, amplitude, direction) #spectrotemporal_filter = GaborFilter(translate, sigma_f, Lambda/np.cos(theta)/4, theta, Lambda, psi, amplitude, # direction) if template_name: model_name = template_name[1] else: model_name = node['pop_name'] if model_name == 'AUD_filt': # Create the spectro-temporal filter pass # transfer_function = ScalarTransferFunction('Heaviside(s+{})*(s+{})'.format(spont_fr, spont_fr)) cell_type = model_name # Get spontaneous firing rate, either from the cell property of calculate from experimental data if 'spont_fr' in node: spont_fr = node['spont_fr'] else: ''' exp_prs_dict = get_data_metrics_for_each_subclass(cell_type) subclass_prs_dict = exp_prs_dict[tf_str] spont_fr = subclass_prs_dict['spont_exp'][0] ''' spont_fr = 2 transfer_function = ScalarTransferFunction('Heaviside(s+{})*(s+{})'.format(spont_fr, spont_fr)) amplitude = 1.0 linear_filter = SpectroTemporalFilter(spectrotemporal_filter, amplitude=amplitude) cell = OnUnit(linear_filter, transfer_function) else: pass return cell
py_modules.add_cell_processor('default', default_cell_loader, overwrite=False) py_modules.add_cell_processor('preset_params', default_cell_loader, overwrite=False)