import matplotlib as mpl
from matplotlib import cm
import numpy as np
import matplotlib.pyplot as plt
[docs]
def find_l_r_in_t_range(t_range, t):
for tl in range(len(t_range)-1):
tr = tl+1
test_val = (t_range[tl]-t)*(t_range[tr]-t)
if np.abs(test_val) < 1e-16:
if np.abs(t_range[tl]-t) < 1e-16:
return (tl,)
else:
return (tr,)
elif test_val < 0:
t_range[tl], t_range[tr], t
return tl, tr
[docs]
def get_contour(X, Y, Z, c):
contour_obj = plt.contour(X, Y, Z)
res = contour_obj.trace(c)
nseg = len(res) // 2
if nseg > 0:
seg = res[:nseg][0]
return seg[:, 0], seg[:, 1]
else:
return [], []
[docs]
def plot_single_contour(ax, x_contour, y_contour, t, color):
t_contour = t+np.zeros_like(x_contour)
ax.plot(x_contour, t_contour, y_contour, zdir='z', color=color)
[docs]
class Kernel1D(object):
def __init__(self, t_range, kernel_array, threshold=0., reverse=False):
assert len(t_range) == len(kernel_array)
kernel_array = np.array(kernel_array)
inds_to_keep = np.where(np.abs(kernel_array) > threshold)
if reverse:
self.t_range = -np.array(t_range)[::-1]
t_inds_tmp = inds_to_keep[0]
max_t_ind = t_inds_tmp.max()
reversed_t_inds = max_t_ind - t_inds_tmp
self.t_inds = reversed_t_inds - max_t_ind - 1 # Had an off by one error here should be "- 1" nhc 14 Apr '17 change made in cursor evalutiate too
else:
self.t_range = np.array(t_range)
self.t_inds = inds_to_keep[0]
self.kernel = kernel_array[inds_to_keep]
assert len(self.t_inds) == len(self.kernel)
[docs]
def rescale(self):
if np.abs(self.kernel.sum())!=0:
self.kernel /= np.abs(self.kernel.sum())
[docs]
def normalize(self):
self.kernel /= np.abs(self.kernel.sum())
def __len__(self):
return len(self.kernel)
[docs]
def imshow(self, ax=None, show=True, save_file_name=None, ylim=None, xlim=None, color='b', reverse=True):
if ax is None:
_, ax = plt.subplots(1, 1)
t_vals = self.t_range[self.t_inds]
kernel_data = self.kernel
if reverse:
kernel_data = self.kernel[-1::-1]
ax.plot(t_vals, kernel_data, color)
ax.set_xlabel('Time (Seconds)')
if ylim is not None:
ax.set_ylim(ylim)
if xlim is not None:
ax.set_xlim(xlim)
else:
a, b = (t_vals[0], t_vals[-1])
ax.set_xlim(min(a, b), max(a, b))
if save_file_name is not None:
ax.savefig(save_file_name, transparent=True)
if show:
plt.show()
return ax, (t_vals, self.kernel)
[docs]
def full(self, truncate_t=True):
data = np.zeros(len(self.t_range))
data[self.t_inds] = self.kernel
if truncate_t:
ind_min = np.where(np.abs(data) > 0)[0].min()
return data[ind_min:]
else:
return data
return data
[docs]
class Kernel2D(object):
def __init__(self, row_range, col_range, row_inds, col_inds, kernel):
self.col_range = np.array(col_range)
self.row_range = np.array(row_range)
self.row_inds = np.array(row_inds)
self.col_inds = np.array(col_inds)
self.kernel = np.array(kernel)
assert len(self.row_inds) == len(self.col_inds)
assert len(self.row_inds) == len(self.kernel)
[docs]
def rescale(self):
if np.abs(self.kernel.sum()) != 0:
self.kernel /= np.abs(self.kernel.sum())
[docs]
def normalize(self):
self.kernel /= np.abs(self.kernel.sum())
[docs]
def normalize2(self, remove_offset=True):
# Better for kernels that are not all positive
if remove_offset:
self.kernel -= self.kernel.mean() # Set amplitude offset to 0
size = np.sum(np.abs(self.kernel))
self.kernel /= size # Normalize overall size and maximum output
[docs]
@classmethod
def from_dense(cls, row_range, col_range, kernel_array, threshold=0.):
col_range = np.array(col_range).copy()
row_range = np.array(row_range).copy()
kernel_array = np.array(kernel_array).copy()
#inds_to_keep = np.where(np.abs(kernel_array) > threshold)
# Find cropped contiguous rectangle containing above threshold kernel values
above_thresh = np.abs(kernel_array) > threshold
start_ind0 = np.argmax(np.max(above_thresh, axis=1))
b = above_thresh[::-1,:]
end_ind0 = b.shape[0] - np.argmax(np.max(b,axis=1))
start_ind1 = np.argmax(np.max(above_thresh, axis=0))
b = above_thresh[:,::-1]
end_ind1 = b.shape[1] - np.argmax(np.max(b, axis=0))
#kernel = kernel_array[inds_to_keep]
col_inds, row_inds = [v.flatten() for v in
np.meshgrid(range(start_ind0, end_ind0), range(start_ind1, end_ind1), indexing='ij')]
kernel = kernel_array[col_inds, row_inds]
if len(np.where(above_thresh)) == 1:
col_inds, row_inds = np.array([]), np.array([])
return cls(row_range, col_range, row_inds, col_inds, kernel)
[docs]
@classmethod
def copy(cls, instance):
return cls(instance.row_range.copy(),
instance.col_range.copy(),
instance.row_inds.copy(),
instance.col_inds.copy(),
instance.kernel.copy())
def __mul__(self, constant):
new_copy = Kernel2D.copy(self)
new_copy.kernel *= constant
return new_copy
def __add__(self, other):
if len(other) == 0:
return self
try:
np.testing.assert_almost_equal(self.row_range, other.row_range)
np.testing.assert_almost_equal(self.col_range, other.col_range)
except:
raise Exception('Kernels must exist on same grid to be added')
row_range = self.row_range.copy()
col_range = self.col_range.copy()
kernel_dict = {}
for key, ker in zip(zip(self.row_inds, self.col_inds), self.kernel):
kernel_dict[key] = kernel_dict.setdefault(key, 0) + ker
for key, ker in zip(zip(other.row_inds, other.col_inds), other.kernel):
kernel_dict[key] = kernel_dict.setdefault(key, 0) + ker
key_list, kernel_list = zip(*kernel_dict.items())
row_inds_list, col_inds_list = zip(*key_list)
row_inds = np.array(row_inds_list)
col_inds = np.array(col_inds_list)
kernel = np.array(kernel_list)
return Kernel2D(row_range, col_range, row_inds, col_inds, kernel)
[docs]
def apply_threshold(self, threshold):
inds_to_keep = np.where(np.abs(self.kernel) > threshold)
self.row_inds = self.row_inds[inds_to_keep]
self.col_inds = self.col_inds[inds_to_keep]
self.kernel = self.kernel[inds_to_keep]
[docs]
def full(self, truncate_col=False):
data = np.zeros((len(self.row_range), len(self.col_range)))
data[self.row_inds, self.col_inds] = self.kernel
if truncate_col: # For spectrotemporal receptive fields where col is time dimension
ind_max = np.max(self.col_inds)
return data[:, :ind_max]
else:
return data
[docs]
def imshow(self, ax=None, show=True, save_file_name=None, clim=None, colorbar=True, truncate_col=False, xlabel=None,
ylabel=None):
from mpl_toolkits.axes_grid1 import make_axes_locatable
if ax is None:
_, ax = plt.subplots(1, 1)
if colorbar:
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
data = self.full(truncate_col=truncate_col)
if truncate_col:
col_max = self.col_range[np.max(self.col_inds)]
else:
col_max = self.col_range[-1]
if clim is not None:
im = ax.imshow(data, extent=(self.col_range[0], col_max, np.squeeze(self.row_range[0]),
np.squeeze(self.row_range[-1])), origin='lower', clim=clim, interpolation='none')
else:
im = ax.imshow(data, extent=(self.col_range[0], col_max, np.squeeze(self.row_range[0]),
np.squeeze(self.row_range[-1])), origin='lower', interpolation='none',
aspect='auto')
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if colorbar:
plt.colorbar(im, cax=cax)
if save_file_name is not None:
plt.savefig(save_file_name, transparent=True)
if show:
plt.show()
return ax, data
def __len__(self):
return len(self.kernel)
[docs]
class Kernel3D(object):
[docs]
def rescale(self):
if np.abs(self.kernel.sum()) != 0:
self.kernel /= np.abs(self.kernel.sum())
[docs]
def normalize(self):
self.kernel /= (self.kernel.sum())*np.sign(self.kernel.sum())
[docs]
@classmethod
def copy(cls, instance):
return cls(instance.row_range.copy(),
instance.col_range.copy(),
instance.t_range.copy(),
instance.row_inds.copy(),
instance.col_inds.copy(),
instance.t_inds.copy(),
instance.kernel.copy())
def __len__(self):
return len(self.kernel)
def __init__(self, row_range, col_range, t_range, row_inds, col_inds, t_inds, kernel):
self.col_range = np.array(col_range)
self.row_range = np.array(row_range)
self.t_range = np.array(t_range)
self.col_inds = np.array(col_inds)
self.row_inds = np.array(row_inds)
self.t_inds = np.array(t_inds)
self.kernel = np.array(kernel)
assert len(self.row_inds) == len(self.col_inds)
assert len(self.row_inds) == len(self.t_inds)
assert len(self.row_inds) == len(self.kernel)
[docs]
def apply_threshold(self, threshold):
inds_to_keep = np.where(np.abs(self.kernel) > threshold)
self.row_inds = self.row_inds[inds_to_keep]
self.col_inds = self.col_inds[inds_to_keep]
self.t_inds = self.t_inds[inds_to_keep]
self.kernel = self.kernel[inds_to_keep]
def __add__(self, other):
if len(other) == 0:
return self
try:
if not (len(self.row_range) == 0 or len(other.row_range) == 0):
np.testing.assert_almost_equal(self.row_range, other.row_range)
if not (len(self.col_range) == 0 or len(other.col_range) == 0):
np.testing.assert_almost_equal(self.col_range, other.col_range)
if not (len(self.t_range) == 0 or len(other.t_range) == 0):
np.testing.assert_almost_equal(self.t_range, other.t_range)
except:
raise Exception('Kernels must exist on same grid to be added')
if len(self.row_range) == 0:
row_range = other.row_range.copy()
else:
row_range = self.row_range.copy()
if len(self.col_range) == 0:
col_range = other.col_range.copy()
else:
col_range = self.col_range.copy()
if len(self.t_range) == 0:
t_range = other.t_range.copy()
else:
t_range = self.t_range.copy()
kernel_dict = {}
for key, ker in zip(zip(self.row_inds, self.col_inds, self.t_inds), self.kernel):
kernel_dict[key] = kernel_dict.setdefault(key, 0) + ker
for key, ker in zip(zip(other.row_inds, other.col_inds, other.t_inds), other.kernel):
kernel_dict[key] = kernel_dict.setdefault(key, 0) + ker
key_list, kernel_list = zip(*kernel_dict.items())
row_inds_list, col_inds_list, t_inds_list = zip(*key_list)
row_inds = np.array(row_inds_list)
col_inds = np.array(col_inds_list)
t_inds = np.array(t_inds_list)
kernel = np.array(kernel_list)
return Kernel3D(row_range, col_range, t_range, row_inds, col_inds, t_inds, kernel)
def __mul__(self, constant):
new_copy = Kernel3D.copy(self)
new_copy.kernel *= constant
return new_copy
[docs]
def t_slice(self, t):
ind_list = find_l_r_in_t_range(self.t_range, t)
if ind_list is None:
return None
elif len(ind_list) == 1:
t_ind_i = ind_list[0]
inds_i = np.where(self.t_range[self.t_inds] == self.t_range[t_ind_i])
row_inds = self.row_inds[inds_i]
col_inds = self.col_inds[inds_i]
kernel = self.kernel[inds_i]
return Kernel2D(self.row_range, self.col_range, row_inds, col_inds, kernel)
else:
t_ind_l, t_ind_r = ind_list
t_l, t_r = self.t_range[t_ind_l], self.t_range[t_ind_r]
inds_l = np.where(self.t_range[self.t_inds] == self.t_range[t_ind_l])
inds_r = np.where(self.t_range[self.t_inds] == self.t_range[t_ind_r])
row_inds_l = self.row_inds[inds_l]
col_inds_l = self.col_inds[inds_l]
kernel_l = self.kernel[inds_l]
kl = Kernel2D(self.row_range, self.col_range, row_inds_l, col_inds_l, kernel_l)
row_inds_r = self.row_inds[inds_r]
col_inds_r = self.col_inds[inds_r]
kernel_r = self.kernel[inds_r]
kr = Kernel2D(self.row_range, self.col_range, row_inds_r, col_inds_r, kernel_r)
wa, wb = (1-(t-t_l)/(t_r-t_l)), (1-(t_r-t)/(t_r-t_l))
return kl*wa + kr*wb
[docs]
def full(self, truncate_t=True):
data = np.zeros((len(self.t_range), len(self.row_range), len(self.col_range)))
data[self.t_inds, self.row_inds, self.col_inds] = self.kernel
if truncate_t:
ind_max = np.where(np.abs(data) > 0)[0].min()
return data[ind_max:, :, :]
else:
return data
[docs]
def imshow(self, ax=None, t_range=None, cmap=cm.bwr, N=10, show=True, save_file_name=None, kvals=None):
if ax is None:
fig = plt.figure()
ax = fig.gca(projection='3d')
if t_range is None:
t_range = self.t_range
slice_list_sparse = [self.t_slice(t) for t in t_range]
slice_list = []
slice_t_list = []
for curr_slice, curr_t in zip(slice_list_sparse, t_range):
if not curr_slice is None:
slice_list.append(curr_slice.full())
slice_t_list.append(curr_t)
all_slice_max = max(map(np.max, slice_list))
all_slice_min = min(map(np.min, slice_list))
upper_bound = max(np.abs(all_slice_max), np.abs(all_slice_min))
lower_bound = -upper_bound
norm = mpl.colors.Normalize(vmin=lower_bound, vmax=upper_bound)
color_mapper = cm.ScalarMappable(norm=norm, cmap=cmap).to_rgba
if kvals is None:
kvals = np.linspace(lower_bound, upper_bound, N)
X, Y = np.meshgrid(self.row_range, self.col_range)
contour_dict = {}
for kval in kvals:
for t_val, curr_slice in zip(slice_t_list, slice_list):
x_contour, y_contour = get_contour(Y, X, curr_slice.T, kval)
contour_dict[kval, t_val] = x_contour, y_contour
color = color_mapper(kval)
color = color[0], color[1], color[2], np.abs(kval)/upper_bound
plot_single_contour(ax, x_contour, y_contour, t_val, color)
ax.set_zlim(self.row_range[0], self.row_range[-1])
ax.set_ylim(self.t_range[0], self.t_range[-1])
ax.set_xlim(self.col_range[0], self.col_range[-1])
if save_file_name is not None:
plt.savefig(save_file_name, transparent=True)
if show:
plt.show()
return ax, contour_dict
[docs]
def merge_spatial_temporal(spatial_kernel, temporal_kernel, threshold=0):
t_range = temporal_kernel.t_range
spatiotemporal_kernel = np.ones((len(temporal_kernel), len(spatial_kernel)))
spatiotemporal_kernel *= spatial_kernel.kernel[None, :]
spatiotemporal_kernel *= temporal_kernel.kernel[:, None]
spatiotemporal_kernel = spatiotemporal_kernel.reshape((np.prod(spatiotemporal_kernel.shape)))
spatial_coord_array = np.empty((len(spatial_kernel), 2))
spatial_coord_array[:, 0] = spatial_kernel.col_inds
spatial_coord_array[:, 1] = spatial_kernel.row_inds
spatiiotemporal_coord_array = np.zeros((len(spatial_kernel)*len(temporal_kernel), 3))
spatiiotemporal_coord_array[:, 0:2] = np.kron(np.ones((len(temporal_kernel), 1)), spatial_coord_array)
spatiiotemporal_coord_array[:, 2] = np.kron(temporal_kernel.t_inds, np.ones(len(spatial_kernel)))
col_inds, row_inds, t_inds = map(lambda x: x.astype(np.int), spatiiotemporal_coord_array.T)
kernel = Kernel3D(spatial_kernel.row_range, spatial_kernel.col_range, t_range, row_inds, col_inds, t_inds,
spatiotemporal_kernel)
kernel.apply_threshold(threshold)
return kernel