import numpy as np
from .kernel import Kernel3D
[docs]
class SpatioTemporalFilter(object):
def __init__(self, spatial_filter, temporal_filter, amplitude=1.):
self.spatial_filter = spatial_filter
self.temporal_filter = temporal_filter
self.amplitude = amplitude
[docs]
def get_spatiotemporal_kernel(self, row_range, col_range, t_range=None, threshold=0, reverse=False):
# TODO: Rename to get_kernel() to match with spatialfilter and temporalfilter
spatial_kernel = self.spatial_filter.get_kernel(row_range, col_range, threshold=0)
temporal_kernel = self.temporal_filter.get_kernel(t_range=t_range, threshold=0, reverse=reverse)
t_range = temporal_kernel.t_range
spatiotemporal_kernel = np.ones((len(temporal_kernel), len(spatial_kernel)))
spatiotemporal_kernel *= spatial_kernel.kernel[None, :]
spatiotemporal_kernel *= temporal_kernel.kernel[:, None]
spatiotemporal_kernel = spatiotemporal_kernel.reshape((np.prod(spatiotemporal_kernel.shape)))
spatial_coord_array = np.empty((len(spatial_kernel), 2))
spatial_coord_array[:, 0] = spatial_kernel.col_inds
spatial_coord_array[:, 1] = spatial_kernel.row_inds
spatiiotemporal_coord_array = np.zeros((len(spatial_kernel)*len(temporal_kernel), 3))
spatiiotemporal_coord_array[:, 0:2] = np.kron(np.ones((len(temporal_kernel), 1)), spatial_coord_array)
spatiiotemporal_coord_array[:, 2] = np.kron(temporal_kernel.t_inds, np.ones(len(spatial_kernel)))
col_inds, row_inds, t_inds = map(lambda x: x.astype(np.int64), spatiiotemporal_coord_array.T)
kernel = Kernel3D(spatial_kernel.row_range, spatial_kernel.col_range, t_range, row_inds, col_inds, t_inds,
spatiotemporal_kernel)
kernel.apply_threshold(threshold)
kernel.kernel *= self.amplitude
return kernel
[docs]
def t_slice(self, t, *args, **kwargs):
k = self.get_spatiotemporal_kernel(*args, **kwargs)
return k.t_slice(t)
[docs]
def show_temporal_filter(self, *args, **kwargs):
self.temporal_filter.imshow(*args, **kwargs)
[docs]
def show_spatial_filter(self, *args, **kwargs):
self.spatial_filter.imshow(*args, **kwargs)
[docs]
def to_dict(self):
return {'class': (__name__, self.__class__.__name__),
'spatial_filter': self.spatial_filter.to_dict(),
'temporal_filter': self.temporal_filter.to_dict(),
'amplitude': self.amplitude}
[docs]
class SpectroTemporalFilter(object):
def __init__(self, spectrotemporal_filter, amplitude=1.):
self.spectrotemporal_filter = spectrotemporal_filter
self.amplitude = amplitude
[docs]
def get_spectrotemporal_kernel(self, freq_range, t_range, threshold=0, reverse=False):
spectrotemporal_kernel = self.spectrotemporal_filter.get_kernel(freq_range, t_range, threshold_rel = threshold)
if reverse:
t_range = -np.array(t_range)[::-1]
t_inds = -1 * spectrotemporal_kernel.col_inds - 1
else:
t_range = np.array(t_range)
t_inds = spectrotemporal_kernel.col_inds
# Keep it 3D to keep downstream consistent
# spectrotemporal_kernel.col_range is put in t_range
# frequencies are put in row_range
kernel = Kernel3D(spectrotemporal_kernel.row_range, [0], t_range,
spectrotemporal_kernel.row_inds, np.zeros_like(spectrotemporal_kernel.row_inds),
t_inds, spectrotemporal_kernel.kernel)
#kernel.apply_threshold(threshold)
#kernel.kernel *= self.amplitude
return kernel
[docs]
def t_slice(self, t, *args, **kwargs):
k = self.get_spatiotemporal_kernel(*args, **kwargs)
return k.t_slice(t)
[docs]
def show_temporal_filter(self, *args, **kwargs):
self.temporal_filter.imshow(*args, **kwargs)
[docs]
def show_spatial_filter(self, *args, **kwargs):
self.spatial_filter.imshow(*args, **kwargs)
[docs]
def to_dict(self):
return {'class': (__name__, self.__class__.__name__),
'spatial_filter': self.spatial_filter.to_dict(),
'temporal_filter': self.temporal_filter.to_dict(),
'amplitude': self.amplitude}