# Copyright 2017. Allen Institute. All rights reserved
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
# following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
# disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import json
import functools
import nest
from six import string_types
import numpy as np
from bmtk.simulator.core.simulator_network import SimNetwork
from bmtk.simulator.pointnet.sonata_adaptors import PointNodeAdaptor, PointEdgeAdaptor
from bmtk.simulator.pointnet import pyfunction_cache
from bmtk.simulator.pointnet.io_tools import io
from bmtk.simulator.pointnet.nest_utils import nest_version
from .gids import GidPool
[docs]def set_spikes_nest2(node_id, nest_obj, spike_trains):
st = spike_trains.get_times(node_id=node_id)
if st is None or len(st) == 0:
return
st = np.array(st)
if np.any(st <= 0.0):
# NRN will fail if VecStim contains negative spike-time, throw an exception and log info for user
io.log_exception('spike train {} contains negative/zero time, unable to run virtual cell in NEST'.format(st))
st.sort()
nest.SetStatus([nest_obj], {'spike_times': st})
[docs]def set_spikes_nest3(node_id, nest_obj, spike_trains):
st = spike_trains.get_times(node_id=node_id)
if st is None or len(st) == 0:
return
st = np.array(st)
if np.any(st <= 0.0):
io.log_exception('spike train {} contains negative/zero time, unable to run virtual cell in NEST'.format(st))
st.sort()
nest.SetStatus(nest_obj, {'spike_times': st})
if nest_version[0] >= 3:
set_spikes = set_spikes_nest3
else:
set_spikes = set_spikes_nest2
[docs]class PointNetwork(SimNetwork):
def __init__(self, **properties):
super(PointNetwork, self).__init__(**properties)
self._io = io
self.__weight_functions = {}
self._params_cache = {}
self._virtual_ids_map = {}
self._batch_nodes = True
# self._nest_id_map = {}
self._nestid2nodeid_map = {}
self._nestid2gid = {}
self._nodes_table = {}
self._gid2nestid = {}
self._gid_map = GidPool()
self._virtual_gids = GidPool()
@property
def py_function_caches(self):
return pyfunction_cache
@property
def gid_map(self):
return self._gid_map
@property
def gid_pool(self):
return self._gid_map
[docs] def get_nodes_df(self, population):
nodes_adaptor = self.get_node_population(population)
return nodes_adaptor.nodes_df()
def __get_params(self, node_params):
if node_params.with_dynamics_params:
# TODO: use property, not name
return node_params['dynamics_params']
params_file = node_params[self._params_column]
# params_file = self._MT.params_column(node_params) #node_params['dynamics_params']
if params_file in self._params_cache:
return self._params_cache[params_file]
else:
params_dir = self.get_component('models_dir')
params_path = os.path.join(params_dir, params_file)
params_dict = json.load(open(params_path, 'r'))
self._params_cache[params_file] = params_dict
return params_dict
def _register_adaptors(self):
super(PointNetwork, self)._register_adaptors()
self._node_adaptors['sonata'] = PointNodeAdaptor
self._edge_adaptors['sonata'] = PointEdgeAdaptor
# TODO: reimplement with py_modules like in bionet
[docs] def add_weight_function(self, fnc, name=None, **kwargs):
fnc_name = name if name is not None else function.__name__
self.__weight_functions[fnc_name] = functools.partial(fnc)
[docs] def set_default_weight_function(self, fnc):
self.add_weight_function(fnc, 'default_weight_fnc', overwrite=True)
[docs] def get_weight_function(self, name):
return self.__weight_functions[name]
[docs] def get_node_id(self, population, node_id):
pop = self.get_node_population(population)
return pop.get_node(node_id)
[docs] def build_nodes(self):
for node_pop in self.node_populations:
pop_name = node_pop.name
gid_map = self.gid_map
gid_map.create_pool(pop_name)
if node_pop.internal_nodes_only:
for node in node_pop.get_nodes():
node.build()
gid_map.add_nestids(name=pop_name, node_ids=node.node_ids, nest_ids=node.nest_ids)
elif node_pop.mixed_nodes:
for node in node_pop.get_nodes():
if node.model_type != 'virtual':
node.build()
gid_map.add_nestids(name=pop_name, node_ids=node.node_ids, nest_ids=node.nest_ids)
[docs] def build_recurrent_edges(self, force_resolution=False):
recurrent_edge_pops = [ep for ep in self._edge_populations if not ep.virtual_connections]
if not recurrent_edge_pops:
return
for edge_pop in recurrent_edge_pops:
for edge in edge_pop.get_edges():
nest_srcs = self.gid_map.get_nestids(edge_pop.source_nodes, edge.source_node_ids)
nest_trgs = self.gid_map.get_nestids(edge_pop.target_nodes, edge.target_node_ids)
if np.isscalar(edge.nest_params['weight']):
edge.nest_params['weight'] = np.full(shape=len(nest_srcs),
fill_value=edge.nest_params['weight'])
self._nest_connect(nest_srcs, nest_trgs, conn_spec='one_to_one', syn_spec=edge.nest_params)
[docs] def find_edges(self, source_nodes=None, target_nodes=None):
# TODO: Move to parent
selected_edges = self._edge_populations[:]
if source_nodes is not None:
selected_edges = [edge_pop for edge_pop in selected_edges if edge_pop.source_nodes == source_nodes]
if target_nodes is not None:
selected_edges = [edge_pop for edge_pop in selected_edges if edge_pop.target_nodes == target_nodes]
return selected_edges
[docs] def add_spike_trains(self, spike_trains, node_set, sg_params={'precise_times': True}):
# Build the virtual nodes
src_nodes = [node_pop for node_pop in self.node_populations if node_pop.name in node_set.population_names()]
virt_gid_map = self._virtual_gids
for node_pop in src_nodes:
if node_pop.name in self._virtual_ids_map:
continue
virt_node_map = {}
if node_pop.virtual_nodes_only:
for node in node_pop.get_nodes():
nest_objs = nest.Create('spike_generator', node.n_nodes, sg_params)
nest_ids = nest_objs.tolist() if nest_version[0] >= 3 else nest_objs
virt_gid_map.add_nestids(name=node_pop.name, nest_ids=nest_ids, node_ids=node.node_ids)
for node_id, nest_obj, nest_id in zip(node.node_ids, nest_objs, nest_ids):
virt_node_map[node_id] = nest_id
set_spikes(node_id=node_id, nest_obj=nest_obj, spike_trains=spike_trains)
elif node_pop.mixed_nodes:
for node in node_pop.get_nodes():
if node.model_type != 'virtual':
continue
nest_ids = nest.Create('spike_generator', node.n_nodes, sg_params)
for node_id, nest_id in zip(node.node_ids, nest_ids):
virt_node_map[node_id] = nest_id
set_spikes(node_id=node_id, nest_id=nest_id, spike_trains=spike_trains)
self._virtual_ids_map[node_pop.name] = virt_node_map
# Create virtual synaptic connections
for source_reader in src_nodes:
for edge_pop in self.find_edges(source_nodes=source_reader.name):
for edge in edge_pop.get_edges():
nest_trgs = self.gid_map.get_nestids(edge_pop.target_nodes, edge.target_node_ids)
nest_srcs = virt_gid_map.get_nestids(edge_pop.source_nodes, edge.source_node_ids)
if np.isscalar(edge.nest_params['weight']):
edge.nest_params['weight'] = np.full(shape=len(nest_srcs),
fill_value=edge.nest_params['weight'])
self._nest_connect(nest_srcs, nest_trgs, conn_spec='one_to_one', syn_spec=edge.nest_params)
def _nest_connect(self, nest_srcs, nest_trgs, conn_spec='one_to_one', syn_spec=None):
"""Calls nest.Connect but with some extra error logging and exception handling."""
try:
nest.Connect(nest_srcs, nest_trgs, conn_spec=conn_spec, syn_spec=syn_spec)
except nest.kernel.NESTErrors.BadDelay as bde:
# An occuring issue is when dt > delay, add some extra messaging in log to help users fix problem.
res_kernel = nest.GetKernelStatus().get('resolution', 'NaN')
delay_edges = syn_spec.get('delay', 'NaN')
msg = 'synaptic "delay" value in edges ({}) is not compatible with simulator resolution/"dt" ({})'.format(
delay_edges, res_kernel
)
self.io.log_error('{}{}'.format(bde.errorname, bde.errormessage))
self.io.log_error(msg)
raise
except Exception as e:
# Record exception to log file.
self.io.log_error(str(e))
raise