Source code for bmtk.simulator.pointnet.gids

from collections import namedtuple
import pandas as pd
import nest

from .nest_utils import nest_version

PopulationID = namedtuple('PopulationID', 'node_id population')


[docs]def ids2list_nest2(nest_ids): return nest_ids
[docs]def ids2list_nest3(nest_ids): if isinstance(nest_ids, nest.NodeCollection): return nest_ids.tolist() else: return nest_ids
ids2list = ids2list_nest3 if nest_version[0] >= 3 else ids2list_nest2
[docs]class GidPool(object): def __init__(self): # self._popid2gid = {} # (pop_name, node_id) --> nest_id self._gid2pop_id = {} # nest_id --> (pop_name, node_id) self._nestid_lu = {} @property def gids(self): return list(self._gid2pop_id.keys()) @property def populations(self): return list(self._nestid_lu.keys())
[docs] def add(self, name, node_id, gid): raise NotImplementedError()
[docs] def get_gid(self, name, node_id): return self.get_nestids(name=name, node_ids=[node_id])[0]
[docs] def get_pool_id(self, gid): return self._gid2pop_id[gid]
[docs] def create_pool(self, name): pass
[docs] def add_nestids(self, name, node_ids, nest_ids): # in NEST 3.0+ nest.Create() returns a NodeCollection instead of a list of ids, need to convert nest_ids = ids2list(nest_ids) if name not in self._nestid_lu: lu_table = pd.DataFrame({'nest_ids': nest_ids, 'node_ids': node_ids}) lu_table = lu_table.set_index('node_ids') else: new_df = pd.DataFrame({'nest_ids': nest_ids, 'node_ids': node_ids}) new_df = new_df.set_index('node_ids') lu_table = self._nestid_lu[name] lu_table = pd.concat((lu_table, new_df)) # lu_table = lu_table.reindex(lu_table.index.values) self._nestid_lu[name] = lu_table for node_id, nest_id in zip(node_ids, nest_ids): self._gid2pop_id[nest_id] = PopulationID(population=name, node_id=node_id)
[docs] def add_gids(self, name, node_ids, gids): self.add_nestids(name=name, node_ids=node_ids, nest_ids=gids)
[docs] def get_nestids(self, name, node_ids): nestids_table = self._nestid_lu[name] return nestids_table.loc[node_ids]['nest_ids'].values
[docs] def get_gids(self, name, node_ids): return self.get_nestids(name=name, node_ids=node_ids)
def __len__(self): return len(self.gids)