import os
import csv
import h5py
import numpy as np
from neuron import h
from glob import glob
from itertools import product
from .sim_module import SimulatorMod
from bmtk.simulator.bionet.biocell import BioCell
from bmtk.simulator.bionet.io_tools import io
from bmtk.utils.sonata.utils import add_hdf5_magic, add_hdf5_version
from bmtk.simulator.bionet.pointprocesscell import PointProcessCell
pc = h.ParallelContext()
MPI_RANK = int(pc.id())
N_HOSTS = int(pc.nhost())
[docs]class SaveSynapses(SimulatorMod):
def __init__(self, network_dir, single_file=False, **params):
self._network_dir = network_dir
self._virt_lookup = {}
self._gid_lookup = {}
self._sec_lookup = {}
if MPI_RANK == 0:
if not os.path.exists(network_dir):
os.makedirs(network_dir)
pc.barrier()
#if N_HOSTS > 1:
# io.log_exception('save_synapses module is not current supported with mpi')
self._syn_writer = ConnectionWriter(network_dir)
def _print_nc(self, nc, src_nid, trg_nid, cell, src_pop, trg_pop, edge_type_id):
if isinstance(cell, BioCell):
sec_x = nc.postloc()
sec = h.cas()
sec_id = self._sec_lookup[cell.gid][sec] # cell.get_section_id(sec)
h.pop_section()
self._syn_writer.add_bio_conn(edge_type_id, src_nid, src_pop, trg_nid, trg_pop, nc.weight[0], sec_id, sec_x)
else:
self._syn_writer.add_point_conn(edge_type_id, src_nid, src_pop, trg_nid, trg_pop, nc.weight[0])
[docs] def initialize(self, sim):
io.log_info('Saving network connections. This may take a while.')
# Need a way to look up virtual nodes from nc.pre()
for pop_name, nodes_table in sim.net._virtual_nodes.items():
for node_id, virt_node in nodes_table.items():
self._virt_lookup[virt_node.hobj] = (pop_name, node_id)
# Need to figure out node_id and pop_name from nc.srcgid()
for node_pop in sim.net.node_populations:
pop_name = node_pop.name
for node in node_pop[0::1]:
if node.model_type != 'virtual':
self._gid_lookup[node.gid] = (pop_name, node.node_id)
for gid, cell in sim.net.get_local_cells().items():
trg_pop, trg_id = self._gid_lookup[gid]
if isinstance(cell, BioCell):
# sections = cell._syn_seg_ix
self._sec_lookup[gid] = {sec_name: sec_id for sec_id, sec_name in enumerate(cell.get_sections_id())}
else:
sections = [-1]*len(cell.netcons)
for nc, edge_type_id in zip(cell.netcons, cell._edge_type_ids):
src_gid = int(nc.srcgid())
if src_gid == -1:
# source is a virtual node
src_pop, src_id = self._virt_lookup[nc.pre()]
else:
src_pop, src_id = self._gid_lookup[src_gid]
self._print_nc(nc, src_id, trg_id, cell, src_pop, trg_pop, edge_type_id)
self._syn_writer.close()
pc.barrier()
if N_HOSTS > 1 and MPI_RANK == 0:
merger = H5Merger(self._network_dir, self._syn_writer._pop_groups.keys())
pc.barrier()
io.log_info(' Done saving network connections.')
[docs]class H5Merger(object):
def __init__(self, network_dir, grp_keys):
self._network_dir = network_dir
self._grp_keys = list(grp_keys)
self._edge_counts = {(s, t): 0 for s, t in self._grp_keys}
self._biophys_edge_count = {(s, t): 0 for s, t in self._grp_keys}
self._point_edge_count = {(s, t): 0 for s, t in self._grp_keys}
self._tmp_files = {(s, t): [] for s, t in self._grp_keys}
for (src_pop, trg_pop), r in product(self._grp_keys, range(N_HOSTS)):
fname = '.core{}.{}_{}_edges.h5'.format(r, src_pop, trg_pop)
fpath = os.path.join(self._network_dir, fname)
if not os.path.exists(fpath):
io.log_warning('Expected file {} is missing'.format(fpath))
h5file = h5py.File(fpath, 'r')
edges_grp = h5file['/edges/{}_{}'.format(src_pop, trg_pop)]
self._tmp_files[(src_pop, trg_pop)].append(edges_grp)
self._edge_counts[(src_pop, trg_pop)] += len(edges_grp['source_node_id'])
self._biophys_edge_count[(src_pop, trg_pop)] += len(edges_grp['0/syn_weight'])
self._point_edge_count[(src_pop, trg_pop)] += len(edges_grp['1/syn_weight'])
for (src_pop, trg_pop), in_grps in self._tmp_files.items():
out_h5 = h5py.File(os.path.join(self._network_dir, '{}_{}_edges.h5'.format(src_pop, trg_pop)), 'w')
add_hdf5_magic(out_h5)
add_hdf5_version(out_h5)
pop_root = out_h5.create_group('/edges/{}_{}'.format(src_pop, trg_pop))
n_edges_total = self._edge_counts[(src_pop, trg_pop)]
n_edges_bio = self._biophys_edge_count[(src_pop, trg_pop)]
n_edges_point = self._point_edge_count[(src_pop, trg_pop)]
pop_root.create_dataset('source_node_id', (n_edges_total, ), dtype=np.uint64)
pop_root['source_node_id'].attrs['node_population'] = src_pop
pop_root.create_dataset('target_node_id', (n_edges_total, ), dtype=np.uint64)
pop_root['target_node_id'].attrs['node_population'] = trg_pop
pop_root.create_dataset('edge_group_id', (n_edges_total, ), dtype=np.uint16)
pop_root.create_dataset('edge_group_index', (n_edges_total,), dtype=np.uint16)
pop_root.create_dataset('edge_type_id', (n_edges_total, ), dtype=np.uint32)
pop_root.create_dataset('0/syn_weight', (n_edges_bio, ), dtype=float)
pop_root.create_dataset('0/sec_id', (n_edges_bio, ), dtype=np.uint64)
pop_root.create_dataset('0/sec_x', (n_edges_bio, ), dtype=float)
pop_root.create_dataset('1/syn_weight', (n_edges_point, ), dtype=float)
total_offset = 0
bio_offset = 0
point_offset = 0
for grp in in_grps:
n_ds = len(grp['source_node_id'])
pop_root['source_node_id'][total_offset:(total_offset + n_ds)] = grp['source_node_id'][()]
pop_root['target_node_id'][total_offset:(total_offset + n_ds)] = grp['target_node_id'][()]
pop_root['edge_group_id'][total_offset:(total_offset + n_ds)] = grp['edge_group_id'][()]
pop_root['edge_group_index'][total_offset:(total_offset + n_ds)] = grp['edge_group_index'][()]
pop_root['edge_type_id'][total_offset:(total_offset + n_ds)] = grp['edge_type_id'][()]
total_offset += n_ds
n_ds = len(grp['0/syn_weight'])
# print(grp['0/syn_weight'][()])
pop_root['0/syn_weight'][bio_offset:(bio_offset + n_ds)] = grp['0/syn_weight'][()]
pop_root['0/sec_id'][bio_offset:(bio_offset + n_ds)] = grp['0/sec_id'][()]
pop_root['0/sec_x'][bio_offset:(bio_offset + n_ds)] = grp['0/sec_x'][()]
bio_offset += n_ds
n_ds = len(grp['1/syn_weight'])
pop_root['1/syn_weight'][point_offset:(point_offset + n_ds)] = grp['1/syn_weight'][()]
point_offset += n_ds
fname = grp.file.filename
grp.file.close()
if os.path.exists(fname):
os.remove(fname)
self._create_index(pop_root, index_type='target')
self._create_index(pop_root, index_type='source')
out_h5.close()
def _create_index(self, pop_root, index_type='target'):
if index_type == 'target':
edge_nodes = np.array(pop_root['target_node_id'], dtype=np.int64)
output_grp = pop_root.create_group('indices/target_to_source')
elif index_type == 'source':
edge_nodes = np.array(pop_root['source_node_id'], dtype=np.int64)
output_grp = pop_root.create_group('indices/source_to_target')
edge_nodes = np.append(edge_nodes, [-1])
n_targets = np.max(edge_nodes)
ranges_list = [[] for _ in range(n_targets + 1)]
n_ranges = 0
begin_index = 0
cur_trg = edge_nodes[begin_index]
for end_index, trg_gid in enumerate(edge_nodes):
if cur_trg != trg_gid:
ranges_list[cur_trg].append((begin_index, end_index))
cur_trg = int(trg_gid)
begin_index = end_index
n_ranges += 1
node_id_to_range = np.zeros((n_targets + 1, 2))
range_to_edge_id = np.zeros((n_ranges, 2))
range_index = 0
for node_index, trg_ranges in enumerate(ranges_list):
if len(trg_ranges) > 0:
node_id_to_range[node_index, 0] = range_index
for r in trg_ranges:
range_to_edge_id[range_index, :] = r
range_index += 1
node_id_to_range[node_index, 1] = range_index
output_grp.create_dataset('range_to_edge_id', data=range_to_edge_id, dtype='uint64')
output_grp.create_dataset('node_id_to_range', data=node_id_to_range, dtype='uint64')
[docs]class ConnectionWriter(object):
[docs] class H5Index(object):
def __init__(self, file_path, src_pop, trg_pop):
# TODO: Merge with NetworkBuilder code for building SONATA files
self._nsyns = 0
self._n_biosyns = 0
self._n_pointsyns = 0
self._block_size = 5
self._pop_name = '{}_{}'.format(src_pop, trg_pop)
# self._h5_file = h5py.File(os.path.join(network_dir, '{}_edges.h5'.format(self._pop_name)), 'w')
self._h5_file = h5py.File(file_path, 'w')
add_hdf5_magic(self._h5_file)
add_hdf5_version(self._h5_file)
self._pop_root = self._h5_file.create_group('/edges/{}'.format(self._pop_name))
self._pop_root.create_dataset('edge_group_id', (self._block_size, ), dtype=np.uint16,
chunks=(self._block_size, ), maxshape=(None, ))
self._pop_root.create_dataset('source_node_id', (self._block_size, ), dtype=np.uint64,
chunks=(self._block_size, ), maxshape=(None, ))
self._pop_root['source_node_id'].attrs['node_population'] = src_pop
self._pop_root.create_dataset('target_node_id', (self._block_size, ), dtype=np.uint64,
chunks=(self._block_size, ), maxshape=(None, ))
self._pop_root['target_node_id'].attrs['node_population'] = trg_pop
self._pop_root.create_dataset('edge_type_id', (self._block_size, ), dtype=np.uint32,
chunks=(self._block_size, ), maxshape=(None, ))
self._pop_root.create_dataset('0/syn_weight', (self._block_size, ), dtype=float,
chunks=(self._block_size, ), maxshape=(None, ))
self._pop_root.create_dataset('0/sec_id', (self._block_size, ), dtype=np.uint64,
chunks=(self._block_size, ), maxshape=(None, ))
self._pop_root.create_dataset('0/sec_x', (self._block_size, ), chunks=(self._block_size, ),
maxshape=(None, ), dtype=float)
self._pop_root.create_dataset('1/syn_weight', (self._block_size, ), dtype=float,
chunks=(self._block_size, ), maxshape=(None, ))
def _add_conn(self, edge_type_id, src_id, trg_id, grp_id):
self._pop_root['edge_type_id'][self._nsyns] = edge_type_id
self._pop_root['source_node_id'][self._nsyns] = src_id
self._pop_root['target_node_id'][self._nsyns] = trg_id
self._pop_root['edge_group_id'][self._nsyns] = grp_id
self._nsyns += 1
if self._nsyns % self._block_size == 0:
self._pop_root['edge_type_id'].resize((self._nsyns + self._block_size,))
self._pop_root['source_node_id'].resize((self._nsyns + self._block_size, ))
self._pop_root['target_node_id'].resize((self._nsyns + self._block_size, ))
self._pop_root['edge_group_id'].resize((self._nsyns + self._block_size, ))
[docs] def add_bio_conn(self, edge_type_id, src_id, trg_id, syn_weight, sec_id, sec_x):
self._add_conn(edge_type_id, src_id, trg_id, 0)
self._pop_root['0/syn_weight'][self._n_biosyns] = syn_weight
self._pop_root['0/sec_id'][self._n_biosyns] = sec_id
self._pop_root['0/sec_x'][self._n_biosyns] = sec_x
self._n_biosyns += 1
if self._n_biosyns % self._block_size == 0:
self._pop_root['0/syn_weight'].resize((self._n_biosyns + self._block_size, ))
self._pop_root['0/sec_id'].resize((self._n_biosyns + self._block_size, ))
self._pop_root['0/sec_x'].resize((self._n_biosyns + self._block_size, ))
[docs] def add_point_conn(self, edge_type_id, src_id, trg_id, syn_weight):
self._add_conn(edge_type_id, src_id, trg_id, 1)
self._pop_root['1/syn_weight'][self._n_pointsyns] = syn_weight
self._n_pointsyns += 1
if self._n_pointsyns % self._block_size == 0:
self._pop_root['1/syn_weight'].resize((self._n_pointsyns + self._block_size, ))
[docs] def clean_ends(self):
self._pop_root['source_node_id'].resize((self._nsyns,))
self._pop_root['target_node_id'].resize((self._nsyns,))
self._pop_root['edge_group_id'].resize((self._nsyns,))
self._pop_root['edge_type_id'].resize((self._nsyns,))
self._pop_root['0/syn_weight'].resize((self._n_biosyns,))
self._pop_root['0/sec_id'].resize((self._n_biosyns,))
self._pop_root['0/sec_x'].resize((self._n_biosyns,))
self._pop_root['1/syn_weight'].resize((self._n_pointsyns,))
eg_ds = self._pop_root.create_dataset('edge_group_index', (self._nsyns, ), dtype=np.uint64)
bio_count, point_count = 0, 0
for idx, grp_id in enumerate(self._pop_root['edge_group_id']):
if grp_id == 0:
eg_ds[idx] = bio_count
bio_count += 1
elif grp_id == 1:
eg_ds[idx] = point_count
point_count += 1
self._create_index('target')
self._create_index('source')
def _create_index(self, index_type='target'):
if index_type == 'target':
edge_nodes = np.array(self._pop_root['target_node_id'], dtype=np.int64)
output_grp = self._pop_root.create_group('indices/target_to_source')
elif index_type == 'source':
edge_nodes = np.array(self._pop_root['source_node_id'], dtype=np.int64)
output_grp = self._pop_root.create_group('indices/source_to_target')
edge_nodes = np.append(edge_nodes, [-1])
n_targets = np.max(edge_nodes)
ranges_list = [[] for _ in range(n_targets + 1)]
n_ranges = 0
begin_index = 0
cur_trg = edge_nodes[begin_index]
for end_index, trg_gid in enumerate(edge_nodes):
if cur_trg != trg_gid:
ranges_list[cur_trg].append((begin_index, end_index))
cur_trg = int(trg_gid)
begin_index = end_index
n_ranges += 1
node_id_to_range = np.zeros((n_targets + 1, 2))
range_to_edge_id = np.zeros((n_ranges, 2))
range_index = 0
for node_index, trg_ranges in enumerate(ranges_list):
if len(trg_ranges) > 0:
node_id_to_range[node_index, 0] = range_index
for r in trg_ranges:
range_to_edge_id[range_index, :] = r
range_index += 1
node_id_to_range[node_index, 1] = range_index
output_grp.create_dataset('range_to_edge_id', data=range_to_edge_id, dtype='uint64')
output_grp.create_dataset('node_id_to_range', data=node_id_to_range, dtype='uint64')
[docs] def close_h5(self):
self._h5_file.close()
def __init__(self, network_dir):
self._network_dir = network_dir
self._pop_groups = {}
def _group_key(self, src_pop, trg_pop):
return (src_pop, trg_pop)
def _get_edge_group(self, src_pop, trg_pop):
grp_key = self._group_key(src_pop, trg_pop)
if grp_key not in self._pop_groups:
pop_name = '{}_{}'.format(src_pop, trg_pop)
if N_HOSTS > 1:
pop_name = '.core{}.{}'.format(MPI_RANK, pop_name)
file_path = os.path.join(self._network_dir, '{}_edges.h5'.format(pop_name))
self._pop_groups[grp_key] = self.H5Index(file_path, src_pop, trg_pop)
return self._pop_groups[grp_key]
[docs] def add_bio_conn(self, edge_type_id, src_id, src_pop, trg_id, trg_pop, syn_weight, sec_id, sec_x):
h5_grp = self._get_edge_group(src_pop, trg_pop)
h5_grp.add_bio_conn(edge_type_id, src_id, trg_id, syn_weight, sec_id, sec_x)
[docs] def add_point_conn(self, edge_type_id, src_id, src_pop, trg_id, trg_pop, syn_weight):
h5_grp = self._get_edge_group(src_pop, trg_pop)
h5_grp.add_point_conn(edge_type_id, src_id, trg_id, syn_weight)
[docs] def close(self):
for _, h5index in self._pop_groups.items():
h5index.clean_ends()
h5index.close_h5()