import os
import h5py
import numpy as np
from bmtk.utils.sonata.utils import add_hdf5_magic, add_hdf5_version
from .compartment_reader import CompartmentReaderVer01 as CompartmentReader
from .core import CompartmentWriterABC
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
nhosts = comm.Get_size()
barrier = comm.Barrier
except Exception as exc:
rank = 0
nhosts = 1
barrier = lambda: None
[docs]class PopulationWriterv01(CompartmentWriterABC, CompartmentReader):
"""Used to save cell membrane variables (V, Ca2+, etc) to the described hdf5 format.
For parallel simulations this class will write to a seperate tmp file on each rank, then use the merge method to
combine the results. This is less efficent, but doesn't require the user to install mpi4py and build h5py in
parallel mode. For better performance use the CellVarRecorderParrallel class instead.
"""
[docs] class DataTable(object):
"""A small struct to keep track of different data (and buffer) tables
"""
def __init__(self, var_name):
self.var_name = var_name
# If buffering data, buffer_block will be an in-memory array and will write to data_block during when
# filled. If not buffering buffer_block is an hdf5 dataset and data_block is ignored
self.data_block = None
self.buffer_block = None
self.block_window = (0.0, 10e20)
def __init__(self, parent, population, variable=None, units=None, tstart=0.0, tstop=1.0, dt=0.01, n_steps=None,
buffer_size=0, **kwargs):
self._h5_base = None
self._parent = parent
self._population = population
self._variable = variable
self._units = units
self._tstart = tstart
self._tstop = tstop
self._dt = dt
self._n_steps = n_steps
#if self._n_steps is None:
# self._n_steps = int((self._tstop - self._tstart)/self._dt)
self._tmp_files = []
self._mapping_gids = [] # list of gids in the order they appear in the data
self._gid_map = {} # table for looking up the gid offsets
self._element_data = {} # Used for additonal attributes in /mapping
self._mapping_element_ids = [] # sections
self._mapping_element_pos = [] # segments
self._mapping_index = [0] # index_pointer
self._buffer_size = buffer_size
self._buffer_data = buffer_size > 0
self._data_block = self.DataTable(self._variable)
# self._last_save_indx = 0 # for buffering, used to keep track of last timestep data was saved to disk
self._buffer_block_size = 0
self._total_steps = 0
# Keep track of gids across the different ranks
self._n_gids_all = 0
self._n_gids_local = 0
self._gids_beg = 0
self._gids_end = 0
# Keep track of segment counts across the different ranks
self._n_segments_all = 0
self._n_segments_local = 0
self._seg_offset_beg = 0
self._seg_offset_end = 0
self._is_initialized = False
@property
def h5_base(self):
if self._h5_base is None:
self._h5_base = self._parent.report_group.create_group(self._population)
return self._h5_base
def _calc_offset(self):
self._n_segments_all = self._n_segments_local
self._seg_offset_beg = 0
self._seg_offset_end = self._n_segments_local
self._n_gids_all = self._n_gids_local
self._gids_beg = 0
self._gids_end = self._n_gids_local
[docs] def set_units(self, val, population=None):
self._units = val
[docs] def units(self, population=None):
return self._units
[docs] def set_variable(self, val, population=None):
self._variable = val
[docs] def variable(self, population=None):
return self._variable
[docs] def set_tstart(self, val, population=None):
self._tstart = val
[docs] def tstart(self, population=None):
return self._tstart
[docs] def set_tstop(self, val, population=None):
self._tstop = val
[docs] def tstop(self, population=None):
return self._tstop
[docs] def set_dt(self, val, population=None):
self._dt = val
[docs] def dt(self, population=None):
return self._dt
[docs] def n_steps(self, population=None):
if self._n_steps is None:
self._n_steps = int((self._tstop - self._tstart) / self._dt)
return self._n_steps
[docs] def set_time_trace(self, val, population=None):
raise NotImplementedError()
[docs] def time_trace(self, population=None):
raise NotImplementedError()
[docs] def add_cell(self, node_id, element_ids, element_pos, **map_attrs):
assert(len(element_ids) == len(element_pos))
# TODO: Check the same gid isn't added twice
n_segs = len(element_pos)
self._gid_map[node_id] = (self._n_segments_local, self._n_segments_local + n_segs)
self._mapping_gids.append(node_id)
self._mapping_element_ids.extend(element_ids)
self._mapping_element_pos.extend(element_pos)
self._mapping_index.append(self._mapping_index[-1] + n_segs)
self._n_segments_local += n_segs
self._n_gids_local += 1
for k, v in map_attrs.items():
if k not in self._element_data:
self._element_data[k] = v
else:
self._element_data[k].extend(v)
[docs] def initialize(self, **kwargs):
if self._is_initialized:
return
n_steps = self.n_steps()
if n_steps <= 0:
raise Exception('A non-zero positive integer num-of-steps is required to initialize the compartment report.'
'Please specify report length using the n_steps parameters (or using appropiate tstop,'
'tstart, and dt).')
self._calc_offset()
base_grp = self.h5_base
var_grp = base_grp.create_group('mapping')
var_grp.create_dataset('node_ids', shape=(self._n_gids_all,), dtype=np.uint)
var_grp.create_dataset('element_ids', shape=(self._n_segments_all,), dtype=np.uint)
var_grp.create_dataset('element_pos', shape=(self._n_segments_all,), dtype=float)
var_grp.create_dataset('index_pointer', shape=(self._n_gids_all+1,), dtype=np.uint64)
var_grp.create_dataset('time', data=[self.tstart(), self.tstop(), self.dt()])
for k, v in self._element_data.items():
var_grp.create_dataset(k, shape=(self._n_segments_all,), dtype=type(v[0]))
var_grp['node_ids'][self._gids_beg:self._gids_end] = self._mapping_gids
var_grp['element_ids'][self._seg_offset_beg:self._seg_offset_end] = self._mapping_element_ids
var_grp['element_pos'][self._seg_offset_beg:self._seg_offset_end] = self._mapping_element_pos
var_grp['index_pointer'][self._gids_beg:(self._gids_end+1)] = self._mapping_index
for k, v in self._element_data.items():
var_grp[k][self._seg_offset_beg:self._seg_offset_end] = v
self._total_steps = n_steps
self._buffer_size = np.min((self._total_steps, self._buffer_size))
self._buffer_block_size = self._buffer_size
if not self._buffer_data:
# If data is not being buffered and instead written to the main block, we have to add a rank offset
# to the gid offset
for gid, gid_offset in self._gid_map.items():
self._gid_map[gid] = (gid_offset[0] + self._seg_offset_beg, gid_offset[1] + self._seg_offset_beg)
if self._buffer_data:
# Set up in-memory block to buffer recorded variables before writing to the dataset
self._data_block.buffer_block = np.zeros((self._buffer_size, self._n_segments_local), dtype=float)
self._data_block.data_block = base_grp.create_dataset('data', shape=(self.n_steps(), self._n_segments_all),
dtype=float, chunks=True)
if self._variable is not None:
self._data_block.data_block.attrs['variable'] = self._variable
if self._units is not None:
self._data_block.data_block.attrs['units'] = self._units
self._data_block.block_window = (0, self._buffer_block_size)
else:
# Since we are not buffering data, we just write directly to the on-disk dataset
self._data_block.buffer_block = base_grp.create_dataset(
'data',
shape=(self.n_steps(), self._n_segments_all),
dtype=float,
chunks=True
)
if self._variable is not None:
self._data_block.buffer_block.attrs['variable'] = self._variable
if self._units is not None:
self._data_block.buffer_block.attrs['units'] = self._units
self._is_initialized = True
def _reset_buffer_window(self, tstep):
blk_beg = int(tstep/self._buffer_size)*self._buffer_size
blk_end = blk_beg + self._buffer_size
self._data_block.block_window = (blk_beg, blk_end)
[docs] def record_cell(self, node_id, vals, tstep, population=None):
"""Record cell parameters.
:param gid: gid of cell.
:param var_name: name of variable being recorded.
:param seg_vals: list of all segment values
:param tstep: time step
"""
self.initialize()
gid_beg, gid_end = self._gid_map[node_id]
buffer_block = self._data_block.buffer_block
if not self._buffer_data:
buffer_block[tstep, gid_beg:gid_end] = vals
elif self._data_block.block_window[0] <= tstep < self._data_block.block_window[1]:
update_index = tstep - self._data_block.block_window[0]
buffer_block[update_index, gid_beg:gid_end] = vals
else:
self.flush()
self._reset_buffer_window(tstep)
update_index = tstep - self._data_block.block_window[0]
buffer_block[update_index, gid_beg:gid_end] = vals
[docs] def record_cell_block(self, node_id, vals, beg_step, end_step, population=None):
"""Save cell parameters one block at a time
:param gid: gid of cell.
:param var_name: name of variable being recorded.
:param seg_vals: A vector/matrix of values being recorded
"""
self.initialize()
gid_beg, gid_end = self._gid_map[node_id]
buffer_block = self._data_block.buffer_block
if isinstance(vals, list) or vals.ndim == 1:
buffer_block[:, gid_beg] = vals
# buffer_block[beg_step:end_step, gid_beg:gid_end] = vals
else:
buffer_block[:, gid_beg:gid_end] = vals
[docs] def flush(self):
"""Move data from memory to dataset"""
if self._buffer_data:
# blk_beg = self._last_save_indx
# blk_end = blk_beg + self._buffer_block_size
blk_beg = self._data_block.block_window[0]
blk_end = self._data_block.block_window[1]
if blk_end > self._total_steps:
# Need to handle the case that simulation doesn't end on a block step
blk_end = blk_beg + self._total_steps - blk_beg
block_size = blk_end - blk_beg
self._data_block.data_block[blk_beg:blk_end, :] = self._data_block.buffer_block[:block_size, :]
self._reset_buffer_window(blk_end+1)
[docs] def close(self):
# Let the parent take care of this
pass
[docs] def merge(self):
# Let the parent take care of this
pass
[docs]class CompartmentWriterv01(CompartmentWriterABC):
def __init__(self, file_path, mode='w', default_population=None, cache_dir=None, variable=None, units=None,
buffer_size=0, tstart=0.0, tstop=0.0, dt=0.0, n_steps=None, **kwargs):
self._mode = mode
self._variable = variable
self._units = units
self._pop_tables = {}
self._default_pop = default_population
self._buffer_size = buffer_size
self._tstart = tstart
self._tstop = tstop
self._dt = dt
self._n_steps = n_steps
self._kwargs = kwargs
self._h5_handle = None
self._h5report_grp = None
self._mpi_rank = kwargs.get('mpi_rank', rank)
self._mpi_size = kwargs.get('mpi_size', nhosts)
self._final_fpath = file_path # name of file being writen too.
self._cache_dir = cache_dir or os.path.dirname(os.path.abspath(file_path)) # used for mulitple ranks
self._base_name = os.path.basename(file_path) # make sure file names don't clash if there are multiple reports
self._interm_fpath = self._get_iterm_fpath() # In certain cases (parallelized simulation) split the final file by rank.
def _get_iterm_fpath(self):
if self._mpi_size > 1:
return self.temp_files[self._mpi_rank]
else:
return self._final_fpath
def _create_h5file(self):
fdir = os.path.dirname(os.path.abspath(self._interm_fpath))
if not os.path.exists(fdir):
os.mkdir(fdir)
self._h5_handle = h5py.File(self._interm_fpath, self._mode)
add_hdf5_version(self._h5_handle)
add_hdf5_magic(self._h5_handle)
@property
def report_group(self):
if self._h5report_grp is None:
self._create_h5file()
if 'report' in self._h5_handle.keys():
self._h5report_grp = self._h5_handle['report']
else:
self._h5report_grp = self._h5_handle.create_group('report')
return self._h5report_grp
@property
def temp_files(self):
return [os.path.join(self._cache_dir, '.bmtk_tmp_cellvars_{}_{}'.format(r, self._base_name))
for r in range(self._mpi_size)]
[docs] def set_units(self, val, population=None):
self[population].set_units(val)
[docs] def set_variable(self, val, population=None):
self[population].set_variable(val)
[docs] def set_tstart(self, val, population=None):
self[population].set_tstart(val)
[docs] def set_tstop(self, val, population=None):
self[population].set_tstop(val)
[docs] def set_dt(self, val, population=None):
self[population].set_dt(val)
[docs] def n_steps(self, population=None):
self[population].n_steps()
[docs] def set_time_trace(self, val, population=None):
self[population].set_time_trace(val)
[docs] def add_cell(self, node_id, element_ids, element_pos, population=None, **element_data):
pop_str = population or self._default_pop
pop_grp = self._build_or_fetch_pop(pop_str)
pop_grp.add_cell(node_id=node_id, element_ids=element_ids, element_pos=element_pos, **element_data)
[docs] def initialize(self):
for pop_grp in self._pop_tables.values():
pop_grp.initialize()
[docs] def record_cell(self, node_id, vals, tstep, population=None):
pop_str = population or self._default_pop
pop_grp = self._build_or_fetch_pop(pop_str)
pop_grp.record_cell(node_id=node_id, vals=vals, tstep=tstep)
[docs] def record_cell_block(self, node_id, vals, beg_step, end_step, population=None):
self[population].record_cell_block(node_id=node_id, vals=vals, beg_step=beg_step, end_step=end_step)
[docs] def flush(self):
for pop_grp in self._pop_tables.values():
pop_grp.flush()
[docs] def close(self):
for pop_grp in self._pop_tables.values():
pop_grp.close()
if self._h5_handle is not None:
self._h5_handle.close()
if self._mpi_size > 1:
self.merge()
[docs] def merge(self):
barrier()
if self._mpi_size > 1 and self._mpi_rank == 0:
h5final = h5py.File(self._final_fpath, 'w')
tmp_reports = [CompartmentReader(name) for name in self.temp_files if os.path.exists(name)]
populations = set()
for r in tmp_reports:
populations.update(r.populations)
for pop in populations:
# Find the gid and segment offsets for each temp h5 file
gid_ranges = [] # list of (gid-beg, gid-end)
gid_offset = 0
total_gid_count = 0 # total number of gids across all ranks
seg_ranges = []
seg_offset = 0
total_seg_count = 0 # total number of segments across all ranks
times = None
n_steps = 0
variable = None
units = None
for rpt in tmp_reports:
if pop not in rpt.populations:
continue
report = rpt[pop]
seg_count = len(report.element_pos()) # ['/mapping/element_pos'])
seg_ranges.append((seg_offset, seg_offset + seg_count))
seg_offset += seg_count
total_seg_count += seg_count
gid_count = len(report.node_ids()) # h5_tmp['mapping/node_ids'])
gid_ranges.append((gid_offset, gid_offset + gid_count))
gid_offset += gid_count
total_gid_count += gid_count
times = report.time() # h5_tmp['mapping/time']
n_steps = report.n_steps()
variable = report.variable()
units = report.units()
mapping_grp = h5final.create_group('/report/{}/mapping'.format(pop))
if times is not None and len(times) > 0:
mapping_grp.create_dataset('time', data=times)
element_id_ds = mapping_grp.create_dataset('element_ids', shape=(total_seg_count,), dtype=np.uint)
el_pos_ds = mapping_grp.create_dataset('element_pos', shape=(total_seg_count,), dtype=float)
gids_ds = mapping_grp.create_dataset('node_ids', shape=(total_gid_count,), dtype=np.uint)
index_pointer_ds = mapping_grp.create_dataset('index_pointer', shape=(total_gid_count + 1,),
dtype=np.uint)
for rpt in tmp_reports:
if pop not in rpt.populations:
continue
report = rpt[pop]
for k, v in report.custom_columns().items():
if k not in mapping_grp.keys():
mapping_grp.create_dataset(k, shape=(total_seg_count,), dtype=type(v[0]))
# combine the /mapping datasets
i = 0
for rpt in tmp_reports:
if pop not in rpt.populations:
continue
report = rpt[pop]
# tmp_mapping_grp = h5_tmp['mapping']
beg, end = seg_ranges[i]
element_id_ds[beg:end] = report.element_ids() # tmp_mapping_grp['element_id']
el_pos_ds[beg:end] = report.element_pos() # tmp_mapping_grp['element_pos']
for k, v in report.custom_columns().items():
mapping_grp[k][beg:end] = v
# shift the index pointer values
index_pointer = np.array(report.index_pointer()) # tmp_mapping_grp['index_pointer'])
update_index = beg + index_pointer
beg, end = gid_ranges[i]
gids_ds[beg:end] = report.node_ids() # tmp_mapping_grp['node_ids']
index_pointer_ds[beg:(end + 1)] = update_index
i += 1
# combine the /var/data datasets
data_name = '/report/{}/data'.format(pop)
# data_name = '/{}/data'.format(var_name)
var_data = h5final.create_dataset(data_name, shape=(n_steps, total_seg_count), dtype=float)
# var_data.attrs['variable_name'] = var_name
i = 0
for rpt in tmp_reports:
if pop not in rpt.populations:
continue
report = rpt[pop]
beg, end = seg_ranges[i]
var_data[:, beg:end] = report.data()
i += 1
if variable is not None:
var_data.attrs['variable'] = variable
if units is not None:
var_data.attrs['units'] = units
for tmp_file in self.temp_files:
if os.path.exists(tmp_file):
os.remove(tmp_file)
barrier()
def _build_or_fetch_pop(self, population):
if population is None:
raise Exception('Please specify a valid node population (or use default_population parameter in constructor).')
if population in self._pop_tables:
pop_grp = self._pop_tables[population]
else:
pop_grp = PopulationWriterv01(self, population, variable=self._variable, units=self._units,
tstart=self._tstart, tstop=self._tstop, dt=self._dt,
buffer_size=self._buffer_size, n_steps=self._n_steps)
self._pop_tables[population] = pop_grp
return pop_grp
def __getitem__(self, population):
pop_str = population or self._default_pop
return self._build_or_fetch_pop(pop_str)