Source code for bmtk.simulator.filternet.sonata_adaptors

import os
import json
import types
import numpy as np
from six import string_types

from bmtk.simulator.core.sonata_reader import NodeAdaptor, SonataBaseNode, EdgeAdaptor, SonataBaseEdge


[docs] class FilterNode(SonataBaseNode): def __init__(self, node, prop_adaptor): super(FilterNode, self).__init__(node, prop_adaptor) self._jitter = (0.0, 0.0) @property def non_dom_params(self): return self._prop_adaptor.non_dom_params(self._node) @property def tuning_angle(self): return self._prop_adaptor.tuning_angle(self._node) @property def predefined_jitter(self): return self._prop_adaptor.predefined_jitter @property def jitter(self): if self.predefined_jitter: return (self._node['jitter_lower'], self._node['jitter_upper']) else: return self._jitter @jitter.setter def jitter(self, val): self._jitter = val @property def sf_sep(self): return self._node['sf_sep'] @property def weights(self): return self._prop_adaptor.weights(self._node) @property def kpeaks(self): return self._prop_adaptor.kpeaks(self._node) @property def delays(self): return self._prop_adaptor.delays(self._node) @property def weights_non_dom(self): return self._prop_adaptor.weights_non_dom(self._node) @property def kpeaks_non_dom(self): return self._prop_adaptor.kpeaks_non_dom(self._node) @property def delays_non_dom(self): return self._prop_adaptor.delays_non_dom(self._node) @property def sigma_f(self): return self._prop_adaptor.sigma_f(self._node) @property def b_t(self): return self._prop_adaptor.b_t(self._node) @property def order_t(self): return self._prop_adaptor.order_t(self._node) @property def t_mod_freq(self): return self._prop_adaptor.t_mod_freq(self._node) @property def sp_mod_freq(self): return self._prop_adaptor.sp_mod_freq(self._node) @property def psi(self): return self._prop_adaptor.psi(self._node) @property def delay(self): return self._prop_adaptor.delay(self._node) @property def amplitude(self): return self._prop_adaptor.amplitude(self._node) @property def direction(self): return self._prop_adaptor.direction(self._node)
[docs] class FilterNodeAdaptor(NodeAdaptor):
[docs] def get_node(self, sonata_node): return FilterNode(sonata_node, self)
[docs] @staticmethod def preprocess_node_types(network, node_population): NodeAdaptor.preprocess_node_types(network, node_population) node_type_ids = np.unique(node_population.type_ids) node_types_table = node_population.types_table if 'non_dom_params' in node_types_table.columns: for nt_id in node_type_ids: node_type = node_types_table[nt_id] dynamics_params = node_type['non_dom_params'] if isinstance(dynamics_params, dict): continue if dynamics_params is None: continue params_dir = network.get_component('filter_models_dir') params_path = os.path.join(params_dir, dynamics_params) try: params_val = json.load(open(params_path, 'r')) node_type['non_dom_params'] = params_val except Exception: # TODO: Check dynamics_params before network.io.log_exception('Could not find node dynamics_params file {}.'.format(params_path))
[docs] @classmethod def patch_adaptor(cls, adaptor, node_group, network): node_adaptor = NodeAdaptor.patch_adaptor(adaptor, node_group, network) if 'non_dom_params' in node_group.all_columns: node_adaptor.non_dom_params = types.MethodType(non_dom_params, node_adaptor) else: node_adaptor.non_dom_params = types.MethodType(return_none, node_adaptor) jitter_lower = 'jitter_lower' in node_group.all_columns jitter_upper = 'jitter_upper' in node_group.all_columns if jitter_lower and jitter_upper: node_adaptor.predefined_jitter = True elif jitter_upper ^ jitter_lower: raise Exception('Need to define both jitter_lower and jitter_upper (or leave both empty)') else: node_adaptor.predefined_jitter = False if 'tuning_angle' in node_group.all_columns: node_adaptor.tuning_angle = types.MethodType(tuning_angle_preset, node_adaptor) else: node_adaptor.tuning_angle = types.MethodType(tuning_angle_rand, node_adaptor) find_weight_params(node_group, node_adaptor) find_kpeaks_params(node_group, node_adaptor) find_delays_params(node_group, node_adaptor) find_nondom_weight_params(node_group, node_adaptor) find_nondom_kpeaks_params(node_group, node_adaptor) find_nondom_delays_params(node_group, node_adaptor) find_spectrotemporal_params(node_group, node_adaptor) return node_adaptor
[docs] def non_dom_params(self, node): return node['non_dom_params']
[docs] def return_none(self, node): return None
[docs] def tuning_angle_preset(self, node): return node['tuning_angle']
[docs] def tuning_angle_rand(self, node): return np.random.uniform(0.0, 360.0)
[docs] def weights(self, node): return node['weights']
[docs] def find_weight_params(node_group, node_adaptor): if 'weights' in node_group.all_columns: node_adaptor.weights = types.MethodType(lambda self, node: node['weights'], node_adaptor) elif 'opt_wts' in node_group.all_columns: node_adaptor.weights = types.MethodType(lambda self, node: node['opt_wts'], node_adaptor) elif 'weights_dom' in node_group.all_columns: node_adaptor.weights = types.MethodType(lambda self, node: node['weights_dom'], node_adaptor) elif 'weight_dom_0' in node_group.all_columns and 'weight_dom_1' in node_group.all_columns: node_adaptor.weights = types.MethodType(lambda self, node: [node['weight_dom_0'], node['weight_dom_1']], node_adaptor) else: node_adaptor.weights = types.MethodType(return_none, node_adaptor)
[docs] def find_kpeaks_params(node_group, node_adaptor): if 'kpeaks' in node_group.all_columns: node_adaptor.kpeaks = types.MethodType(lambda self, node: node['kpeaks'], node_adaptor) elif 'opt_kpeaks' in node_group.all_columns: node_adaptor.kpeaks = types.MethodType(lambda self, node: node['opt_kpeaks'], node_adaptor) elif 'kpeaks_dom' in node_group.all_columns: node_adaptor.kpeaks = types.MethodType(lambda self, node: node['kpeaks_dom'], node_adaptor) elif 'kpeaks_dom_0' in node_group.all_columns and 'kpeaks_dom_1' in node_group.all_columns: node_adaptor.kpeaks = types.MethodType(lambda self, node: [node['kpeaks_dom_0'], node['kpeaks_dom_1']], node_adaptor) else: node_adaptor.kpeaks = types.MethodType(return_none, node_adaptor)
[docs] def find_delays_params(node_group, node_adaptor): if 'delays' in node_group.all_columns: node_adaptor.delays = types.MethodType(lambda self, node: node['delays'], node_adaptor) elif 'opt_delays' in node_group.all_columns: node_adaptor.delays = types.MethodType(lambda self, node: node['opt_delays'], node_adaptor) elif 'delays_dom' in node_group.all_columns: node_adaptor.delays = types.MethodType(lambda self, node: node['delays_dom'], node_adaptor) elif 'delays_dom_0' in node_group.all_columns and 'delays_dom_1' in node_group.all_columns: node_adaptor.delays = types.MethodType(lambda self, node: [node['delays_dom_0'], node['delays_dom_1']], node_adaptor) elif 'delay_dom_0' in node_group.all_columns and 'delay_dom_1' in node_group.all_columns: node_adaptor.delays = types.MethodType(lambda self, node: [node['delay_dom_0'], node['delay_dom_1']], node_adaptor) else: node_adaptor.delays = types.MethodType(return_none, node_adaptor)
[docs] def find_nondom_weight_params(node_group, node_adaptor): if 'weights_non_dom' in node_group.all_columns: node_adaptor.weights_non_dom = types.MethodType(lambda self, node: node['weights_non_dom'], node_adaptor) elif 'weight_non_dom_0' in node_group.all_columns and 'weight_non_dom_1' in node_group.all_columns: node_adaptor.weights_non_dom = types.MethodType( lambda self, node: [node['weight_non_dom_0'], node['weight_non_dom_1']], node_adaptor ) else: node_adaptor.weights_non_dom = types.MethodType(return_none, node_adaptor)
[docs] def find_nondom_kpeaks_params(node_group, node_adaptor): if 'kpeaks_non_dom' in node_group.all_columns: node_adaptor.kpeaks_non_dom = types.MethodType(lambda self, node: node['kpeaks_non_dom'], node_adaptor) elif 'kpeaks_non_dom_0' in node_group.all_columns and 'kpeaks_non_dom_1' in node_group.all_columns: node_adaptor.kpeaks_non_dom = types.MethodType( lambda self, node: [node['kpeaks_non_dom_0'], node['kpeaks_non_dom_1']], node_adaptor ) else: node_adaptor.kpeaks_non_dom = types.MethodType(return_none, node_adaptor)
[docs] def find_nondom_delays_params(node_group, node_adaptor): if 'delays_non_dom' in node_group.all_columns: node_adaptor.delays_non_dom = types.MethodType(lambda self, node: node['delays_non_dom'], node_adaptor) elif 'delay_non_dom_0' in node_group.all_columns and 'delay_non_dom_1' in node_group.all_columns: node_adaptor.delays_non_dom = types.MethodType( lambda self, node: [node['delay_non_dom_0'], node['delay_non_dom_1']], node_adaptor ) else: node_adaptor.delays_non_dom = types.MethodType(return_none, node_adaptor)
[docs] def find_spectrotemporal_params(node_group, node_adaptor): if 'sigma_f' in node_group.all_columns: node_adaptor.sigma_f = types.MethodType(lambda self, node: node['sigma_f'], node_adaptor) else: node_adaptor.sigma_f = types.MethodType(return_none, node_adaptor) if 'b_t' in node_group.all_columns: node_adaptor.b_t = types.MethodType(lambda self, node: node['b_t'], node_adaptor) else: node_adaptor.b_t = types.MethodType(return_none, node_adaptor) if 'order_t' in node_group.all_columns: node_adaptor.order_t = types.MethodType(lambda self, node: node['order_t'], node_adaptor) else: node_adaptor.order_t = types.MethodType(return_none, node_adaptor) if 't_mod_freq' in node_group.all_columns: node_adaptor.t_mod_freq = types.MethodType(lambda self, node: node['t_mod_freq'], node_adaptor) else: node_adaptor.t_mod_freq = types.MethodType(return_none, node_adaptor) if 'sp_mod_freq' in node_group.all_columns: node_adaptor.sp_mod_freq = types.MethodType(lambda self, node: node['sp_mod_freq'], node_adaptor) else: node_adaptor.sp_mod_freq = types.MethodType(return_none, node_adaptor) if 'psi' in node_group.all_columns: node_adaptor.psi = types.MethodType(lambda self, node: node['psi'], node_adaptor) if isinstance(node_adaptor.psi, string_types): node_adaptor.psi = eval(node_adaptor.psi.replace('pi', 'np.pi')) else: node_adaptor.psi = types.MethodType(return_none, node_adaptor) if 'delay' in node_group.all_columns: node_adaptor.delay = types.MethodType(lambda self, node: node['delay'], node_adaptor) else: node_adaptor.delay = types.MethodType(return_none, node_adaptor) if 'amplitude' in node_group.all_columns: node_adaptor.amplitude = types.MethodType(lambda self, node: node['amplitude'], node_adaptor) else: node_adaptor.amplitude = types.MethodType(return_none, node_adaptor) if 'direction' in node_group.all_columns: node_adaptor.direction = types.MethodType(lambda self, node: node['direction'], node_adaptor) else: node_adaptor.direction = types.MethodType(return_none, node_adaptor)