Source code for bmtk.builder.network_adaptors.dm_network_orig

# Copyright 2017. Allen Institute. All rights reserved
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
# following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
# disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import numpy as np
import h5py
import six
import csv

from .network import Network
from bmtk.builder.node import Node
from bmtk.builder.edge import Edge
from bmtk.utils import sonata

try:
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    mpi_rank = comm.Get_rank()
    mpi_size = comm.Get_size()
    barrier = comm.barrier

except ImportError:
    mpi_rank = 0
    mpi_size = 1
    barrier = lambda: None


[docs]class DenseNetworkOrig(Network): def __init__(self, name, **network_props): super(DenseNetworkOrig, self).__init__(name, **network_props or {}) self.__edges_types = {} self.__src_mapping = {} self.__networks = {} self.__node_count = 0 self._nodes = [] self.__edges_tables = [] self._target_networks = {} def _initialize(self): self.__id_map = [] self.__lookup = [] def _add_nodes(self, nodes): self._nodes.extend(nodes) self._nnodes = len(self._nodes) """ id_label = 'node_id' if 'node_id' in nodes[0].keys() else 'id' start_idx = len(self.__id_map) # self.__id_map += [n[id_label] for n in nodes] self.__nodes += [(interal_id, nodes[node_idx]) for node_idx, interal_id in enumerate(xrange(start_idx, len(self.__id_map)))] assert(len(self.__id_map) == len(self.__nodes)) """
[docs] def edges_table(self): return self.__edges_tables
def _save_nodes(self, nodes_file_name): if not self._nodes_built: self._build_nodes() # save the node_types file # TODO: how do we add attributes to the h5 group_indx = 0 groups_lookup = {} group_indicies = {} group_props = {} for ns in self._node_sets: if ns.params_hash in groups_lookup: continue else: groups_lookup[ns.params_hash] = group_indx group_indicies[group_indx] = 0 group_props[group_indx] = {k: [] for k in ns.params_keys if k != 'node_id'} group_indx += 1 node_gid_table = np.zeros(self._nnodes) # todo: set dtypes node_type_id_table = np.zeros(self._nnodes) node_group_table = np.zeros(self._nnodes) node_group_index_tables = np.zeros(self._nnodes) for i, node in enumerate(self.nodes()): node_gid_table[i] = node.node_id node_type_id_table[i] = node.node_type_id group_id = groups_lookup[node.params_hash] node_group_table[i] = group_id node_group_index_tables[i] = group_indicies[group_id] group_indicies[group_id] += 1 group_dict = group_props[group_id] for key, prop_ds in group_dict.items(): prop_ds.append(node.params[key]) # TODO: open in append mode if mpi_rank == 0: with h5py.File(nodes_file_name, 'w') as hf: # Add magic and version attribute add_hdf5_attrs(hf) pop_grp = hf.create_group('/nodes/{}'.format(self.name)) pop_grp.create_dataset('node_id', data=node_gid_table, dtype='uint64') pop_grp.create_dataset('node_type_id', data=node_type_id_table, dtype='uint64') pop_grp.create_dataset('node_group_id', data=node_group_table, dtype='uint32') pop_grp.create_dataset('node_group_index', data=node_group_index_tables, dtype='uint64') for grp_id, props in group_props.items(): model_grp = pop_grp.create_group('{}'.format(grp_id)) for key, dataset in props.items(): # ds_path = 'nodes/{}/{}'.format(grp_id, key) try: model_grp.create_dataset(key, data=dataset) except TypeError: str_list = [str(d) for d in dataset] hf.create_dataset(key, data=str_list) barrier()
[docs] def nodes_iter(self, node_ids=None): if node_ids is not None: return [n for n in self._nodes if n.node_id in node_ids] else: return self._nodes
def _process_nodepool(self, nodepool): return nodepool
[docs] def import_nodes(self, nodes_file_name, node_types_file_name, population=None): sonata_file = sonata.File(data_files=nodes_file_name, data_type_files=node_types_file_name) if sonata_file.nodes is None: raise Exception('nodes file {} does not have any nodes.'.format(nodes_file_name)) populations = sonata_file.nodes.populations if len(populations) == 1: node_pop = populations[0] elif population is None: raise Exception('The nodes file {} contains multiple populations.'.format(nodes_file_name) + 'Please specify population parameter.') else: for pop in populations: if pop.name == population: node_pop = pop break else: raise Exception('Nodes file {} does not contain population {}.'.format(nodes_file_name, population)) for node_type_props in node_pop.node_types_table: self._add_node_type(node_type_props) for node in node_pop: self._node_id_gen.remove_id(node.node_id) self._nodes.append(Node(node.node_id, node.group_props, node.node_type_properties))
def _add_edges(self, connection_map, i): syn_table = self.EdgeTable(connection_map) connections = connection_map.connection_itr() for con in connections: if con[2] is not None: syn_table[con[0], con[1]] = con[2] target_net = connection_map.target_nodes self._target_networks[target_net.network_name] = target_net.network nsyns = np.sum(syn_table.nsyn_table) self._nedges += int(nsyns) edge_table = {'syn_table': syn_table, 'nsyns': nsyns, 'edge_types': connection_map.edge_type_properties, 'edge_type_id': connection_map.edge_type_properties['edge_type_id'], 'source_network': connection_map.source_nodes.network_name, 'target_network': connection_map.target_nodes.network_name, 'params': {}, 'params_dtypes': {}, 'source_query': connection_map.source_nodes.filter_str, 'target_query': connection_map.target_nodes.filter_str} for param in connection_map.params: rule = param.rule param_names = param.names edge_table['params_dtypes'].update(param.dtypes) if isinstance(param_names, list) or isinstance(param_names, tuple): tmp_tables = [self.PropertyTable(nsyns, param.dtypes[n]) for n in param_names] # tmp_tables = [self.PropertyTable(nsyns) for _ in range(len(param_names))] for source in connection_map.source_nodes: src_node_id = source.node_id for target in connection_map.target_nodes: trg_node_id = target.node_id # TODO: pull this out and put in it's own list for _ in range(syn_table[src_node_id, trg_node_id]): pvals = rule(source, target) for i in range(len(param_names)): tmp_tables[i][src_node_id, trg_node_id] = pvals[i] for i, name in enumerate(param_names): # TODO: I think a copy constructor might get called, move this out. edge_table['params'][name] = tmp_tables[i] else: pt = self.PropertyTable(np.sum(nsyns), dtype=param.dtypes[param_names]) for source in connection_map.source_nodes: src_node_id = source.node_id for target in connection_map.target_nodes: trg_node_id = target.node_id # TODO: pull this out and put in it's own list #print('{}, {}: {}'.format(src_node_id, trg_node_id, edge_table[src_node_id, trg_node_id])) for _ in range(syn_table[src_node_id, trg_node_id]): pt[src_node_id, trg_node_id] = rule(source, target) edge_table['params'][param_names] = pt self.__edges_tables.append(edge_table) def _save_gap_junctions(self, gj_file_name): source_ids = [] target_ids = [] src_gap_ids = [] trg_gap_ids = [] for et in self.__edges_tables: try: is_gap = et['edge_types']['is_gap_junction'] except: continue if is_gap: if et['source_network'] != et['target_network']: raise Exception("All gap junctions must be two cells in the same network builder.") table = et['syn_table'] junc_table = table.nsyn_table locs = np.where(junc_table > 0) for i in range(len(locs[0])): source_ids.append(table.source_ids[locs[0][i]]) target_ids.append(table.target_ids[locs[1][i]]) src_gap_ids.append(self._gj_id_gen.next()) trg_gap_ids.append(self._gj_id_gen.next()) else: continue if len(source_ids) > 0: with h5py.File(gj_file_name, 'w') as f: add_hdf5_attrs(f) f.create_dataset('source_ids', data=np.array(source_ids)) f.create_dataset('target_ids', data=np.array(target_ids)) f.create_dataset('src_gap_ids', data=np.array(src_gap_ids)) f.create_dataset('trg_gap_ids', data=np.array(trg_gap_ids)) def _save_edges(self, edges_file_name, src_network, trg_network, name=None): groups = {} group_dtypes = {} # TODO: this should be stored in PropertyTable grp_id_itr = 0 groups_lookup = {} total_syns = 0 matching_edge_tables = [et for et in self.__edges_tables if et['source_network'] == src_network and et['target_network'] == trg_network] for ets in matching_edge_tables: params_hash = str(ets['params'].keys()) group_id = groups_lookup.get(params_hash, None) if group_id is None: group_id = grp_id_itr groups_lookup[params_hash] = group_id grp_id_itr += 1 ets['group_id'] = group_id groups[group_id] = {} group_dtypes[group_id] = ets['params_dtypes'] for param_name in ets['params'].keys(): groups[group_id][param_name] = [] total_syns += int(ets['nsyns']) group_index_itrs = [0 for _ in range(grp_id_itr)] trg_gids = np.zeros(total_syns) # set dtype to uint64 src_gids = np.zeros(total_syns) edge_groups = np.zeros(total_syns) # dtype uint16 or uint8 edge_group_index = np.zeros(total_syns) # uint32 edge_type_ids = np.zeros(total_syns) # uint32 # TODO: Another potential issue if node-ids don't start with 0 index_ptrs = np.zeros(len(self._target_networks[trg_network].nodes()) + 1) #index_ptrs = np.zeros(len(self._nodes)+1) # TODO: issue when target nodes come from another network index_ptr_itr = 0 gid_indx = 0 for trg_node in self._target_networks[trg_network].nodes(): index_ptrs[index_ptr_itr] = gid_indx index_ptr_itr += 1 for ets in matching_edge_tables: edge_group_id = ets['group_id'] group_table = groups[edge_group_id] syn_table = ets['syn_table'] if syn_table.has_target(trg_node.node_id): if ets['params']: for src_id, nsyns in syn_table.trg_itr(trg_node.node_id): # Add on to the edges index indx_end = gid_indx+nsyns while gid_indx < indx_end: trg_gids[gid_indx] = trg_node.node_id src_gids[gid_indx] = src_id edge_type_ids[gid_indx] = ets['edge_type_id'] edge_groups[gid_indx] = edge_group_id edge_group_index[gid_indx] = group_index_itrs[edge_group_id] group_index_itrs[edge_group_id] += 1 gid_indx += 1 for param_name, param_table in ets['params'].items(): param_vals = group_table[param_name] for val in param_table.itr_vals(src_id, trg_node.node_id): param_vals.append(val) else: # If no properties just print nsyns table. if 'nsyns' not in group_table: group_table['nsyns'] = [] group_dtypes[edge_group_id]['nsyns'] = 'uint16' for src_id, nsyns in syn_table.trg_itr(trg_node.node_id): trg_gids[gid_indx] = trg_node.node_id src_gids[gid_indx] = src_id edge_type_ids[gid_indx] = ets['edge_type_id'] edge_groups[gid_indx] = edge_group_id edge_group_index[gid_indx] = group_index_itrs[edge_group_id] group_index_itrs[edge_group_id] += 1 gid_indx += 1 group_table['nsyns'].append(nsyns) trg_gids = trg_gids[:gid_indx] src_gids = src_gids[:gid_indx] edge_groups = edge_groups[:gid_indx] edge_group_index = edge_group_index[:gid_indx] edge_type_ids = edge_type_ids[:gid_indx] pop_name = '{}_to_{}'.format(src_network, trg_network) if name is None else name index_ptrs[index_ptr_itr] = gid_indx with h5py.File(edges_file_name, 'w') as hf: add_hdf5_attrs(hf) pop_grp = hf.create_group('/edges/{}'.format(pop_name)) pop_grp.create_dataset('target_node_id', data=trg_gids, dtype='uint64') pop_grp['target_node_id'].attrs['node_population'] = trg_network pop_grp.create_dataset('source_node_id', data=src_gids, dtype='uint64') pop_grp['source_node_id'].attrs['node_population'] = src_network pop_grp.create_dataset('edge_group_id', data=edge_groups, dtype='uint16') pop_grp.create_dataset('edge_group_index', data=edge_group_index, dtype='uint32') pop_grp.create_dataset('edge_type_id', data=edge_type_ids, dtype='uint32') # pop_grp.create_dataset('edges/index_pointer', data=index_ptrs, dtype='uint32') for group_id, params_dict in groups.items(): model_grp = pop_grp.create_group(str(group_id)) for params_key, params_vals in params_dict.items(): #group_path = 'edges/{}/{}'.format(group_id, params_key) dtype = group_dtypes[group_id][params_key] if dtype is not None: model_grp.create_dataset(params_key, data=list(params_vals), dtype=dtype) else: model_grp.create_dataset(params_key, data=list(params_vals)) self._create_index(pop_grp['target_node_id'], pop_grp, index_type='target') self._create_index(pop_grp['source_node_id'], pop_grp, index_type='source') def _create_index(self, node_ids_ds, output_grp, index_type='target'): if index_type == 'target': edge_nodes = np.array(node_ids_ds, dtype=np.int64) output_grp = output_grp.create_group('indices/target_to_source') elif index_type == 'source': edge_nodes = np.array(node_ids_ds, dtype=np.int64) output_grp = output_grp.create_group('indices/source_to_target') edge_nodes = np.append(edge_nodes, [-1]) n_targets = np.max(edge_nodes) ranges_list = [[] for _ in six.moves.range(n_targets + 1)] n_ranges = 0 begin_index = 0 cur_trg = edge_nodes[begin_index] for end_index, trg_gid in enumerate(edge_nodes): if cur_trg != trg_gid: ranges_list[cur_trg].append((begin_index, end_index)) cur_trg = int(trg_gid) begin_index = end_index n_ranges += 1 node_id_to_range = np.zeros((n_targets + 1, 2)) range_to_edge_id = np.zeros((n_ranges, 2)) range_index = 0 for node_index, trg_ranges in enumerate(ranges_list): if len(trg_ranges) > 0: node_id_to_range[node_index, 0] = range_index for r in trg_ranges: range_to_edge_id[range_index, :] = r range_index += 1 node_id_to_range[node_index, 1] = range_index output_grp.create_dataset('range_to_edge_id', data=range_to_edge_id, dtype='uint64') output_grp.create_dataset('node_id_to_range', data=node_id_to_range, dtype='uint64') def _clear(self): self._nedges = 0 self._nnodes = 0
[docs] def edges_iter(self, trg_gids, src_network=None, trg_network=None): matching_edge_tables = self.__edges_tables if trg_network is not None: matching_edge_tables = [et for et in self.__edges_tables if et['target_network'] == trg_network] if src_network is not None: matching_edge_tables = [et for et in matching_edge_tables if et['source_network'] == src_network] for trg_gid in trg_gids: for ets in matching_edge_tables: syn_table = ets['syn_table'] if syn_table.has_target(trg_gid): for src_id, nsyns in syn_table.trg_itr(trg_gid): if ets['params']: synapses = [{} for _ in range(nsyns)] for param_name, param_table in ets['params'].items(): for i, val in enumerate(param_table[src_id, trg_gid]): synapses[i][param_name] = val for syn_prop in synapses: yield Edge(src_gid=src_id, trg_gid=trg_gid, edge_type_props=ets['edge_types'], syn_props=syn_prop) else: yield Edge(src_gid=src_id, trg_gid=trg_gid, edge_type_props=ets['edge_types'], syn_props={'nsyns': nsyns})
@property def nnodes(self): if not self.nodes_built: return 0 return self._nnodes @property def nedges(self): return self._nedges
[docs] class EdgeTable(object): def __init__(self, connection_map): # TODO: save column and row lengths # Create maps between source_node gids and their row in the matrix. self.__idx2src = [n.node_id for n in connection_map.source_nodes] self.__src2idx = {node_id: i for i, node_id in enumerate(self.__idx2src)} # Create maps betwee target_node gids and their column in the matrix self.__idx2trg = [n.node_id for n in connection_map.target_nodes] self.__trg2idx = {node_id: i for i, node_id in enumerate(self.__idx2trg)} self._nsyn_table = np.zeros((len(self.__idx2src), len(self.__idx2trg)), dtype=np.uint8) def __getitem__(self, item): # TODO: make sure matrix is column oriented, or swithc trg and srcs. indexed_pair = (self.__src2idx[item[0]], self.__trg2idx[item[1]]) return self._nsyn_table[indexed_pair] def __setitem__(self, key, value): assert(len(key) == 2) indexed_pair = (self.__src2idx[key[0]], self.__trg2idx[key[1]]) self._nsyn_table[indexed_pair] = value
[docs] def has_target(self, node_id): return node_id in self.__trg2idx
@property def nsyn_table(self): return self._nsyn_table @property def target_ids(self): return self.__idx2trg @property def source_ids(self): return self.__idx2src
[docs] def trg_itr(self, trg_id): trg_i = self.__trg2idx[trg_id] for src_j, src_id in enumerate(self.__idx2src): nsyns = self._nsyn_table[src_j, trg_i] if nsyns: yield src_id, nsyns
[docs] class PropertyTable(object): # TODO: add support for strings def __init__(self, nvalues, dtype): self._prop_array = np.zeros(nvalues, dtype=dtype) # self._prop_table = np.zeros((nvalues, 1)) # TODO: set dtype self._index = np.zeros((nvalues, 2), dtype=np.uint32) self._itr_index = 0
[docs] def itr_vals(self, src_id, trg_id): indicies = np.where((self._index[:, 0] == src_id) & (self._index[:, 1] == trg_id)) for val in self._prop_array[indicies]: yield val
def __setitem__(self, key, value): self._index[self._itr_index, 0] = key[0] # src_node_id self._index[self._itr_index, 1] = key[1] # trg_node_id self._prop_array[self._itr_index] = value self._itr_index += 1 def __getitem__(self, item): indicies = np.where((self._index[:, 0] == item[0]) & (self._index[:, 1] == item[1])) return self._prop_array[indicies]
[docs]def add_hdf5_attrs(hdf5_handle): # TODO: move this as a utility function hdf5_handle['/'].attrs['magic'] = np.uint32(0x0A7A) hdf5_handle['/'].attrs['version'] = [np.uint32(0), np.uint32(1)]