# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2015-2016. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact terms@alleninstitute.org for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
import h5py
import numpy as np
[docs]class NwbDataSet(object):
    """ A very simple interface for exracting electrophysiology data
    from an NWB file.
    """
    SPIKE_TIMES = "spike_times"
    DEPRECATED_SPIKE_TIMES = "aibs_spike_times"
    def __init__(self, file_name, spike_time_key=None):
        """ Initialize the NwbDataSet instance with a file name
        Parameters
        ----------
        file_name: string
           NWB file name
        """
        self.file_name = file_name
        if spike_time_key is None:
            self.spike_time_key = NwbDataSet.SPIKE_TIMES
        else:
            self.spike_time_key = spike_time_key
[docs]    def get_sweep(self, sweep_number):
        """ Retrieve the stimulus, response, index_range, and sampling rate
        for a particular sweep.  This method hides the NWB file's distinction
        between a "Sweep" and an "Experiment".  An experiment is a subset of
        of a sweep that excludes the initial test pulse.  It also excludes
        any erroneous response data at the end of the sweep (usually for
        ramp sweeps, where recording was terminated mid-stimulus).
        Some sweeps do not have an experiment, so full data arrays are
        returned.  Sweeps that have an experiment return full data arrays
        (include the test pulse) with any erroneous data trimmed from the
        back of the sweep.
        Parameters
        ----------
        sweep_number: int
        Returns
        -------
        dict
            A dictionary with 'stimulus', 'response', 'index_range', and
            'sampling_rate' elements.  The index range is a 2-tuple where
            the first element indicates the end of the test pulse and the
            second index is the end of valid response data.
        """
        with h5py.File(self.file_name, 'r') as f:
            swp = f['epochs']['Sweep_%d' % sweep_number]
            # fetch data from file and convert to correct SI unit
            # this operation depends on file version. early versions of
            #   the file have incorrect conversion information embedded
            #   in the nwb file and data was stored in the appropriate
            #   SI unit. For those files, return uncorrected data.
            #   For newer files (1.1 and later), apply conversion value.
            major, minor = self.get_pipeline_version()
            if (major == 1 and minor > 0) or major > 1:
                # stimulus
                stimulus_dataset = swp['stimulus']['timeseries']['data']
                conversion = float(stimulus_dataset.attrs["conversion"])
                stimulus = stimulus_dataset.value * conversion
                # acquisition
                response_dataset = swp['response']['timeseries']['data']
                conversion = float(response_dataset.attrs["conversion"])
                response = response_dataset.value * conversion
            else:   # old file version
                stimulus_dataset = swp['stimulus']['timeseries']['data']
                stimulus = stimulus_dataset.value
                response = swp['response']['timeseries']['data'].value
            if 'unit' in stimulus_dataset.attrs:
                unit = stimulus_dataset.attrs["unit"].decode('UTF-8')
                unit_str = None
                if unit.startswith('A'):
                    unit_str = "Amps"
                elif unit.startswith('V'):
                    unit_str = "Volts"
                assert unit_str is not None, Exception(
                    "Stimulus time series unit not recognized")
            else:
                unit = None
                unit_str = 'Unknown'
            swp_idx_start = swp['stimulus']['idx_start'].value
            swp_length = swp['stimulus']['count'].value
            swp_idx_stop = swp_idx_start + swp_length - 1
            sweep_index_range = (swp_idx_start, swp_idx_stop)
            # if the sweep has an experiment, extract the experiment's index
            # range
            try:
                exp = f['epochs']['Experiment_%d' % sweep_number]
                exp_idx_start = exp['stimulus']['idx_start'].value
                exp_length = exp['stimulus']['count'].value
                exp_idx_stop = exp_idx_start + exp_length - 1
                experiment_index_range = (exp_idx_start, exp_idx_stop)
            except KeyError:
                # this sweep has no experiment.  return the index range of the
                # entire sweep.
                experiment_index_range = sweep_index_range
            assert sweep_index_range[0] == 0, Exception(
                "index range of the full sweep does not start at 0.")
            return {
                'stimulus': stimulus,
                'response': response,
                'stimulus_unit' : unit_str,
                'index_range': experiment_index_range,
                'sampling_rate': 1.0 * swp['stimulus']['timeseries']['starting_time'].attrs['rate']
            } 
[docs]    def set_sweep(self, sweep_number, stimulus, response):
        """ Overwrite the stimulus or response of an NWB file.
        If the supplied arrays are shorter than stored arrays,
        they are padded with zeros to match the original data
        size.
        Parameters
        ----------
        sweep_number: int
        stimulus: np.array
           Overwrite the stimulus with this array.  If None, stimulus is unchanged.
        response: np.array
            Overwrite the response with this array.  If None, response is unchanged.
        """
        with h5py.File(self.file_name, 'r+') as f:
            swp = f['epochs']['Sweep_%d' % sweep_number]
            # this is the length of the entire sweep data, including test pulse and
            # whatever might be in front of it
            # TODO: remove deprecated 'idx_stop'
            if 'idx_stop' in swp['stimulus']:
                sweep_length = swp['stimulus']['idx_stop'].value + 1
            else:
                sweep_length = swp['stimulus']['count'].value
            if stimulus is not None:
                # if the data is shorter than the sweep, pad it with zeros
                missing_data = sweep_length - len(stimulus)
                if missing_data > 0:
                    stimulus = np.append(stimulus, np.zeros(missing_data))
                swp['stimulus']['timeseries']['data'][...] = stimulus
            if response is not None:
                # if the data is shorter than the sweep, pad it with zeros
                missing_data = sweep_length - len(response)
                if missing_data > 0:
                    response = np.append(response, np.zeros(missing_data))
                swp['response']['timeseries']['data'][...] = response 
[docs]    def get_pipeline_version(self):
        """ Returns the AI pipeline version number, stored in the 
            metadata field 'generated_by'. If that field is
            missing, version 0.0 is returned.
            Returns
            -------
            int tuple: (major, minor)
        """
        try:
            with h5py.File(self.file_name, 'r') as f:
                if 'generated_by' in f["general"]:
                    info = f["general/generated_by"]
                    # generated_by stores array of keys and values
                    # keys are even numbered, corresponding values are in
                    #   odd indices
                    for i in range(len(info)):
                        val = info[i]
                        if info[i] == 'version':
                            version = info[i+1]
                            break
            toks = version.split('.')
            if len(toks) >= 2:
                major = int(toks[0])
                minor = int(toks[1])
        except:
            minor = 0
            major = 0
        return major, minor 
[docs]    def get_spike_times(self, sweep_number, key=None):
        """ Return any spike times stored in the NWB file for a sweep.
        Parameters
        ----------
        sweep_number: int
            index to access
        key : string
            label where the spike times are stored (default NwbDataSet.SPIKE_TIMES)
        Returns
        -------
        list
           list of spike times in seconds relative to the start of the sweep
        """
        if key is None:
            key = self.spike_time_key
        with h5py.File(self.file_name, 'r') as f:
            sweep_name = "Sweep_%d" % sweep_number
            datasets = ["analysis/%s/Sweep_%d" % (key, sweep_number),
                        "analysis/%s/Sweep_%d" % (self.DEPRECATED_SPIKE_TIMES, sweep_number)]
            for ds in datasets:
                if ds in f:
                    return f[ds].value
            return [] 
[docs]    def set_spike_times(self, sweep_number, spike_times, key=None):
        """ Set or overwrite the spikes times for a sweep.
        Parameters
        ----------
        sweep_number : int
            index to access
        key : string
            where the times are stored (default NwbDataSet.SPIKE_TIME)
        spike_times: np.array
           array of spike times in seconds
        """
        if key is None:
            key = self.spike_time_key
        with h5py.File(self.file_name, 'r+') as f:
            # make sure expected directory structure is in place
            if "analysis" not in f.keys():
                f.create_group("analysis")
            analysis_dir = f["analysis"]
            if NwbDataSet.SPIKE_TIMES not in analysis_dir.keys():
                #   analysis_dir.create_group(NwbDataSet.SPIKE_TIMES)
                g = analysis_dir.create_group(NwbDataSet.SPIKE_TIMES)
                # mixup in specification for validator resulted everything
                #   in 'analysis' requiring a custom label, even though
                #   it's already known to be custom. don't argue, just
                #   support the metadata redundancy
                g.attrs["neurodata_type"] = "Custom"
            spike_dir = analysis_dir[NwbDataSet.SPIKE_TIMES]
            # see if desired dataset already exists
            sweep_name = "Sweep_%d" % sweep_number
            if sweep_name in spike_dir.keys():
                # rewriting data -- delete old dataset
                del spike_dir[sweep_name]
            spike_dir.create_dataset(
                sweep_name, data=spike_times, dtype='f8', maxshape=(None,)) 
[docs]    def get_sweep_numbers(self):
        """ Get all of the sweep numbers in the file, including test sweeps. """
        with h5py.File(self.file_name, 'r') as f:
            sweeps = [int(e.split('_')[1])
                      for e in f['epochs'].keys() if e.startswith('Sweep_')]
            return sweeps 
[docs]    def get_experiment_sweep_numbers(self):
        """ Get all of the sweep numbers for experiment epochs in the file, not including test sweeps. """
        with h5py.File(self.file_name, 'r') as f:
            sweeps = [int(e.split('_')[1])
                      for e in f['epochs'].keys() if e.startswith('Experiment_')]
            return sweeps 
[docs]    def fill_sweep_responses(self, fill_value=0.0, sweep_numbers=None, extend_experiment=False):
        """ Fill sweep response arrays with a single value.
        Parameters
        ----------
        fill_value: float
            Value used to fill sweep response array
        sweep_numbers: list
            List of integer sweep numbers to be filled (default all sweeps)
        extend_experiment: bool
            If True, extend experiment epoch length to the end of the sweep (undo any truncation)
        """
        with h5py.File(self.file_name, 'a') as f:
            if sweep_numbers is None:
                sweep_numbers = self.get_sweep_numbers()
            for sweep_number in sweep_numbers:
                epoch = "Sweep_%d" % sweep_number
                if epoch in f['epochs']:
                    f['epochs'][epoch]['response'][
                        'timeseries']['data'][...] = fill_value
                if extend_experiment:
                    epoch = "Experiment_%d" % sweep_number
                    if epoch in f['epochs']:
                        idx_start = f['epochs'][epoch]['stimulus']['idx_start'].value
                        count = f['epochs'][epoch]['stimulus']['timeseries']['data'].shape[0]
                        del f['epochs'][epoch]['stimulus']['count']
                        f['epochs'][epoch]['stimulus']['count'] = count - idx_start