Source code for bmtk.simulator.bionet.bionetwork

# Copyright 2017. Allen Institute. All rights reserved
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
# following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
# disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import numpy as np
from neuron import h

from bmtk.simulator.core.simulator_network import SimNetwork
from bmtk.simulator.core import sonata_reader
from bmtk.simulator.bionet.biocell import BioCell, BioCellSpontSyn
from bmtk.simulator.bionet.pointprocesscell import PointProcessCell, PointProcessCellSpontSyns
from bmtk.simulator.bionet.pointsomacell import PointSomaCell
from bmtk.simulator.bionet.virtualcell import VirtualCell
from bmtk.simulator.bionet.morphology import Morphology
from bmtk.simulator.bionet.io_tools import io
from bmtk.simulator.bionet import nrn
from bmtk.simulator.bionet.sonata_adaptors import BioNodeAdaptor, BioEdgeAdaptor
from .gids import GidPool

# TODO: leave this import, it will initialize some of the default functions for building neurons/synapses/weights.
import bmtk.simulator.bionet.default_setters


pc = h.ParallelContext()  # object to access MPI methods
MPI_size = int(pc.nhost())
MPI_rank = int(pc.id())


[docs]class BioNetwork(SimNetwork): model_type_col = 'model_type' def __init__(self): super(BioNetwork, self).__init__() self._io = io # TODO: Find a better way that will allow users to register their own class self._model_type_map = { 'biophysical': BioCell, 'point_process': PointProcessCell, 'point_neuron': PointProcessCell, 'point_soma': PointSomaCell, 'virtual': VirtualCell } # self._morphologies_cache = {} # self._morphology_lookup = {} self._rank_node_gids = {} self._rank_node_ids = {} self._rank_nodes_by_model = {m_type: {} for m_type in self._model_type_map.keys()} self._remote_node_cache = {} self._virtual_nodes = {} self._disconnected_source_cells = {} self._cells_built = False self._connections_initialized = False self._gid_pool = GidPool() self.has_spont_syns = False self.spont_syns_filter = None self.spont_syns_times = None @property def gid_pool(self): return self._gid_pool @property def py_function_caches(self): return nrn
[docs] def set_spont_syn_activity(self, precell_filter, timestamps): self._model_type_map = { 'biophysical': BioCellSpontSyn, 'point_process': PointProcessCellSpontSyns, 'point_neuron': PointProcessCellSpontSyns, 'point_soma': PointSomaCell, 'virtual': VirtualCell } self.has_spont_syns = True self.spont_syns_filter = precell_filter self.spont_syns_times = timestamps
[docs] def get_node_id(self, population, node_id): if node_id in self._rank_node_ids[population]: return self._rank_node_ids[population][node_id].node elif node_id in self._remote_node_cache[population]: return self._remote_node_cache[population][node_id] else: node_pop = self.get_node_population(population) node = node_pop.get_node(node_id) self._remote_node_cache[population][node_id] = node return node
[docs] def cell_type_maps(self, model_type): return self._rank_nodes_by_model[model_type]
[docs] def get_cell_node_id(self, population, node_id): return self._rank_node_ids[population].get(node_id, None)
[docs] def get_cell_gid(self, gid): return self._rank_node_gids[gid]
[docs] def get_local_cells(self): return self._rank_node_gids
@property def local_gids(self): return list(self._rank_node_gids.keys())
[docs] def add_nodes(self, node_population): self._gid_pool.add_pool(node_population.name, node_population.n_nodes()) super(BioNetwork, self).add_nodes(node_population)
[docs] def get_virtual_cells(self, population, node_id, spike_trains): if node_id in self._virtual_nodes[population]: return self._virtual_nodes[population][node_id] else: node = self.get_node_id(population, node_id) virt_cell = VirtualCell(node, population, spike_trains) self._virtual_nodes[population][node_id] = virt_cell return virt_cell
[docs] def get_disconnected_cell(self, population, node_id, spike_trains): if population not in self._disconnected_source_cells: self._disconnected_source_cells[population] = {} if node_id in self._disconnected_source_cells[population]: virt_cell = self._disconnected_source_cells[population][node_id] else: node = self.get_node_id(population, node_id) virt_cell = VirtualCell(node, population, spike_trains) self._disconnected_source_cells[population][node_id] = virt_cell return virt_cell
def _build_cell(self, bionode, population_name): if bionode.model_type in self._model_type_map: cell = self._model_type_map[bionode.model_type](bionode, population_name=population_name, bionetwork=self) self._rank_nodes_by_model[bionode.model_type][cell.gid] = cell return cell else: self.io.log_exception('Unrecognized model_type {}.'.format(bionode.model_type)) def _register_adaptors(self): super(BioNetwork, self)._register_adaptors() self._node_adaptors['sonata'] = BioNodeAdaptor self._edge_adaptors['sonata'] = BioEdgeAdaptor
[docs] def build_nodes(self): for node_pop in self.node_populations: self._remote_node_cache[node_pop.name] = {} node_ids_map = {} if node_pop.internal_nodes_only: for node in node_pop[MPI_rank::MPI_size]: cell = self._build_cell(bionode=node, population_name=node_pop.name) node_ids_map[node.node_id] = cell self._rank_node_gids[cell.gid] = cell elif node_pop.mixed_nodes: # node population contains both internal and virtual (external) nodes and the virtual nodes must be # filtered out self._virtual_nodes[node_pop.name] = {} for node in node_pop[MPI_rank::MPI_size]: if node.model_type == 'virtual': continue else: cell = self._build_cell(bionode=node, population_name=node_pop.name) node_ids_map[node.node_id] = cell self._rank_node_gids[cell.gid] = cell elif node_pop.virtual_nodes_only: self._virtual_nodes[node_pop.name] = {} self._rank_node_ids[node_pop.name] = node_ids_map # self.make_morphologies() # self.set_seg_props() # set segment properties by creating Morphologies # self.calc_seg_coords() # use for computing the ECP self._cells_built = True self.io.barrier()
# def set_seg_props(self): # """Set morphological properties for biophysically (morphologically) detailed cells""" # for _, morphology in self._morphologies_cache.items(): # morphology.set_seg_props() # def calc_seg_coords(self): # """Needed for the ECP calculations""" # # TODO: Is there any reason this function can't be moved to make_morphologies() # for morphology_file, morphology in self._morphologies_cache.items(): # morph_seg_coords = morphology.calc_seg_coords() # needed for ECP calculations # # for gid in self._morphology_lookup[morphology_file]: # self.get_cell_gid(gid).calc_seg_coords(morph_seg_coords) # def make_morphologies(self): # """Creating a Morphology object for each biophysical model""" # # TODO: Let Morphology take care of the cache # # TODO: Let other types have morphologies # # TODO: Get all available morphologies from TypesTable or group # for gid, cell in self._rank_node_gids.items(): # if not isinstance(cell, BioCell): # continue # # morphology_file = cell.morphology_file # if morphology_file in self._morphologies_cache: # # create a single morphology object for each model_group which share that morphology # morph = self._morphologies_cache[morphology_file] # # # associate morphology with a cell # cell.set_morphology(morph) # self._morphology_lookup[morphology_file].append(cell.gid) # # else: # hobj = cell.hobj # get hoc object (hobj) from the first cell with a new morphologys # morph = Morphology(hobj) # # # associate morphology with a cell # cell.set_morphology(morph) # # # create a single morphology object for each model_group which share that morphology # self._morphologies_cache[morphology_file] = morph # self._morphology_lookup[morphology_file] = [cell.gid] # # self.io.barrier() def _init_connections(self): if not self._connections_initialized: for gid, cell in self._rank_node_gids.items(): cell.init_connections() self._connections_initialized = True
[docs] def get_gj_id(self, network, src_nid, trg_nid, source_gap): ''' Returns the gap junction id for the given nodes on a given network. :param source_gap: whether to return the id of the gap junction on the source node or the target node. ''' if src_nid == trg_nid: raise Exception("Cells cannot have gap junctions with themselves.") gap_ids = self._gap_juncs[network] loc = np.where(np.logical_and(gap_ids["source_ids"] == src_nid, gap_ids["target_ids"] == trg_nid))[0] if len(loc) > 1: raise Exception("The gap junction file has more than one gap junction with the same ids.") elif len(loc) == 0: raise Exception("The gap junction file does not contain a gap junction with source id " + str(src_nid) + " and target id " + str(trg_nid) + " on network " + network) ids = [gap_ids["src_gap_ids"][loc[0]], gap_ids["trg_gap_ids"][loc[0]]] if not source_gap: ids.reverse() return ids
[docs] def build_recurrent_edges(self): recurrent_edge_pops = [ep for ep in self._edge_populations if not ep.virtual_connections] if not recurrent_edge_pops: return self._init_connections() for edge_pop in recurrent_edge_pops: if edge_pop.recurrent_connections: source_population = edge_pop.source_nodes for trg_nid, trg_cell in self._rank_node_ids[edge_pop.target_nodes].items(): for edge in edge_pop.get_target(trg_nid): src_node = self.get_node_id(source_population, edge.source_node_id) if edge.is_gap_junction: if source_population != edge_pop.target_nodes: raise Exception("Gap junctions must be from the same network builder") gj_ids = self.get_gj_id(source_population, edge.source_node_id, trg_nid, False) trg_cell.set_syn_connection(edge, src_node, gj_ids=gj_ids) else: trg_cell.set_syn_connection(edge, src_node) for src_nid, src_cell in self._rank_node_ids[source_population].items(): for edge in edge_pop.get_source(src_nid): if edge.is_gap_junction: if source_population != edge_pop.target_nodes: raise Exception("Gap junctions must be from the same network builder") trg_node = self.get_node_id(edge_pop.target_nodes, edge.target_node_id) gj_ids = self.get_gj_id(source_population, src_nid, edge.target_node_id, True) src_cell.set_syn_connection(edge, trg_node, gj_ids=gj_ids) elif edge_pop.mixed_connections: # When dealing with edges that contain both virtual and recurrent edges we have to check every source # node to see if is virtual (bc virtual nodes can't be built yet). This conditional can significantly # slow down build time so we use a special loop that can be ignored. source_population = edge_pop.source_nodes for trg_nid, trg_cell in self._rank_node_ids[edge_pop.target_nodes].items(): for edge in edge_pop.get_target(trg_nid): src_node = self.get_node_id(source_population, edge.source_node_id) if src_node.model_type == 'virtual': continue trg_cell.set_syn_connection(edge, src_node) self.io.barrier()
[docs] def build_replay_inputs(self, spike_trains, edges_path, edge_types_path, source_node_set, target_node_set): self._init_connections() src_cells = self.get_node_set(source_node_set) valid_src_ids = set(src_cells.gids()) trg_cells = self.get_node_set(target_node_set) valid_trg_ids = set(trg_cells.gids()) edges_pop_list = sonata_reader.load_edges(edges_path, edge_types_path, adaptor=self.get_edge_adaptor('sonata')) for edges_pop in edges_pop_list: edges_pop.initialize(self) trg_population = edges_pop.target_nodes src_population = edges_pop.source_nodes if src_population not in src_cells.population_names() or trg_population not in trg_cells.population_names(): continue for trg_nid, trg_cell in self._rank_node_ids[trg_population].items(): if trg_nid not in valid_trg_ids: continue for edge in edges_pop.get_target(trg_nid): src_nid = edge.source_node_id if src_nid not in valid_src_ids: continue src_cell = self.get_disconnected_cell(src_population, src_nid, spike_trains) trg_cell.set_syn_connection(edge, src_cell, src_cell) self.io.barrier()
[docs] def find_edges(self, source_nodes=None, target_nodes=None): selected_edges = self._edge_populations[:] if source_nodes is not None: selected_edges = [edge_pop for edge_pop in selected_edges if edge_pop.source_nodes == source_nodes] if target_nodes is not None: selected_edges = [edge_pop for edge_pop in selected_edges if edge_pop.target_nodes == target_nodes] return selected_edges
[docs] def add_spike_trains(self, spike_trains, node_set): self._init_connections() src_nodes = [node_pop for node_pop in self.node_populations if node_pop.name in node_set.population_names()] for src_node_pop in src_nodes: source_population = src_node_pop.name for edge_pop in self.find_edges(source_nodes=source_population): if edge_pop.virtual_connections: for trg_nid, trg_cell in self._rank_node_ids[edge_pop.target_nodes].items(): for edge in edge_pop.get_target(trg_nid): src_cell = self.get_virtual_cells(source_population, edge.source_node_id, spike_trains) trg_cell.set_syn_connection(edge, src_cell, src_cell) elif edge_pop.mixed_connections: raise NotImplementedError() self.io.barrier()