Source code for bmtk.simulator.filternet.lgnmodel.spatialfilter

import ast
import numpy as np
import itertools
from six import string_types
from scipy import ndimage

from . import utilities as util
from .kernel import Kernel2D


[docs] class ArrayFilter(object): def __init__(self, mask): self.mask = mask
[docs] def imshow(self, row_range, col_range, threshold=0, **kwargs): return self.get_kernel(row_range, col_range, threshold).imshow(**kwargs)
[docs] def get_kernel(self, row_range, col_range, threshold=0, amplitude=1.): row_vals, col_vals = np.where(self.mask > threshold) kernel_vals = self.mask[row_vals, col_vals] kernel_vals = amplitude*kernel_vals/kernel_vals.sum() return Kernel2D(row_range, col_range, row_vals, col_vals, kernel_vals)
[docs] class GaussianSpatialFilter(object): def __init__(self, translate=(0.0, 0.0), sigma=(1.0, 1.0), rotation=0, origin='center'): """A 2D gaussian used for filtering a part of the receptive field. :param translate: (float, float), the location of the gaussian on the screen relative to the origin in pixels. :param sigma: (float, float), the x and y gaussian std :param rotation: rotation of the gaussian in degrees :param origin: origin of the receptive field (defaults center of image) """ # When w=1 and rotation=0, half-height will be at y=1 self.translate = translate self.rotation = rotation if isinstance(sigma, string_types): # TODO: Move this to calling method try: sigma = ast.literal_eval(sigma) except Exception as exc: pass self.sigma = sigma self.origin = origin
[docs] def imshow(self, row_range, col_range, threshold=0, **kwargs): return self.get_kernel(row_range, col_range, threshold).imshow(**kwargs)
[docs] def to_dict(self): return { 'class': (__name__, self.__class__.__name__), 'translate': self.translate, 'rotation': self.rotation, 'sigma': self.sigma }
[docs] def get_kernel(self, row_range, col_range, threshold=0, amplitude=1.0): """Creates a 2D gaussian filter (kernel) for the given dimensions which can be used :param row_range: field height in pixels :param col_range: field width in pixels :param threshold: :param amplitude: :return: A Kernel2D object """ # Create symmetric initial point at center: image_shape = len(col_range), len(row_range) h, w = image_shape on_filter_spatial = np.zeros(image_shape) if h % 2 == 0 and w % 2 == 0: for ii, jj in itertools.product(range(2), range(2)): on_filter_spatial[int(h/2)+ii-1, int(w/2)+jj-1] = .25 elif h % 2 == 0 and w % 2 != 0: for ii in range(2): on_filter_spatial[int(h/2)+ii-1, int(w/2)] = .25 elif h % 2 != 0 and w % 2 == 0: for jj in range(2): on_filter_spatial[int(h/2), int(w/2)+jj-1] = .25 else: on_filter_spatial[int(h/2), int(w/2)] = .25 # Apply gaussian filter to create correct sigma: scaled_sigma_x = float(self.sigma[0]) / (col_range[1]-col_range[0]) scaled_sigma_y = float(self.sigma[1]) / (row_range[1]-row_range[0]) on_filter_spatial = ndimage.gaussian_filter(on_filter_spatial, (scaled_sigma_x, scaled_sigma_y), mode='nearest', cval=0) # Rotate and translate gaussian at center: rotation_matrix = util.get_rotation_matrix(self.rotation, on_filter_spatial.shape) translation_x = float(self.translate[1])/(row_range[1] - row_range[0]) translation_y = -float(self.translate[0])/(col_range[1] - col_range[0]) translation_matrix = util.get_translation_matrix((translation_x, translation_y)) if self.origin != 'center': center_y = -(self.origin[0] - (col_range[-1] + col_range[0])/2)/(col_range[1] - col_range[0]) center_x = (self.origin[1] - (row_range[-1] + row_range[0])/2)/(row_range[1] - row_range[0]) translation_matrix += util.get_translation_matrix((center_x, center_y)) kernel_data = util.apply_transformation_matrix(on_filter_spatial, translation_matrix + rotation_matrix) kernel = Kernel2D.from_dense(row_range, col_range, kernel_data, threshold=0) kernel.apply_threshold(threshold) kernel.normalize() kernel.kernel *= amplitude #kernel.imshow() return kernel