import numpy as np
from collections import Counter
import numbers
import nest
import types
import pandas as pd
from bmtk.simulator.core.sonata_reader import NodeAdaptor, SonataBaseNode, EdgeAdaptor, SonataBaseEdge
from bmtk.simulator.pointnet.io_tools import io
from bmtk.simulator.pointnet.pyfunction_cache import py_modules
from bmtk.simulator.pointnet.glif_utils import convert_aibs2nest
from bmtk.simulator.pointnet.nest_utils import nest_version
NEST_SYNAPSE_MODEL_PROP = 'model' if nest_version[0] == 2 else 'synapse_model'
[docs]def all_null(node_group, column_name):
"""Helper function to determine if a column has any non-NULL values"""
types_table = node_group.parent.types_table
non_null_vals = [types_table[ntid][column_name] for ntid in np.unique(node_group.node_type_ids)
if types_table[ntid][column_name] is not None]
return len(non_null_vals) == 0
[docs]class PointNodeBatched(object):
def __init__(self, node_ids, gids, node_types_table, node_type_id):
self._n_nodes = len(node_ids)
self._node_ids = node_ids
self._gids = gids
self._nt_table = node_types_table
self._nt_id = node_type_id
self._nest_objs = []
self._nest_ids = []
@property
def n_nodes(self):
return self._n_nodes
@property
def node_ids(self):
return self._node_ids
@property
def gids(self):
return self._gids
@property
def nest_ids(self):
return self._nest_ids
@property
def nest_model(self):
return self._nt_table[self._nt_id]['model_template'].split(':')[1]
@property
def nest_params(self):
return self._nt_table[self._nt_id]['dynamics_params']
@property
def model_type(self):
return self._nt_table[self._nt_id]['model_type']
[docs] def build(self):
self._nest_objs = nest.Create(self.nest_model, self.n_nodes, self.nest_params)
self._nest_ids = self._nest_objs.tolist() if nest_version[0] >= 3 else self._nest_objs
[docs]class PointNode(SonataBaseNode):
def __init__(self, node, prop_adaptor):
super(PointNode, self).__init__(node, prop_adaptor)
self._nest_objs = []
self._nest_ids = []
@property
def n_nodes(self):
return 1
@property
def node_ids(self):
return [self._prop_adaptor.node_id(self._node)]
@property
def gids(self):
return [self._prop_adaptor.gid(self._node)]
@property
def nest_ids(self):
return self._nest_ids
@property
def nest_model(self):
return self._prop_adaptor.model_template(self._node)[1]
@property
def nest_params(self):
return self.dynamics_params
[docs] def build(self):
nest_model = self.nest_model
dynamics_params = self.dynamics_params
fnc_name = self._node['model_processing']
if fnc_name is None:
self._nest_objs = nest.Create(nest_model, 1, dynamics_params)
else:
cell_fnc = py_modules.cell_processor(fnc_name)
self._nest_objs = cell_fnc(nest_model, self._node, dynamics_params)
self._nest_ids = self._nest_objs.tolist() if nest_version[0] >= 3 else self._nest_objs
[docs]class PointNodeAdaptor(NodeAdaptor):
def __init__(self, network):
super(PointNodeAdaptor, self).__init__(network)
# Flag for determining if we can build multiple NEST nodes at once. If each individual node has unique
# NEST params or a model_processing function is being called then we must nest.Create for each individual cell.
# Otherwise we can try to call nest.Create for a batch of nodes that share the same properties
self._can_batch = True
@property
def batch_process(self):
return self._can_batch
@batch_process.setter
def batch_process(self, flag):
self._can_batch = flag
[docs] def get_node(self, sonata_node):
return PointNode(sonata_node, self)
[docs] def get_batches(self, node_group):
node_ids = node_group.node_ids
node_type_ids = node_group.node_type_ids
node_gids = node_group.gids
if node_gids is None:
node_gids = node_ids
ntids_counter = Counter(node_type_ids)
nid_groups = {nt_id: np.zeros(ntids_counter[nt_id], dtype=np.uint32) for nt_id in ntids_counter}
gid_groups = {nt_id: np.zeros(ntids_counter[nt_id], dtype=np.uint32) for nt_id in ntids_counter}
node_groups_counter = {nt_id: 0 for nt_id in ntids_counter}
for node_id, gid, node_type_id in zip(node_ids, node_gids, node_type_ids):
grp_indx = node_groups_counter[node_type_id]
nid_groups[node_type_id][grp_indx] = node_id
gid_groups[node_type_id][grp_indx] = gid
node_groups_counter[node_type_id] += 1
return [PointNodeBatched(nid_groups[nt_id], gid_groups[nt_id], node_group.parent.node_types_table, nt_id)
for nt_id in ntids_counter]
[docs] @staticmethod
def preprocess_node_types(network, node_population):
NodeAdaptor.preprocess_node_types(network, node_population)
node_types_table = node_population.types_table
if 'model_template' in node_types_table.columns and 'dynamics_params' in node_types_table.columns:
node_type_ids = np.unique(node_population.type_ids)
for nt_id in node_type_ids:
node_type_attrs = node_types_table[nt_id]
mtemplate = node_type_attrs['model_template']
dyn_params = node_type_attrs['dynamics_params']
if mtemplate.startswith('nest:glif') and dyn_params.get('type', None) == 'GLIF':
model_template, dynamics_params = convert_aibs2nest(mtemplate, dyn_params)
node_type_attrs['model_template'] = model_template
node_type_attrs['dynamics_params'] = dynamics_params
[docs] @staticmethod
def patch_adaptor(adaptor, node_group, network):
node_adaptor = NodeAdaptor.patch_adaptor(adaptor, node_group, network)
# If dynamics params is stored in the nodes.h5 then we have to build each node separate
if node_group.has_dynamics_params:
node_adaptor.batch_process = False
# If there is a non-null value in the model_processing column then it potentially means that every cell is
# uniquly built (currently model_processing is applied to each individ. cell) and nodes can't be batched
if 'model_processing' in node_group.columns:
node_adaptor.batch_process = False
elif 'model_processing' in node_group.all_columns and not all_null(node_group, 'model_processing'):
node_adaptor.batch_process = False
if node_adaptor.batch_process:
io.log_info('Batch processing nodes for {}/{}.'.format(node_group.parent.name, node_group.group_id))
return node_adaptor
[docs]class PointEdge(SonataBaseEdge):
@property
def source_node_ids(self):
return [self._edge.source_node_id]
@property
def target_node_ids(self):
return [self._edge.target_node_id]
@property
def nest_params(self):
if self.model_template in py_modules.synapse_models:
src_node = self._prop_adaptor._network.get_node_id(self.source_population, self.source_node_id)
trg_node = self._prop_adaptor._network.get_node_id(self.target_population, self.target_node_id)
syn_model_fnc = py_modules.synapse_model(self.model_template)
else:
src_node = None
trg_node = None
syn_model_fnc = py_modules.synapse_models('default')
return syn_model_fnc(self, src_node, trg_node)
[docs]class PointEdgeBatched(object):
def __init__(self, source_nids, target_nids, nest_params):
self._src_nids = source_nids
self._trg_nids = target_nids
self._nest_params = nest_params
@property
def source_node_ids(self):
return self._src_nids
@property
def target_node_ids(self):
return self._trg_nids
@property
def nest_params(self):
return self._nest_params
[docs]class PointEdgeAdaptor(EdgeAdaptor):
def __init__(self, network):
super(PointEdgeAdaptor, self).__init__(network)
self._can_batch = True
@property
def batch_process(self):
return self._can_batch
@batch_process.setter
def batch_process(self, flag):
self._can_batch = flag
[docs] def synaptic_params(self, edge):
# TODO: THIS NEEDS to be replaced with call to synapse_models
params_dict = {'weight': self.syn_weight(edge, None, None), 'delay': edge.delay}
params_dict.update(edge.dynamics_params)
return params_dict
[docs] def get_edge(self, sonata_node):
return PointEdge(sonata_node, self)
[docs] @staticmethod
def preprocess_edge_types(network, edge_population):
# Fix for sonata/300_pointneurons
EdgeAdaptor.preprocess_edge_types(network, edge_population)
edge_types_table = edge_population.types_table
edge_type_ids = np.unique(edge_population.type_ids)
for et_id in edge_type_ids:
edge_type = edge_types_table[et_id]
if 'model_template' in edge_types_table.columns:
model_template = edge_type['model_template']
if model_template.startswith('nest'):
edge_type['model_template'] = model_template[5:]
[docs] def get_batches(self, edge_group):
src_ids = {}
trg_ids = {}
edge_types_table = edge_group.parent.edge_types_table
edge_type_ids = edge_group.edge_type_ids
et_id_counter = Counter(edge_type_ids)
tmp_df = pd.DataFrame({'etid': edge_type_ids, 'src_nids': edge_group.src_node_ids,
'trg_nids': edge_group.trg_node_ids})
if 'nsyns' in edge_group.columns:
tmp_df['nsyns'] = edge_group.get_dataset('nsyns')
if 'syn_weight' in edge_group.columns:
tmp_df['syn_weight'] = edge_group.get_dataset('syn_weight')
if 'delay' in edge_group.columns:
tmp_df['delay'] = edge_group.get_dataset('delay')
#for et_id, grp_vals in tmp_df.groupby('etid'):
# src_ids[et_id] = np.array(grp_vals['src_nids'])
# trg_ids[et_id] = np.array(grp_vals['trg_nids'])
type_params = {edge_id: {} for edge_id in et_id_counter.keys()}
src_pop = edge_group.parent.source_population
trg_pop = edge_group.parent.target_population
grp_df = None
src_nodes_df = None
trg_nodes_df = None
for edge_id, grp_vals in tmp_df.groupby('etid'):
edge_props = edge_types_table[edge_id]
n_edges = len(grp_vals)
# Get the model type
type_params[edge_id][NEST_SYNAPSE_MODEL_PROP] = edge_props['model_template']
# Add dynamics params
# TODO: Add to dataframe and if a part of hdf5 we can return any dynamics params as a list
type_params[edge_id].update(edge_props['dynamics_params'])
# get the delay parameter
if 'delay' in grp_vals.columns:
type_params[edge_id]['delay'] = grp_vals['delay']
elif 'delay' in edge_props.keys():
delay = edge_props['delay']
# For NEST 2.* 'delay' can be a single value, but for 3.* it requires a full array for each edge
delay = np.full(n_edges, delay) if nest_version[0] >= 3 else delay
type_params[edge_id]['delay'] = delay
weight_function = edge_types_table[edge_id].get('weight_function', None)
if weight_function is not None:
if grp_df is None:
grp_df = edge_group.to_dataframe()
src_nodes_df = self._network.get_nodes_df(src_pop)
trg_nodes_df = self._network.get_nodes_df(trg_pop)
edges = grp_df[grp_df['edge_type_id'] == edge_id]
target_nodes = trg_nodes_df.loc[edges['target_node_id'].values]
source_nodes = src_nodes_df.loc[edges['source_node_id'].values]
if not py_modules.has_synaptic_weight(weight_function):
err_msg = 'Unable to calculate synaptic weight for "{}" edges, missing "weight_function" ' \
'attribute value {} function.'.format(edge_group.parent.name, weight_function)
io.log_exception(err_msg)
weight_fnc = py_modules.synaptic_weight(weight_function)
type_params[edge_id]['weight'] = weight_fnc(edges, source_nodes, target_nodes)
else:
# Get nsyns as either an array or a constant. If not explcitiy specified assume nsyns = 1
if 'nsyns' in grp_vals.columns:
nsyns = grp_vals['nsyns'].values
else:
nsyns = edge_props.get('nsyns', 1)
# get syn_weight as either an array or constant. If not explicity stated throw an error
if 'syn_weight' in grp_vals.columns:
syn_weight = grp_vals['syn_weight'].values
elif 'syn_weight' in edge_props.keys():
syn_weight = edge_props['syn_weight']
else:
# TODO: Make more explicity. Or default to syn_weight of 0
raise Exception('Could not find syn_weight value')
# caluclate weight
type_params[edge_id]['weight'] = nsyns * syn_weight
yield PointEdgeBatched(source_nids=grp_vals['src_nids'].values, target_nids=grp_vals['trg_nids'].values,
nest_params=type_params[edge_id])
[docs] @staticmethod
def patch_adaptor(adaptor, edge_group):
edge_adaptor = EdgeAdaptor.patch_adaptor(adaptor, edge_group)
if 'weight_function' not in edge_group.all_columns and 'syn_weight' in edge_group.all_columns:
adaptor.syn_weight = types.MethodType(point_syn_weight, adaptor)
#else:
# edge_adaptor.batch_process = False
return edge_adaptor
[docs]def point_syn_weight(self, edge, src_node, trg_node):
return edge['syn_weight']*edge.nsyns