Source code for bmtk.simulator.core.simulator_network

from six import string_types
import numpy as np
import os
import h5py
import pandas as pd

from bmtk.simulator.core.io_tools import io
from .simulation_config import SimulationConfig
from bmtk.simulator.core.node_sets import NodeSet, NodeSetAll
from bmtk.simulator.core import sonata_reader


[docs]class SimNetwork(object): def __init__(self): self._components = {} self._io = io self._node_adaptors = {} self._edge_adaptors = {} self._register_adaptors() self._node_populations = {} self._node_sets = {} self._edge_populations = [] self._gap_juncs = {} @property def io(self): return self._io @property def node_populations(self): return self._node_populations.values() @property def recurrent_edges(self): return [ep for ep in self._edge_populations if ep.recurrent_connections] @property def py_function_caches(self): return None def _register_adaptors(self): self._node_adaptors['sonata'] = sonata_reader.NodeAdaptor self._edge_adaptors['sonata'] = sonata_reader.EdgeAdaptor
[docs] def get_node_adaptor(self, name): return self._node_adaptors[name]
[docs] def get_edge_adaptor(self, name): return self._edge_adaptors[name]
[docs] def add_component(self, name, path): self._components[name] = path
[docs] def get_component(self, name): if name not in self._components: self.io.log_exception('No network component set with name {}'.format(name)) else: return self._components[name]
[docs] def has_component(self, name): return name in self._components
[docs] def get_node_population(self, name): return self._node_populations[name]
[docs] def get_node_populations(self): return self._node_populations.values()
[docs] def add_node_set(self, name, node_set): self._node_sets[name] = node_set
[docs] def get_node_set(self, node_set): if isinstance(node_set, string_types) and node_set in self._node_sets: return self._node_sets[node_set] elif isinstance(node_set, (dict, list)): return NodeSet(node_set, self) else: self.io.log_exception('Unable to load or find node_set "{}"'.format(node_set))
[docs] def add_nodes(self, node_population): pop_name = node_population.name if pop_name in self._node_populations: # Make sure their aren't any collisions self.io.log_exception('There are multiple node populations with name {}.'.format(pop_name)) node_population.initialize(self) self._node_populations[pop_name] = node_population if node_population.mixed_nodes: # We'll allow a population to have virtual and non-virtual nodes but it is not ideal self.io.log_warning(('Node population {} contains both virtual and non-virtual nodes which can cause ' + 'memory and build-time inefficency. Consider separating virtual nodes into their ' + 'own population').format(pop_name)) # Used in inputs/reports when needed to get all gids belonging to a node population self._node_sets[pop_name] = NodeSet({'population': pop_name}, self)
[docs] def node_properties(self, populations=None): if populations is None: selected_pops = self.node_populations elif isinstance(populations, string_types): selected_pops = [pop for pop in self.node_populations if pop.name == populations] else: selected_pops = [pop for pop in self.node_populations if pop.name in populations] all_nodes_df = None for node_pop in selected_pops: node_pop_df = node_pop.nodes_df() if 'population' not in node_pop_df: node_pop_df['population'] = node_pop.name node_pop_df = node_pop_df.set_index(['population', node_pop_df.index.astype(dtype=np.uint64)]) if all_nodes_df is None: all_nodes_df = node_pop_df else: all_nodes_df = pd.concat([all_nodes_df, node_pop_df]) return all_nodes_df
[docs] def get_node_groups(self, populations=None): if populations is None: selected_pops = self.node_populations elif isinstance(populations, string_types): selected_pops = [pop for pop in self.node_populations if pop.name == populations] else: selected_pops = [pop for pop in self.node_populations if pop.name in populations] all_nodes_df = None for node_pop in selected_pops: node_pop_df = node_pop.nodes_df(index_by_id=False) if 'population' not in node_pop_df: node_pop_df['population'] = node_pop.name if all_nodes_df is None: all_nodes_df = node_pop_df else: all_nodes_df = pd.concat([all_nodes_df, node_pop_df], sort=False) return all_nodes_df
[docs] def get_node_sets(self, populations=None, groupby=None, **filterby): selected_nodes_df = self.node_properties(populations) for k, v in filterby: if isinstance(v, (np.ndarray, list, tuple)): selected_nodes_df = selected_nodes_df[selected_nodes_df[k].isin(v)] else: selected_nodes_df = selected_nodes_df[selected_nodes_df[k].isin(v)] if groupby is not None: return {k: v.tolist() for k, v in selected_nodes_df.groupby(groupby).groups.items()} else: return selected_nodes_df.index.tolist()
[docs] def add_edges(self, edge_population): edge_population.initialize(self) pop_name = edge_population.name # Check that source_population exists src_pop_name = edge_population.source_nodes if src_pop_name not in self._node_populations: self.io.log_exception('Source node population {} not found. Please update {} edges'.format(src_pop_name, pop_name)) # Check that the target population exists and contains non-virtual nodes (we cannot synapse onto virt nodes) trg_pop_name = edge_population.target_nodes if trg_pop_name not in self._node_populations or self._node_populations[trg_pop_name].virtual_nodes_only: self.io.log_exception(('Node population {} does not exists (or consists of only virtual nodes). ' + '{} edges cannot create connections.').format(trg_pop_name, pop_name)) edge_population.set_connection_type(src_pop=self._node_populations[src_pop_name], trg_pop = self._node_populations[trg_pop_name]) self._edge_populations.append(edge_population)
[docs] def load_gap_junc_files(self, gj_dic): for p in gj_dic: path = p['gap_juncs_file'] f_name = os.path.basename(path) network = f_name[:f_name.find("_gap_juncs.h5")] self._gap_juncs[network] = {} with h5py.File(path, 'r') as f: for key in ['source_ids', 'target_ids', 'src_gap_ids', 'trg_gap_ids']: self._gap_juncs[network][key] = f[key][()]
[docs] def build(self): self.build_nodes() self.build_recurrent_edges()
[docs] def build_nodes(self): raise NotImplementedError()
[docs] def build_recurrent_edges(self, **opts): raise NotImplementedError()
[docs] def build_virtual_connections(self): raise NotImplementedError()
[docs] @classmethod def from_config(cls, conf, **properties): """Generates a graph structure from a json config file or dictionary. :param conf: name of json config file, or a dictionary with config parameters :param properties: optional properties. :return: A graph object of type cls """ network = cls(**properties) # The simulation run script should create a config-dict since it's likely to vary based on the simulator engine, # however in the case the user doesn't we will try a generic conversion from dict/json to ConfigDict if isinstance(conf, SimulationConfig): config = conf else: try: config = SimulationConfig.load(conf) except Exception as e: network.io.log_exception('Could not convert {} (type "{}") to json.'.format(conf, type(conf))) if not config.with_networks: network.io.log_exception('Could not find any network files. Unable to build network.') # TODO: These are simulator specific network.spike_threshold = config.spike_threshold network.dL = config.dL # load components for name, value in config.components.items(): network.add_component(name, value) # load nodes gid_map = config.gid_mappings node_adaptor = network.get_node_adaptor('sonata') for node_dict in config.nodes: nodes = sonata_reader.load_nodes(node_dict['nodes_file'], node_dict['node_types_file'], gid_map, adaptor=node_adaptor) for node_pop in nodes: network.add_nodes(node_pop) # TODO: Raise a warning if more than one internal population and no gids (node_id collision) # load edges edge_adaptor = network.get_edge_adaptor('sonata') for edge_dict in config.edges: if not edge_dict.get('enabled', True): continue edges = sonata_reader.load_edges(edge_dict['edges_file'], edge_dict['edge_types_file'], adaptor=edge_adaptor) for edge_pop in edges: network.add_edges(edge_pop) network.load_gap_junc_files(config.gap_juncs) # Add nodeset section network.add_node_set('all', NodeSetAll(network)) for ns_name, ns_filter in config.node_sets.items(): network.add_node_set(ns_name, NodeSet(ns_filter, network)) return network
[docs] @classmethod def from_manifest(cls, manifest_json): # TODO: Add adaptors to build a simulation network from model files downloaded celltypes.brain-map.org raise NotImplementedError()
[docs] @classmethod def from_builder(cls, network): # TODO: Add adaptors to build a simulation network from a bmtk.builder Network object raise NotImplementedError()