Source code for bmtk.utils.io.cell_vars

import os
import h5py
import numpy as np

from bmtk.utils import io
from bmtk.utils.sonata.utils import add_hdf5_magic, add_hdf5_version


[docs]class CellVarRecorder(object): """Used to save cell membrane variables (V, Ca2+, etc) to the described hdf5 format. For parallel simulations this class will write to a seperate tmp file on each rank, then use the merge method to combine the results. This is less efficent, but doesn't require the user to install mpi4py and build h5py in parallel mode. For better performance use the CellVarRecorderParallel class instead. """ _io = io
[docs] class DataTable(object): """A small struct to keep track of different \*/data (and buffer) tables""" def __init__(self, var_name): self.var_name = var_name # If buffering data, buffer_block will be an in-memory array and will write to data_block during when # filled. If not buffering buffer_block is an hdf5 dataset and data_block is ignored self.data_block = None self.buffer_block = None
def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0, mpi_size=1, **kwargs): self._file_name = file_name self._h5_handle = None self._tmp_dir = tmp_dir self._variables = variables if isinstance(variables, list) else [variables] self._n_vars = len(self._variables) # Used later to keep track if more than one var is saved to the same file. self._mpi_rank = mpi_rank self._mpi_size = mpi_size self._tmp_files = [] self._saved_file = file_name self._population = kwargs.get('population', None) self._units = kwargs.get('units', None) if mpi_size > 1: self._io.log_warning('Was unable to run h5py in parallel (mpi) mode.' + ' Saving of membrane variable(s) may slow down.') tmp_fname = os.path.basename(file_name) # make sure file names don't clash if there are multiple reports self._tmp_files = [os.path.join(tmp_dir, '__bmtk_tmp_cellvars_{}_{}'.format(r, tmp_fname)) for r in range(self._mpi_size)] self._file_name = self._tmp_files[self._mpi_rank] self._mapping_gids = [] # list of gids in the order they appear in the data self._gid_map = {} # table for looking up the gid offsets self._map_attrs = {} # Used for additonal attributes in /mapping self._mapping_element_ids = [] # sections self._mapping_element_pos = [] # segments self._mapping_index = [0] # index_pointer self._buffer_data = buffer_data self._data_blocks = {var_name: self.DataTable(var_name) for var_name in self._variables} self._last_save_indx = 0 # for buffering, used to keep track of last timestep data was saved to disk self._buffer_block_size = 0 self._total_steps = 0 # Keep track of gids across the different ranks self._n_gids_all = 0 self._n_gids_local = 0 self._gids_beg = 0 self._gids_end = 0 # Keep track of segment counts across the different ranks self._n_segments_all = 0 self._n_segments_local = 0 self._seg_offset_beg = 0 self._seg_offset_end = 0 self._tstart = 0.0 self._tstop = 0.0 self._dt = 0.01 self._is_initialized = False @property def tstart(self): return self._tstart @tstart.setter def tstart(self, time_ms): self._tstart = time_ms @property def tstop(self): return self._tstop @tstop.setter def tstop(self, time_ms): self._tstop = time_ms @property def dt(self): return self._dt @dt.setter def dt(self, time_ms): self._dt = time_ms @property def is_initialized(self): return self._is_initialized def _calc_offset(self): self._n_segments_all = self._n_segments_local self._seg_offset_beg = 0 self._seg_offset_end = self._n_segments_local self._n_gids_all = self._n_gids_local self._gids_beg = 0 self._gids_end = self._n_gids_local def _create_h5_file(self): self._h5_handle = h5py.File(self._file_name, 'w') add_hdf5_version(self._h5_handle) add_hdf5_magic(self._h5_handle)
[docs] def add_cell(self, gid, sec_list, seg_list, **map_attrs): assert(len(sec_list) == len(seg_list)) # TODO: Check the same gid isn't added twice n_segs = len(seg_list) self._gid_map[gid] = (self._n_segments_local, self._n_segments_local + n_segs) self._mapping_gids.append(gid) self._mapping_element_ids.extend(sec_list) self._mapping_element_pos.extend(seg_list) self._mapping_index.append(self._mapping_index[-1] + n_segs) self._n_segments_local += n_segs self._n_gids_local += 1 for k, v in map_attrs.items(): if k not in self._map_attrs: self._map_attrs[k] = v else: self._map_attrs[k].extend(v)
[docs] def initialize(self, n_steps, buffer_size=0): self._calc_offset() self._create_h5_file() #root_name = '/report/{}'.format(self._population) if self._population is not None else '/' var_grp = self._h5_handle.create_group('/mapping') # TODO: gids --> node_ids var_grp.create_dataset('gids', shape=(self._n_gids_all,), dtype=np.uint) # TODO: element_id --> element_ids var_grp.create_dataset('element_id', shape=(self._n_segments_all,), dtype=np.uint) var_grp.create_dataset('element_pos', shape=(self._n_segments_all,), dtype=np.float) var_grp.create_dataset('index_pointer', shape=(self._n_gids_all+1,), dtype=np.uint64) var_grp.create_dataset('time', data=[self.tstart, self.tstop, self.dt]) # TODO: Let user determine value var_grp['time'].attrs['units'] = 'ms' for k, v in self._map_attrs.items(): var_grp.create_dataset(k, shape=(self._n_segments_all,), dtype=type(v[0])) var_grp['gids'][self._gids_beg:self._gids_end] = self._mapping_gids var_grp['element_id'][self._seg_offset_beg:self._seg_offset_end] = self._mapping_element_ids var_grp['element_pos'][self._seg_offset_beg:self._seg_offset_end] = self._mapping_element_pos var_grp['index_pointer'][self._gids_beg:(self._gids_end+1)] = self._mapping_index for k, v in self._map_attrs.items(): var_grp[k][self._seg_offset_beg:self._seg_offset_end] = v self._total_steps = n_steps self._buffer_block_size = buffer_size if not self._buffer_data: # If data is not being buffered and instead written to the main block, we have to add a rank offset # to the gid offset for gid, gid_offset in self._gid_map.items(): self._gid_map[gid] = (gid_offset[0] + self._seg_offset_beg, gid_offset[1] + self._seg_offset_beg) for var_name, data_tables in self._data_blocks.items(): # If users are trying to save multiple variables in the same file put data table in its own /{var} group # (not sonata compliant). Otherwise the data table is located at the root data_grp = self._h5_handle if self._n_vars == 1 else self._h5_handle.create_group('/{}'.format(var_name)) if self._buffer_data: # Set up in-memory block to buffer recorded variables before writing to the dataset data_tables.buffer_block = np.zeros((buffer_size, self._n_segments_local), dtype=np.float) data_tables.data_block = data_grp.create_dataset('data', shape=(n_steps, self._n_segments_all), dtype=np.float, chunks=True) # TODO: Remove Variable name data_tables.data_block.attrs['variable_name'] = var_name if self._units is not None: data_tables.data_block.attrs['units'] = self._units else: # Since we are not buffering data, we just write directly to the on-disk dataset data_tables.buffer_block = data_grp.create_dataset('data', shape=(n_steps, self._n_segments_all), dtype=np.float, chunks=True) data_tables.buffer_block.attrs['variable_name'] = var_name if self._units is not None: data_tables.buffer_block.attrs['units'] = self._units self._is_initialized = True
[docs] def record_cell(self, gid, var_name, seg_vals, tstep): """Record cell parameters. :param gid: gid of cell. :param var_name: name of variable being recorded. :param seg_vals: list of all segment values :param tstep: time step """ gid_beg, gid_end = self._gid_map[gid] buffer_block = self._data_blocks[var_name].buffer_block update_index = (tstep - self._last_save_indx) buffer_block[update_index, gid_beg:gid_end] = seg_vals
[docs] def record_cell_block(self, gid, var_name, seg_vals): """Save cell parameters one block at a time :param gid: gid of cell. :param var_name: name of variable being recorded. :param seg_vals: A vector/matrix of values being recorded """ gid_beg, gid_end = self._gid_map[gid] buffer_block = self._data_blocks[var_name].buffer_block if gid_end - gid_beg == 1: buffer_block[:, gid_beg] = seg_vals else: buffer_block[:, gid_beg:gid_end] = seg_vals
[docs] def flush(self): """Move data from memory to dataset""" if self._buffer_data: blk_beg = self._last_save_indx blk_end = blk_beg + self._buffer_block_size if blk_end > self._total_steps: # Need to handle the case that simulation doesn't end on a block step blk_end = blk_beg + self._total_steps - blk_beg block_size = blk_end - blk_beg self._last_save_indx += block_size for _, data_table in self._data_blocks.items(): data_table.data_block[blk_beg:blk_end, :] = data_table.buffer_block[:block_size, :]
[docs] def close(self): self._h5_handle.close()
[docs] def merge(self): if self._mpi_size > 1 and self._mpi_rank == 0: h5final = h5py.File(self._saved_file, 'w') tmp_h5_handles = [h5py.File(name, 'r') for name in self._tmp_files] # Find the gid and segment offsets for each temp h5 file gid_ranges = [] # list of (gid-beg, gid-end) gid_offset = 0 total_gid_count = 0 # total number of gids across all ranks seg_ranges = [] seg_offset = 0 total_seg_count = 0 # total number of segments across all ranks time_ds = None for h5_tmp in tmp_h5_handles: seg_count = len(h5_tmp['/mapping/element_pos']) seg_ranges.append((seg_offset, seg_offset+seg_count)) seg_offset += seg_count total_seg_count += seg_count gid_count = len(h5_tmp['mapping/gids']) gid_ranges.append((gid_offset, gid_offset+gid_count)) gid_offset += gid_count total_gid_count += gid_count time_ds = h5_tmp['mapping/time'] mapping_grp = h5final.create_group('mapping') if time_ds: mapping_grp.create_dataset('time', data=time_ds) element_id_ds = mapping_grp.create_dataset('element_id', shape=(total_seg_count,), dtype=np.uint) el_pos_ds = mapping_grp.create_dataset('element_pos', shape=(total_seg_count,), dtype=np.float) gids_ds = mapping_grp.create_dataset('gids', shape=(total_gid_count,), dtype=np.uint) index_pointer_ds = mapping_grp.create_dataset('index_pointer', shape=(total_gid_count+1,), dtype=np.uint) for k, v in self._map_attrs.items(): mapping_grp.create_dataset(k, shape=(total_seg_count,), dtype=type(v[0])) # combine the /mapping datasets for i, h5_tmp in enumerate(tmp_h5_handles): tmp_mapping_grp = h5_tmp['mapping'] beg, end = seg_ranges[i] element_id_ds[beg:end] = tmp_mapping_grp['element_id'] el_pos_ds[beg:end] = tmp_mapping_grp['element_pos'] for k, v in self._map_attrs.items(): mapping_grp[k][beg:end] = v # shift the index pointer values index_pointer = np.array(tmp_mapping_grp['index_pointer']) update_index = beg + index_pointer beg, end = gid_ranges[i] gids_ds[beg:end] = tmp_mapping_grp['gids'] index_pointer_ds[beg:(end+1)] = update_index # combine the /var/data datasets for var_name in self._variables: data_name = '/data' if self._n_vars == 1 else '/{}/data'.format(var_name) # data_name = '/{}/data'.format(var_name) var_data = h5final.create_dataset(data_name, shape=(self._total_steps, total_seg_count), dtype=np.float) var_data.attrs['variable_name'] = var_name for i, h5_tmp in enumerate(tmp_h5_handles): beg, end = seg_ranges[i] var_data[:, beg:end] = h5_tmp[data_name] for tmp_file in self._tmp_files: os.remove(tmp_file)
[docs]class CellVarRecorderParallel(CellVarRecorder): """ Unlike the parent, this take advantage of parallel h5py to writting to the results file across different ranks. """ def __init__(self, file_name, tmp_dir, variables, buffer_data=True, **kwargs): super(CellVarRecorder, self).__init__(file_name, tmp_dir, variables, buffer_data=buffer_data, mpi_rank=0, mpi_size=1, **kwargs) def _calc_offset(self): from mpi4py import MPI # Just needed for UNSIGNED_INT comm = io.bmtk_world_comm.comm rank = comm.Get_rank() nhosts = comm.Get_size() # iterate through the ranks let rank r determine the offset from rank r-1 for r in range(comm.Get_size()): if rank == r: if rank < (nhosts - 1): # pass the num of segments and num of gids to the next rank offsets = np.array([self._n_segments_local, self._n_gids_local], dtype=np.uint) comm.Send([offsets, MPI.UNSIGNED_INT], dest=(rank+1)) if rank > 0: # get num of segments and gids from prev. rank and calculate offsets offset = np.empty(2, dtype=np.uint) comm.Recv([offsets, MPI.UNSIGNED_INT], source=(r-1)) self._seg_offset_beg = offsets[0] self._seg_offset_end = self._seg_offset_beg + self._n_segments_local self._gids_beg = offset[1] self._gids_end = self._gids_beg + self._n_gids_local comm.Barrier() # broadcast the total num of gids/segments from the final rank to all the others if rank == (nhosts - 1): total_counts = np.array([self._seg_offset_end, self._gids_end], dtype=np.uint) else: total_counts = np.empty(2, dtype=np.uint) comm.Bcast(total_counts, root=(nhosts-1)) self._n_segments_all = total_counts[0] self._n_gids_all = total_counts[1] def _create_h5_file(self): self._h5_handle = h5py.File(self._file_name, 'w', driver='mpio', comm=io.bmtk_world_comm.comm) add_hdf5_version(self._h5_handle) add_hdf5_magic(self._h5_handle)
[docs] def merge(self): pass