Source code for bmtk.simulator.popnet.popsimulator

# Copyright 2017. Allen Institute. All rights reserved
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
# following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
# disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import logging
from six import string_types

from dipde.internals.internalpopulation import InternalPopulation
from dipde.internals.externalpopulation import ExternalPopulation
from dipde.internals.connection import Connection
import dipde

from bmtk.simulator.core.simulator import Simulator
from . import config as cfg
from . import utils as poputils
import bmtk.simulator.utils.simulation_inputs as inputs
from bmtk.utils.reports.spike_trains import SpikeTrains
from bmtk.utils.io import firing_rates


[docs]class PopSimulator(Simulator): def __init__(self, graph, dt=0.0001, tstop=0.0, overwrite=True): self._graph = graph self._tstop = tstop self._dt = dt self._rates_file = None # name of file where the output is saved self.__population_list = [] # list of all populations, internal and external self.__connection_list = [] # list of all connections self._dipde_network = None # reference to dipde.Network object self.io = self._graph.io @property def tstop(self): return self._tstop @tstop.setter def tstop(self, value): self._tstop = value @property def dt(self): return self._dt @dt.setter def dt(self, value): self._dt = value @property def rates_file(self): return self._rates_file @rates_file.setter def rates_file(self, value): self._rates_file = value @property def populations(self): return self.__population_list @property def connections(self): return self.__connection_list
[docs] def add_rates_nwb(self, network, nwb_file, trial, force=False): """Creates external population firing rates from an NWB file. Will iterate through a processing trial of an NWB file by assigning gids the population it belongs too and taking the average firing rate. This should be done before calling build_cells(). If a population has already been assigned a firing rate an error will occur unless force=True. :param network: Name of network with external populations. :param nwb_file: NWB file with spike rates. :param trial: trial id in NWB file :param force: will overwrite existing firing rates """ existing_rates = self._rates[network] # TODO: validate network exists # Get all unset, external populations in a network. network_pops = self._graph.get_populations(network) selected_pops = [] for pop in network_pops: if pop.is_internal: continue elif not force and pop.pop_id in existing_rates: self.io.log_info('Firing rate for {}/{} has already been set, skipping.'.format(network, pop.pop_id)) else: selected_pops.append(pop) if selected_pops: # assign firing rates from NWB file # TODO: rates_dict = poputils.get_firing_rate_from_nwb(selected_pops, nwb_file, trial) self._rates[network].update(rates_dict)
[docs] def add_rate_hz(self, network, pop_id, rate, force=False): """Set the firing rate of an external population. This should be done before calling build_cells(). If a population has already been assigned a firing rate an error will occur unless force=True. :param network: name of network with wanted exteranl population :param pop_id: name/id of external population :param rate: firing rate in Hz. :param force: will overwrite existing firing rates """ self.__add_rates_validator(network, pop_id, force) self._rates[network][pop_id] = rate
def __add_rates_validator(self, network, pop_id, force): if network not in self._graph.networks: raise Exception('No network {} found in PopGraph.'.format(network)) pop = self._graph.get_population(network, pop_id) if pop is None: raise Exception('No population with id {} found in {}.'.format(pop_id, network)) if pop.is_internal: raise Exception('Population {} in {} is not an external population.'.format(pop_id, network)) if not force and pop_id in self._rates[network]: raise Exception('The firing rate for {}/{} already set and force=False.'.format(network, pop_id)) def _get_rate(self, network, pop): """Gets the firing rate for a given population""" return self._rates[network][pop.pop_id]
[docs] def build_populations(self): """Build dipde Population objects from graph nodes. To calculate external populations firing rates, it first see if a population's firing rate has been manually set in the graph. Otherwise it attempts to calulate the firing rate from the call to add_rate_hz, add_rates_NWB, etc. (which should be called first). """ for network in self._graph.networks: for pop in self._graph.get_populations(network): if pop.is_internal: dipde_pop = self.__create_internal_pop(pop) else: dipde_pop = self.__create_external_pop(pop, self._get_rate(network, pop)) self.__population_list.append(dipde_pop) self.__population_table[network][pop.pop_id] = dipde_pop
[docs] def set_logging(self, log_file): # TODO: move this out of the function, put in io class if os.path.exists(log_file): os.remove(log_file) # get root logger logger = logging.getLogger() for h in list(logger.handlers): # remove existing handlers that will write to console. logger.removeHandler(h) # creates handler that write to log_file logging.basicConfig(filename=log_file, filemode='w', level=logging.DEBUG)
[docs] def set_external_connections(self, network_name): """Sets the external connections for populations in a given network. :param network_name: name of external network with External Populations to connect to internal pops. """ for edge in self._graph.get_edges(network_name): # Get source and target populations src = edge.source source_pop = self.__population_table[src.network][src.pop_id] trg = edge.target target_pop = self.__population_table[trg.network][trg.pop_id] # build a connection. self.__connection_list.append(self.__create_connection(source_pop, target_pop, edge))
[docs] def set_recurrent_connections(self): """Initialize internal connections.""" for network in self._graph.internal_networks(): for edge in self._graph.get_edges(network): src = edge.source source_pop = self.__population_table[src.network][src.pop_id] trg = edge.target target_pop = self.__population_table[trg.network][trg.pop_id] self.__connection_list.append(self.__create_connection(source_pop, target_pop, edge))
[docs] def run(self, tstop=None): # TODO: Check if cells/connections need to be rebuilt. # Create the network dipde_pops = [p.dipde_obj for p in self._graph.populations] dipde_conns = [c.dipde_obj for c in self._graph.connections] self._dipde_network = dipde.Network(population_list=dipde_pops, connection_list=dipde_conns) if tstop is None: tstop = self.tstop self.io.log_info("Running simulation.") self._dipde_network.run(t0=0.0, tf=tstop, dt=self.dt) # TODO: make record_rates optional? self.__record_rates() self.io.log_info("Finished simulation.")
def __create_internal_pop(self, params): # TODO: use getter methods directly in case arguments are not stored in dynamics params # pop = InternalPopulation(**params.dynamics_params) pop = InternalPopulation(**params.model_params) return pop def __create_external_pop(self, params, rates): pop = ExternalPopulation(rates, record=False) return pop def __create_connection(self, source, target, params): return Connection(source, target, nsyn=params.nsyns, delays=params.delay, weights=params.weight) def __record_rates(self): with open(self._rates_file, 'w') as f: for pop in self._graph.internal_populations: if pop.record: for time, rate in zip(pop.dipde_obj.t_record, pop.dipde_obj.firing_rate_record): f.write('{} {} {}\n'.format(pop.pop_id, time, rate))
[docs] @classmethod def from_config(cls, configure, graph): # load the json file or object if isinstance(configure, string_types): config = cfg.from_json(configure, validate=True) elif isinstance(configure, dict): config = configure else: raise Exception('Could not convert {} (type "{}") to json.'.format(configure, type(configure))) if 'run' not in config: raise Exception('Json file is missing "run" entry. Unable to build Bionetwork.') run_dict = config['run'] # Get network parameters # step time (dt) is set in the kernel and should be passed overwrite = run_dict['overwrite_output_dir'] if 'overwrite_output_dir' in run_dict else True print_time = run_dict['print_time'] if 'print_time' in run_dict else False dt = run_dict['dt'] # TODO: make sure dt exists tstop = float(config.tstop) / 1000.0 network = cls(graph, dt=config.dt, tstop=tstop, overwrite=overwrite) if 'output_dir' in config['output']: network.output_dir = config['output']['output_dir'] # network.spikes_file = config['output']['spikes_ascii'] if 'block_run' in run_dict and run_dict['block_run']: if 'block_size' not in run_dict: raise Exception('"block_run" is set to True but "block_size" not found.') network._block_size = run_dict['block_size'] if 'duration' in run_dict: network.duration = run_dict['duration'] graph.io.log_info('Building cells.') graph.build_nodes() graph.io.log_info('Building recurrent connections') graph.build_recurrent_edges() for sim_input in inputs.from_config(config): node_set = graph.get_node_set(sim_input.node_set) if sim_input.input_type == 'spikes': path = sim_input.params['input_file'] spikes = SpikeTrains.load(path=path, file_type=sim_input.module, **sim_input.params) graph.io.log_info('Build virtual cell stimulations for {}'.format(sim_input.name)) graph.add_spike_trains(spikes, node_set) else: graph.io.log_info('Build virtual cell stimulations for {}'.format(sim_input.name)) rates = firing_rates.RatesInput(sim_input.params) graph.add_rates(rates, node_set) # Create the output file if 'output' in config: out_dict = config['output'] rates_file = out_dict.get('rates_file', None) if rates_file is not None: rates_file = rates_file if os.path.isabs(rates_file) else os.path.join(config.output_dir, rates_file) # create directory if required network.rates_file = rates_file parent_dir = os.path.dirname(rates_file) if not os.path.exists(parent_dir): os.makedirs(parent_dir) if 'log_file' in out_dict: log_file = out_dict['log_file'] network.set_logging(log_file) graph.io.log_info('Network created.') return network