Source code for bmtk.simulator.pointnet.modules.ecephys_module

import nest
import numpy as np
import pandas as pd

from bmtk.simulator.core.modules.ecephys_module import ECEphysUnitsModule
from bmtk.simulator.pointnet.nest_utils import nest_version
from bmtk.simulator.pointnet.io_tools import io
from bmtk.utils.reports.spike_trains.spike_trains import SpikeTrains


[docs]def set_spikes_nest2(node_id, nest_obj, spike_trains): if isinstance(spike_trains, SpikeTrains): st = spike_trains.get_times(node_id) elif isinstance(spike_trains, (list, np.ndarray, pd.Series)): st = spike_trains if st is None or len(st) == 0: return st = np.array(st) if np.any(st <= 0.0): # NRN will fail if VecStim contains negative spike-time, throw an exception and log info for user io.log_exception('spike train {} contains negative/zero time, unable to run virtual cell in NEST'.format(st)) st.sort() nest.SetStatus([nest_obj], {'spike_times': st})
[docs]def set_spikes_nest3(node_id, nest_obj, spike_trains): if isinstance(spike_trains, SpikeTrains): st = spike_trains.get_times(node_id) elif isinstance(spike_trains, (list, np.ndarray, pd.Series)): st = spike_trains if st is None or len(st) == 0: return st = np.array(st) if np.any(st <= 0.0): io.log_exception('spike train {} contains negative/zero time, unable to run virtual cell in NEST'.format(st)) st.sort() nest.SetStatus(nest_obj, {'spike_times': st})
if nest_version[0] >= 3: set_spikes = set_spikes_nest3 else: set_spikes = set_spikes_nest2
[docs]class PointECEphysUnitsModule(ECEphysUnitsModule):
[docs] def initialize(self, sim): net = sim.net sg_params={'precise_times': True} node_set = net.get_node_set(self._node_set) self._mapping_strategy.build_map(node_set=node_set) src_nodes = [node_pop for node_pop in net.node_populations if node_pop.name in node_set.population_names()] virt_gid_map = net._virtual_gids total_spikes = 0 total_firing_neurons = 0 for node_pop in src_nodes: if node_pop.name in net._virtual_ids_map: continue virt_node_map = {} if node_pop.virtual_nodes_only: for node in node_pop.get_nodes(): nest_objs = nest.Create('spike_generator', node.n_nodes, sg_params) nest_ids = nest_objs.tolist() if nest_version[0] >= 3 else nest_objs virt_gid_map.add_nestids(name=node_pop.name, nest_ids=nest_ids, node_ids=node.node_ids) for node_id, nest_obj, nest_id in zip(node.node_ids, nest_objs, nest_ids): spike_trains = self._mapping_strategy.get_spike_trains(node_id, '') total_spikes += len(spike_trains) total_firing_neurons += 1 if len(spike_trains) > 0 else 0 virt_node_map[node_id] = nest_id set_spikes(node_id=node_id, nest_obj=nest_obj, spike_trains=spike_trains) elif node_pop.mixed_nodes: for node in node_pop.get_nodes(): if node.model_type != 'virtual': continue nest_ids = nest.Create('spike_generator', node.n_nodes, sg_params) for node_id, nest_id in zip(node.node_ids, nest_ids): virt_node_map[node_id] = nest_id set_spikes(node_id=node_id, nest_id=nest_id, spike_trains=spike_trains) net._virtual_ids_map[node_pop.name] = virt_node_map # Create virtual synaptic connections for source_reader in src_nodes: for edge_pop in net.find_edges(source_nodes=source_reader.name): for edge in edge_pop.get_edges(): nest_trgs = net.gid_map.get_nestids(edge_pop.target_nodes, edge.target_node_ids) nest_srcs = virt_gid_map.get_nestids(edge_pop.source_nodes, edge.source_node_ids) if np.isscalar(edge.nest_params['weight']): edge.nest_params['weight'] = np.full(shape=len(nest_srcs), fill_value=edge.nest_params['weight']) net._nest_connect(nest_srcs, nest_trgs, conn_spec='one_to_one', syn_spec=edge.nest_params)