Source code for bmtk.simulator.popnet.ssn.modules.external_inputs

import numpy as np
import pandas as pd
import h5py

from ..pyfunction_cache import py_modules
from .sim_module import SimulatorMod


[docs] class ExternalRatesMod(SimulatorMod): def __init__(self, name, input_type, module, **kwargs): self._name = name self._input_type = input_type self._module = module self._params = kwargs
[docs] def initialize(self, sim): if self._module == 'npy': npy_path = self._params['file'] inputs_arr = np.load(npy_path) if sim.tstop is None: # TODO: WARNING THAT TSTOP IS BEING SET BY INPUT sim.tstop = len(inputs_arr)*sim.dt if sim.nsteps < len(inputs_arr): # TODO: WARN THAT input is being cut inputs_arr = inputs_arr[:sim.nsteps] elif sim.nsteps > len(inputs_arr): # GIVE WARNING inputs_arr = np.append(inputs_arr, np.zeros(sim.nsteps - len(inputs_arr))) node_set = sim.network.get_node_set(self._params.get('node_set', 'all')) for node in node_set.fetch_nodes(): ssn_node = sim.network.get_node(node.population_name, node.node_id) ssn_node.external_inputs = inputs_arr.flatten() elif self._module == 'function': fnc_name = self._params['inputs_generator'] generator_fnc = py_modules.user_function(fnc_name) node_set = sim.network.get_node_set(self._params.get('node_set', 'all')) for node in node_set.fetch_nodes(): ssn_node = sim.network.get_node(node.population_name, node.node_id) external_inputs = generator_fnc(ssn_node, sim) ssn_node.external_inputs = external_inputs if sim.tstop is None: # TODO: WARNING THAT TSTOP IS BEING SET BY INPUT sim.tstop = len(external_inputs)*sim.dt elif self._module == 'csv': # TODO: Check Timestamps match, line up if needed rates_df = pd.read_csv(self._params['file'], sep=self._params.get('sep', ' ')) node_set = sim.network.get_node_set(self._params.get('node_set', 'all')) for node in node_set.fetch_nodes(): ssn_node = sim.network.get_node(node.population_name, node.node_id) # TODO: Check that node exists in file inputs = rates_df[(rates_df['node_id'] == ssn_node.node_id) & (rates_df['population'] == ssn_node.population)]['firing_rates'] if len(inputs) > 0: ssn_node.external_inputs = inputs if sim.tstop is None: # TODO: WARNING THAT TSTOP IS BEING SET BY INPUT sim.tstop = len(inputs)*sim.dt elif self._module in ['h5', 'sonata']: with h5py.File(self._params['file'], 'r') as h5: ratesgrp = h5['/rates'] node_set = sim.network.get_node_set(self._params.get('node_set', 'all')) for node in node_set.fetch_nodes(): ssn_node = sim.network.get_node(node.population_name, node.node_id) node_id_map = ratesgrp[f'{ssn_node.population}/mapping/node_ids'][()] idx = np.argwhere(node_id_map == ssn_node.node_id)[0][0] inputs = ratesgrp[f'{ssn_node.population}/data'][:, idx] ssn_node.external_inputs = inputs if sim.tstop is None: # TODO: WARNING THAT TSTOP IS BEING SET BY INPUT sim.tstop = len(inputs)*sim.dt else: raise ValueError('Unknown module')