Source code for allensdk.brain_observatory.observatory_plots

# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2017. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact terms@alleninstitute.org for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.collections import PatchCollection
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import scipy.interpolate as si
from scipy.stats import gaussian_kde
import matplotlib.colorbar as cbar
from mpl_toolkits.axes_grid1 import ImageGrid

import allensdk.brain_observatory.circle_plots as cplots
from contextlib import contextmanager

import numpy as np

SI_RANGE = [ 0, 1.5 ]
P_VALUE_MAX = 0.05
PEAK_DFF_MIN = 3
N_HIST_BINS = 50
STIM_COLOR = "#ccccdd"
STIMULUS_COLOR_MAP = LinearSegmentedColormap.from_list('default',[ [1.0,1.0,1.0,0.0], [.6,.6,.85,1.0] ])
PUPIL_COLOR_MAP = LinearSegmentedColormap.from_list(
    'custom_plasma', [[0.050383, 0.029803, 0.527975],
                      [0.417642, 0.000564, 0.658390],
                      [0.692840, 0.165141, 0.564522],
                      [0.881443, 0.392529, 0.383229],
                      [0.988260, 0.652325, 0.211364],
                      [0.940015, 0.975158, 0.131326]])
EVOKED_COLOR = "#b30000"
SPONTANEOUS_COLOR = "#0000b3"

[docs]def plot_cell_correlation(sig_corrs, labels, colors, scale=15): if len(sig_corrs) > 1: alpha = 1.0 / (len(sig_corrs) + 1) else: alpha = 1.0 ax = plt.gca() ps = [] for sig_corr, color, label in zip(sig_corrs, colors, labels): ax.hist(sig_corr, bins=30, range=[-1,1], histtype='stepfilled', facecolor=(.6,.6,.6,alpha), edgecolor=color, linewidth=1.5, label=label) ax.set_xlabel("signal correlation") ax.set_ylabel("cell count") ax.xaxis.grid(True) leg = ax.legend(loc='upper left', frameon=False) for i, t in enumerate(leg.get_texts()): t.set_color(colors[i]) plt.text(.125, .5, u'\u2014', transform=ax.transAxes, horizontalalignment='center', verticalalignment='center', weight='bold', size='xx-large') plt.text(.875, .5, '+', transform=ax.transAxes, horizontalalignment='center', verticalalignment='center', weight='bold', size='xx-large')
[docs]def population_correlation_scatter(sig_corrs, noise_corrs, labels, colors, scale=15): alpha = max(0.85 - 0.15 * (len(sig_corrs)-1), 0.2) ax = plt.gca() for sig_corr, noise_corr, color, label in zip(sig_corrs, noise_corrs, colors, labels): inds = np.tril_indices(len(sig_corr)) ax.scatter(sig_corr[inds], noise_corr[inds], s=scale, color=color, linewidth=0.5, edgecolor='#333333', label=label, alpha=alpha) ax.set_xlabel("signal correlation") ax.set_ylabel("noise correlation") ax.set_xlim([-1,1]) ax.set_ylim([-1,1]) leg = ax.legend(loc='upper left', frameon=False) for i, t in enumerate(leg.get_texts()): t.set_color(colors[i])
[docs]def plot_mask_outline(mask, ax, color='k'): pim = np.pad(mask, 1, 'constant', constant_values=(0,0)) hedges = np.argwhere(np.diff(pim, axis=0)) vedges = np.argwhere(np.diff(pim, axis=1)) hlines = [ [ [r-.5, c-1.5], [r-.5, c-.5] ] for r,c in hedges ] vlines = [ [ [r-1.5, c-.5], [r-.5, c-.5] ] for r,c in vedges ] for p1,p2 in hlines + vlines: ax.add_line(mlines.Line2D([ p1[1], p2[1] ], [ p1[0], p2[0] ], linewidth=3, color=color, clip_on=False))
[docs]class DimensionPatchHandler(object): def __init__(self, vals, start_color, end_color, *args, **kwargs): super(DimensionPatchHandler, self).__init__(*args, **kwargs) self.vals = vals self.start_color = start_color self.end_color = end_color
[docs] def legend_artist(self, legend, orig_handle, fontsize, handlebox): x0, y0 = handlebox.xdescent, handlebox.ydescent width, height = handlebox.width, handlebox.height num_vals = len(self.vals) sub_width = float(width) / num_vals x = x0 for i in range(len(self.vals)): rgb = self.dim_color(i) r = mpatches.Rectangle((x+i*sub_width, y0), sub_width, y0+height, facecolor=rgb, linewidth=0) r.set_clip_on(False) handlebox.add_artist(r) return r
[docs] def dim_color(self, index): rgb1 = np.array(mcolors.colorConverter.to_rgb(self.start_color)) rgb2 = np.array(mcolors.colorConverter.to_rgb(self.end_color)) t = float(index) / (len(self.vals)+1) rgb = t * rgb2 + (1.0 - t) * rgb1 return rgb
[docs]def float_label(n): if isinstance(n, int): return str(n) if n.is_integer(): return str(int(n)) else: return "%.2f" % n
[docs]def plot_representational_similarity(rs, dims=None, dim_labels=None, colors=None, dim_order=None, labels=True): if dim_order is not None: rsr = np.arange(len(rs)).reshape(*map(len,dims)) rsrt = rsr.transpose(dim_order) ri = rsrt.flatten() rs = rs[ri,:][:,ri] dims = np.array(dims)[dim_order] colors = np.array(colors)[dim_order] dim_labels = np.array(dim_labels)[dim_order] # force the color map to be centered at zero clim = np.nanpercentile(rs, [5.0,95.0], axis=None) vrange = max(abs(clim[0]), abs(clim[1])) rs = rs.copy() np.fill_diagonal(rs, np.nan) if labels: grid = ImageGrid(plt.gcf(), 111, nrows_ncols=(1,1), cbar_location="right", cbar_mode="single", cbar_size="7%", cbar_pad=0.05) for ax in grid: pass else: ax = plt.gca() im = ax.imshow(rs, interpolation='nearest', cmap='RdBu_r', vmin=-vrange, vmax=vrange) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_xticks([]) ax.set_yticks([]) if labels: cbar = ax.cax.colorbar(im) cbar.set_label_text('stimulus correlation') if dims is not None: dim_labels = ["%s(%s)" % (dim_labels[i],', '.join(map(float_label, dims[i].tolist()))) for i in range(len(dims)) ] dim_handlers = [ DimensionPatchHandler(dims[i], colors[i], 'w') for i in range(len(dims)) ] n = len(rs) for cell_i in range(n): idx = np.unravel_index(cell_i, map(len, dims)) start = -(len(dims))*2 width = 1.8 for dim_i, color in enumerate(colors): v_i = idx[dim_i] rgb = dim_handlers[dim_i].dim_color(v_i) r = mpatches.Rectangle((start + dim_i * width, cell_i-.5), width, 1.2, facecolor=rgb, linewidth=0) r.set_clip_on(False) ax.add_patch(r) r = mpatches.Rectangle((cell_i-.5, start + dim_i * width), 1.2, width, facecolor=rgb, linewidth=0) r.set_clip_on(False) ax.add_patch(r) if labels: patches = [ mpatches.Patch(label=dim_labels[i]) for i in range(len(dims)) ] ax.legend(handles=patches, handler_map=dict(zip(patches,dim_handlers)), loc='upper left', bbox_to_anchor=(0,0), ncol=2, fontsize=9, frameon=False) if labels: plt.subplots_adjust(left=0.07, right=.88, wspace=0.0, hspace=0.0)
[docs]def plot_condition_histogram(vals, bins, color=STIM_COLOR): plt.grid() if len(vals) > 1: vals = [np.array(vals).flatten()] # matplotlib >= 2.1 needs this if len(vals) > 0: n, hbins, patches = plt.hist(vals, bins=np.arange(len(bins)+1)+1, align='left', normed=False, rwidth=.8, color=color, zorder=3) else: hbins = np.arange(len(bins)+1)+1 plt.xticks(hbins[:-1], np.round(bins, 2))
[docs]def plot_selectivity_cumulative_histogram(sis, xlabel, si_range=SI_RANGE, n_hist_bins=N_HIST_BINS, color=STIM_COLOR): if len(sis) > 1: sis = [np.array(sis).flatten()] # matplotlib >= 2.1 needs this bins = np.linspace(si_range[0], si_range[1], n_hist_bins) yticks = np.linspace(0,1,5) xticks = np.linspace(si_range[0], si_range[1], 4) yscale = 1.0 # this is for normalizing to total # cells, not just significant cells # yscale = float(num_cells) / len(osis) # orientation selectivity cumulative histogram if len(sis) > 0: n, bins, patches = plt.hist(sis, normed=True, bins=bins, cumulative=True, histtype='stepfilled', color=color) plt.xlim(si_range) plt.ylim([0,yscale]) plt.yticks(yticks*yscale, yticks) plt.xticks(xticks) plt.xlabel(xlabel) plt.ylabel("fraction of cells") plt.grid()
[docs]def plot_radial_histogram(angles, counts, all_angles=None, include_labels=False, offset=180.0, direction=-1, closed=False, color=STIM_COLOR): if all_angles is None: if len(angles) < 2: all_angles = np.linspace(0, 315, 8) else: all_angles = angles dth = (all_angles[1] - all_angles[0]) * 0.5 if len(counts) == 0: max_count = 1 else: max_count = max(counts) wedges = [] for count, angle in zip(counts, angles): angle = angle*direction + offset wedge = mpatches.Wedge((0,0), count, angle-dth, angle+dth) wedges.append(wedge) wedge_coll = PatchCollection(wedges) wedge_coll.set_facecolor(color) wedge_coll.set_zorder(2) angles_rad = (all_angles*direction + offset)*np.pi/180.0 if closed: border_coll = cplots.radial_circles([max_count]) else: border_coll = cplots.radial_arcs([max_count], min(angles_rad), max(angles_rad)) border_coll.set_facecolor((0,0,0,0)) border_coll.set_zorder(1) line_coll = cplots.angle_lines(angles_rad, 0, max_count) line_coll.set_edgecolor((0,0,0,1)) line_coll.set_linestyle(":") line_coll.set_zorder(1) ax = plt.gca() ax.add_collection(wedge_coll) ax.add_collection(border_coll) ax.add_collection(line_coll) if include_labels: cplots.add_angle_labels(ax, angles_rad, all_angles.astype(int), max_count, (0,0,0,1), offset=max_count*0.1) ax.set(xlim=(-max_count*1.2, max_count*1.2), ylim=(-max_count*1.2, max_count*1.2), aspect=1.0) else: ax.set(xlim=(-max_count*1.05, max_count*1.05), ylim=(-max_count*1.05, max_count*1.05), aspect=1.0)
[docs]def plot_time_to_peak(msrs, ttps, t_start, t_end, stim_start, stim_end, cmap): plt.plot(ttps, np.arange(msrs.shape[0],0,-1)-0.5, color='black') if msrs.shape[0] > 0: plt.imshow(msrs, cmap=cmap, clim=[0,3], aspect=float((t_end-t_start) / msrs.shape[0]), # float to get rid of MPL error extent=[t_start, t_end, 0, msrs.shape[0]], interpolation='nearest') plt.ylim([0,msrs.shape[0]]) else: plt.ylim([0, 1]) plt.xlim([t_start, t_end]) plt.axvline(stim_start, linestyle=':', color='black') plt.axvline(stim_end, linestyle=':', color='black') xticks = np.array([ t_start, stim_start, stim_end, t_end ]) plt.xticks(xticks, np.round(xticks - stim_start, 2)) plt.xlabel("time from stimulus start (s)") yticks, _ = plt.yticks() plt.ylabel("cell number")
[docs]@contextmanager def figure_in_px(w, h, file_name, dpi=96.0, transparent=False): fig = plt.figure(figsize=(w/dpi, h/dpi), dpi=dpi) yield fig plt.savefig(file_name, dpi=dpi, transparent=transparent) plt.close()
[docs]def finalize_no_axes(pad=0.0): plt.axis('off') plt.subplots_adjust(left=pad, right=1.0-pad, bottom=pad, top=1.0-pad, wspace=0.0, hspace=0.0)
[docs]def finalize_with_axes(pad=.3): plt.tight_layout(pad=pad)
[docs]def finalize_no_labels(pad=.3, legend=False): ax = plt.gca() ax.set_xlabel("") ax.set_ylabel("") ax.set_xticklabels([]) ax.set_yticklabels([]) if not legend and ax.legend_ is not None: ax.legend_.remove() plt.tight_layout(pad=pad)
[docs]def plot_combined_speed(binned_resp_vis, binned_dx_vis, binned_resp_sp, binned_dx_sp, evoked_color, spont_color): ax = plt.gca() num_bins = max(binned_dx_vis.shape[0], binned_dx_sp.shape[0]) plot_speed(binned_resp_vis, binned_dx_vis, num_bins, evoked_color) plot_speed(binned_resp_sp, binned_dx_sp, num_bins, spont_color) xmin = min(binned_dx_vis[:,0].min(), binned_dx_sp[:,0].min()) xmax = max(binned_dx_vis[:,0].max(), binned_dx_sp[:,0].max()) ymin = min(binned_resp_vis[:,0].min(), binned_resp_sp[:,0].min()) ymax = max(binned_resp_vis[:,0].max(), binned_resp_sp[:,0].max()) xpadding = (xmax-xmin)*.05 ypadding = (ymax-ymin)*.20 ax.set_xlim([xmin - xpadding, xmax + xpadding]) ax.set_ylim([ymin - ypadding, ymax + ypadding])
[docs]def plot_speed(binned_resp, binned_dx, num_bins, color): ax = plt.gca() # plot the zero bin as a dot with whiskers ax.errorbar([ binned_dx[0,0] ], [ binned_resp[0,0] ], yerr=[ binned_resp[0,1] ], fmt='o', color=color) # if there's only one bin, drop out if len(binned_dx[:,0]) <= 1: return f = si.interp1d(binned_dx[:,0], binned_resp[:,0]) x = np.linspace(min(binned_dx[:,0]), max(binned_dx[:,0]), num=num_bins, endpoint=True) y = f(x) f_up = si.interp1d(binned_dx[:,0], binned_resp[:,0] + binned_resp[:,1]) y_up = f_up(x) f_down = si.interp1d(binned_dx[:,0], binned_resp[:,0] - binned_resp[:,1]) y_down = f_down(x) ax.plot(x, y, color=color) ax.fill_between(x, y_down, y_up, facecolor=color, alpha=0.1)
[docs]def plot_receptive_field(rf, color_map=None, clim=None, mask=None, outline_color='#cccccc', scalebar=True): if mask is not None: rf = np.ma.array(rf, mask=~mask) if clim is None: clim = np.nanpercentile(rf, [1.0,99.0], axis=None) plt.imshow(rf, interpolation='nearest', cmap=color_map, clim=clim, origin='bottom') if mask is not None: plot_mask_outline(mask, plt.gca(), outline_color) if scalebar: scale_dims = np.array([ 28.0, 16.0 ]) scale_p = [ 26.8, 14.8 ] text_p = [ scale_p[0]+0.5, scale_p[1]-0.5 ] ax = plt.gca() ax.add_patch(mpatches.Rectangle(scale_p / scale_dims, 1.0/scale_dims[0], 1.0/scale_dims[1], facecolor='w', transform=ax.transAxes, linewidth=1.0, edgecolor=outline_color)) plt.text(text_p[0] / scale_dims[0], text_p[1] / scale_dims[1], "4deg", horizontalalignment='center', verticalalignment='center', transform=ax.transAxes)
[docs]def plot_pupil_location(xy_deg, s=1, c=None, cmap=PUPIL_COLOR_MAP, edgecolor='', include_labels=True): if c is None: xy_deg = xy_deg[~np.isnan(xy_deg).any(axis=1)] c = gaussian_kde(xy_deg.T)(xy_deg.T) plt.scatter(xy_deg[:,0], xy_deg[:,1], s=s, c=c, cmap=cmap, edgecolor=edgecolor) plt.xlim(-70, 70) plt.ylim(-70, 70) if include_labels: plt.xlabel("azimuth (degrees)") plt.ylabel("altitude (degrees)")