Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Visualizing Neuronal Unit Responses

Note: Some of this content is adapted from the Allen SDK Documentation.

After processing Neuropixels ecephys data with Kilosort, individual neuronal units have been identified and are stored in the units table, in the Units section of the NWB file. The units table contains information about hypothetical neurons determined by Kilosort. With this information and the stimulus information found in the stimulus tables of the Intervals section, this notebook helps examine the spiking behavior of these units in response to stimuli and their associated waveforms

Environment Setup

⚠️Note: If running on a new environment, run this cell once and then restart the kernel⚠️

import warnings
warnings.filterwarnings('ignore')

try:
    from databook_utils.dandi_utils import dandi_download_open
except:
    !git clone https://github.com/AllenInstitute/openscope_databook.git
    %cd openscope_databook
    %pip install -e .
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

Downloading Ecephys File

Change the values below to download the file you’re interested in. Set dandiset_id and dandi_filepath to correspond to the dandiset id and filepath of the file you want. If you’re accessing an embargoed dataset, set dandi_api_key to your DANDI API key. If you want to stream a file instead of downloading it, use dandi_stream_open instead. Checkout Streaming an NWB File with remfile for more details on this.

dandiset_id = "000021"
dandi_filepath = "sub-703279277/sub-703279277_ses-719161530.nwb"
download_loc = "."
dandi_api_key = None
# This can sometimes take a while depending on the size of the file
io = dandi_download_open(dandiset_id, dandi_filepath, download_loc, dandi_api_key=dandi_api_key)
nwb = io.read()
File already exists
Opening file

Extracting Unit Data and Stimulus Data

Below, the Units table is read from the file. Individual units (putative neurons) are identified with the id column. Note that, while each id is unique, they are not perfectly ordinal; some ids are missing. In the cells below, the unit properties are listed and a slice of the units table is shown. More thorough descriptions of units and their properties can be found in Visualizing Unit Quality Metrics

Additionally, the stimulus data is also read from the NWB file’s Intervals section. Stimulus information is stored as a series of tables depending on the type of stimulus shown in the session. One such table is displayed below.

units = nwb.units
units.colnames
('waveform_duration', 'cluster_id', 'peak_channel_id', 'cumulative_drift', 'amplitude_cutoff', 'snr', 'recovery_slope', 'isolation_distance', 'nn_miss_rate', 'silhouette_score', 'velocity_above', 'quality', 'PT_ratio', 'l_ratio', 'velocity_below', 'max_drift', 'isi_violations', 'firing_rate', 'amplitude', 'local_index', 'spread', 'waveform_halfwidth', 'd_prime', 'presence_ratio', 'repolarization_slope', 'nn_hit_rate', 'spike_times', 'spike_amplitudes', 'waveform_mean')
units[:10]
Loading...
# In older Allen NWBs, the peak channel id is stored directly in the units table
def get_peak_channel_id(units, i):
    return units["peak_channel_id"][i]

# In newer Allen NWBs, peak_channel_id does not exist. Instead, compute peak channel from waveform_mean:
# def get_peak_channel_id(units, i):
#     mean_waveforms = units["waveform_mean"][i]
#     peak_channel_idx = np.argmin(np.min(mean_waveforms, axis=0))
#     detected_electrodes = units["electrodes"][i]
#     return detected_electrodes.index[peak_channel_idx]
stimulus_names = nwb.intervals.keys()
print(stimulus_names)
dict_keys(['drifting_gratings_presentations', 'flashes_presentations', 'gabors_presentations', 'invalid_times', 'natural_movie_one_presentations', 'natural_movie_three_presentations', 'natural_scenes_presentations', 'spontaneous_presentations', 'static_gratings_presentations'])
stim_table = nwb.intervals["drifting_gratings_presentations"]
stim_table[:]
Loading...

Getting Stimulus Epochs

Here, epochs are extracted from the stimulus tables. In this case, an epoch is a continuous period of time during a session where a particular type of stimulus is shown. The output here is a list of epochs, where an epoch is a tuple of four values; the stimulus name, the stimulus block, the starting time and the ending time. Since stimulus information can vary significantly between experiments and NWB files, you may need to tailor the code below to extract epochs for the file you’re interested in.

# extract epoch times from stim table where stimulus rows have a different 'block' than following row
# returns list of epochs, where an epoch is of the form (stimulus name, stimulus block, start time, stop time)
def extract_epochs(stim_name, stim_table, epochs):
    block_key = "stim_block" if "stim_block" in stim_table.colnames else "stimulus_block"

    # if block_key is not present, treat the whole table as one block
    if block_key not in stim_table.colnames:
        epoch_start = stim_table.start_time[0]
        epoch_stop = stim_table.stop_time[-1]
        epochs.append((stim_name, 0, epoch_start, epoch_stop))
        return epochs

    # specify a current epoch stop and start time
    epoch_start = stim_table.start_time[0]
    epoch_stop = stim_table.stop_time[0]

    # for each row, try to extend current epoch stop_time
    for i in range(len(stim_table)):
        this_block = stim_table[block_key][i]
        # if end of table, end the current epoch
        if i+1 >= len(stim_table):
            epochs.append((stim_name, this_block, epoch_start, epoch_stop))
            break
            
        next_block = stim_table[block_key][i+1]
        # if next row is the same stim block, push back epoch_stop time
        if next_block == this_block:
            epoch_stop = stim_table.stop_time[i+1]
        # otherwise, end the current epoch, start new epoch
        else:
            epochs.append((stim_name, this_block, epoch_start, epoch_stop))
            epoch_start = stim_table.start_time[i+1]
            epoch_stop = stim_table.stop_time[i+1]
    
    return epochs
# extract epochs from all valid stimulus tables
epochs = []
for stim_name in stimulus_names:
    stim_table = nwb.intervals[stim_name]
    print(stim_name)
    try:
        epochs = extract_epochs(stim_name, stim_table, epochs)
    except:
        continue

# epochs take the form (stimulus name, stimulus block, start time, stop time)
print(len(epochs))
epochs.sort(key=lambda x: x[2])
for epoch in epochs:
    print(epoch)
drifting_gratings_presentations
flashes_presentations
gabors_presentations
invalid_times
natural_movie_one_presentations
natural_movie_three_presentations
natural_scenes_presentations
spontaneous_presentations
static_gratings_presentations
17
('spontaneous_presentations', 0, np.float64(29.83010738159049), np.float64(8611.04613738159))
('gabors_presentations', np.float64(0.0), np.float64(89.8968273815905), np.float64(1001.8917716749913))
('invalid_times', 0, np.float64(970.0), np.float64(982.0))
('flashes_presentations', np.float64(1.0), np.float64(1290.8830973815907), np.float64(1589.382401312724))
('drifting_gratings_presentations', np.float64(2.0), np.float64(1591.1338573815906), np.float64(2190.634543106125))
('natural_movie_three_presentations', np.float64(3.0), np.float64(2221.6604473815905), np.float64(2822.161967381591))
('natural_movie_one_presentations', np.float64(4.0), np.float64(2852.1870373815905), np.float64(3152.4377773815904))
('drifting_gratings_presentations', np.float64(5.0), np.float64(3182.4628573815908), np.float64(3781.963503106125))
('natural_movie_three_presentations', np.float64(6.0), np.float64(4083.215117381591), np.float64(4683.716567381592))
('drifting_gratings_presentations', np.float64(7.0), np.float64(4713.741627381592), np.float64(5397.312443106124))
('static_gratings_presentations', np.float64(8.0), np.float64(5398.31325738159), np.float64(5878.714467381591))
('natural_scenes_presentations', np.float64(9.0), np.float64(5908.739537381591), np.float64(6389.157337381589))
('natural_scenes_presentations', np.float64(10.0), np.float64(6689.408117381589), np.float64(7169.809297381591))
('static_gratings_presentations', np.float64(11.0), np.float64(7199.83431738159), np.float64(7680.268867381591))
('natural_movie_one_presentations', np.float64(12.0), np.float64(7710.293937381591), np.float64(8010.54467738159))
('natural_scenes_presentations', np.float64(13.0), np.float64(8040.56972738159), np.float64(8568.510620243858))
('static_gratings_presentations', np.float64(14.0), np.float64(8611.04613738159), np.float64(9151.49746738159))

Visualizing Unit Activity Throughout Epochs

Below is a view of the spiking activity of a unit throughout a session, where epochs are shown as colored sections. Set unit_num to be the id of the unit to view. Set time_start to the starting bound in seconds of the session, you’d like to see, and time_end to the ending bound. You may want to use the output above to inform your choice. As mentioned above, if your file’s stimulus information differs significantly, this code may need to be modified to appropriately display the epochs.

unit_num = 950930672 # chosen from units table
time_start = 0
time_end = 10000
# translate unit id to row index
unit_idx = -1
for i in range(len(units.id)):
    if units.id[i] == unit_num:
        unit_idx = i
        break
print("Unit index:",unit_idx)
Unit index: 648
# make histogram of unit spikes per second over specified timeframe
spikes = units["spike_times"][unit_idx]
time_bin_edges = np.linspace(time_start, time_end, (time_end-time_start))
hist, bins = np.histogram(spikes, bins=time_bin_edges)
# generate plot of spike histogram with colored epoch intervals and legend
fig, ax = plt.subplots(figsize=(15,5))

# assign unique color to each stimulus name
stim_names = list({epoch[0] for epoch in epochs})
colors = plt.cm.rainbow(np.linspace(0,1,len(stim_names)))
stim_color_map = {stim_names[i]:colors[i] for i in range(len(stim_names))}

epoch_key = {}
height = max(hist)
# draw colored rectangles for each epoch
for epoch in epochs:
    stim_name, stim_block, epoch_start, epoch_end = epoch
    color = stim_color_map[stim_name]
    rec = ax.add_patch(mpl.patches.Rectangle((epoch_start, 0), epoch_end-epoch_start, height, alpha=0.2, facecolor=color))
    epoch_key[stim_name] = rec
    
ax.set_xlim(time_start, time_end)
ax.set_ylim(-0.1, height+0.1)
ax.set_xlabel("time (s)")
ax.set_ylabel("# spikes")
ax.set_title("Unit Spikes Per Second Throughout Epochs")

fig.legend(epoch_key.values(), epoch_key.keys(), loc="lower right", bbox_to_anchor=(1.12, 0.25))
ax.plot(bins[:-1], hist)
<Figure size 1500x500 with 1 Axes>

Probewise Activity Throughout Epochs

It can also be useful to view the activity of an entire probe throughout epochs. The code below allows users to select a probe and a histogram is produced for the total of all unit spiking counts over time for one probe. To do this, the file’s Electrodes table is used. To understand this, you need not know anything about the Electrodes table, except that it can be used to map channel IDs to probe IDs. Below is printed a list of the probe names to choose from. Set probe_name to one of these. Set time_start and time_end to the start and ending times, in seconds, to view in the session.

print(nwb.devices.keys())
dict_keys(['probeA', 'probeB', 'probeC', 'probeD', 'probeE', 'probeF'])
probe_name = "ProbeC"
time_start = 0
time_end = 10000
# get list of channels on this probe
electrodes = nwb.electrodes.to_dataframe()
channel_ids = set(electrodes.index[electrodes["group_name"].str.lower() == probe_name.lower()].tolist())

# get all spike times from units whose peak channel belongs to the selected probe
probe_spike_times = []
for i in range(len(units)):
    if get_peak_channel_id(units, i) in channel_ids:
        probe_spike_times += list(units["spike_times"][i])

len(probe_spike_times)
27741944
# make histogram of unit spikes per second over specified timeframe
time_bin_edges = np.linspace(time_start, time_end, (time_end-time_start))
hist, bins = np.histogram(probe_spike_times, bins=time_bin_edges)
# generate plot of spike histogram with colored epoch intervals and legend
fig, ax = plt.subplots(figsize=(15,5))

# assign unique color to each stimulus name
stim_names = list({epoch[0] for epoch in epochs})
colors = plt.cm.rainbow(np.linspace(0,1,len(stim_names)))
stim_color_map = {stim_names[i]:colors[i] for i in range(len(stim_names))}

epoch_key = {}
height = max(hist)
# draw colored rectangles for each epoch
for epoch in epochs:
    stim_name, stim_block, epoch_start, epoch_end = epoch
    color = stim_color_map[stim_name]
    rec = ax.add_patch(mpl.patches.Rectangle((epoch_start, 0), epoch_end-epoch_start, height, alpha=0.2, facecolor=color))
    epoch_key[stim_name] = rec
    
ax.set_xlim(time_start, time_end)
ax.set_ylim(-0.1, height+0.1)
ax.set_xlabel("time (s)")
ax.set_ylabel("# spikes")
ax.set_title(f"{probe_name} Spikes Per Second Throughout Epochs")

fig.legend(epoch_key.values(), epoch_key.keys(), loc="lower right", bbox_to_anchor=(1.12, 0.25))
ax.plot(bins[:-1], hist)
<Figure size 1500x500 with 1 Axes>

Regionwise Activity Throughout Epochs

We can also break down our activity based on brain regions (with a bit of work). To do this, we must first be able to retrieve the brain region of each unit. The trick to do this lay in the Electrodes table, shown below. The Electrodes table contains the brain region for each electrode id, while the Units table (shown above) contains the Peak Channel ID for each Unit. These can be used together to get the brain region of each unit’s peak channel. Once this is done, this information can be used just like the Probe selection above to get to get spike counts over time of each Unit in a selected region. Below, set brain_region, start_time, and end_time to view such a plot.

nwb.electrodes[:10]
Loading...
# map electrode ids to brain region acronyms
electrode_locations = {nwb.electrodes["id"][i]: nwb.electrodes["location"][i] for i in range(len(nwb.electrodes))}
print(set(electrode_locations.values()))
{'', 'APN', 'SUB', 'grey', 'VISam', 'VISp', 'POL', 'LP', 'VISpm', 'NOT', 'VL', 'VISrl', 'TH', 'VISal', 'CA1', 'MB', 'CA3', 'DG', 'LGd', 'CA2', 'VPM', 'VISl', 'Eth', 'PO'}
brain_region = "VISpm"
time_start = 0
time_end = 10000
# get all spike times from units whose peak channel belongs to the selected brain region
region_spike_times = []
for i in range(len(units)):
    peak_channel = get_peak_channel_id(units, i)
    unit_location = electrode_locations.get(peak_channel, None)
    if unit_location == brain_region:
        region_spike_times += list(units["spike_times"][i])

len(region_spike_times)
4178172
# make histogram of unit spikes per second over specified timeframe
time_bin_edges = np.linspace(time_start, time_end, (time_end-time_start))
hist, bins = np.histogram(region_spike_times, bins=time_bin_edges)
# generate plot of spike histogram with colored epoch intervals and legend
fig, ax = plt.subplots(figsize=(15,5))

# assign unique color to each stimulus name
stim_names = list({epoch[0] for epoch in epochs})
colors = plt.cm.rainbow(np.linspace(0,1,len(stim_names)))
stim_color_map = {stim_names[i]:colors[i] for i in range(len(stim_names))}

epoch_key = {}
height = max(hist)
# draw colored rectangles for each epoch
for epoch in epochs:
    stim_name, stim_block, epoch_start, epoch_end = epoch
    color = stim_color_map[stim_name]
    rec = ax.add_patch(mpl.patches.Rectangle((epoch_start, 0), epoch_end-epoch_start, height, alpha=0.2, facecolor=color))
    epoch_key[stim_name] = rec
    
ax.set_xlim(time_start, time_end)
ax.set_ylim(-0.1, height+0.1)
ax.set_xlabel("time (s)")
ax.set_ylabel("# spikes")
ax.set_title(f"{brain_region} Spikes Per Second Throughout Epochs")

fig.legend(epoch_key.values(), epoch_key.keys(), loc="lower right", bbox_to_anchor=(1.12, 0.25))
ax.plot(bins[:-1], hist)
<Figure size 1500x500 with 1 Axes>

Showing Spike Times

Here, a histogram plot of unit spikes over time is created. In the second cell below, set stim_time to be the time of the stimulus you’re interested in viewing. To get an idea of the stimulus times you might be interested in, access one of the tables in the Intervals section, discussed above in Extracting Unit Data and Stimulus Data. The first cell below shows how to access these. Set interval_start and interval_end to the relative time bounds, in seconds, of the histogram around stim_time. Finally, start_unit and end_unit can be used to choose the slice indices of selected_units to display.

stim_time = 1007 # arbitrarily chosen here
interval_start = -0.5
interval_end = 5

start_unit = 100
end_unit = 500
spike_times = [elem for elem in units["spike_times"][start_unit:end_unit]]
if len(spike_times) == 0:
    raise Exception("There are no spiking units in this selection")
len(spike_times)
400
# for each unit, generate a histogram with 275 bins, where bins represent the number spikes per second
time_bin_edges = np.linspace(interval_start, interval_end, 276)
hists = []
for unit_spike_times in spike_times:
    hist, bins = np.histogram(unit_spike_times-stim_time, bins=time_bin_edges)
    hists.append(hist)
hists = np.array(hists)

hists.shape
(400, 275)
# display array of histograms as 2D image with color
fig, ax = plt.subplots(figsize=(16,16))
img = ax.imshow(hists)
cbar = plt.colorbar(img, shrink=0.5)
cbar.set_label('# spikes')

ax.yaxis.set_major_locator(plt.NullLocator())
ax.set_ylabel("units", fontsize=16)

xtick_step=25
reltime = np.array(time_bin_edges)

ax.set_xticks(np.arange(0, len(reltime), xtick_step))
ax.set_xticklabels([f'{mp:1.1f}' for mp in reltime[::xtick_step]], rotation=45)

ax.set_xlabel("time since event (s)", fontsize=16)
ax.set_title("Units Spikes Over Time", fontsize=20)
<Figure size 1600x1600 with 2 Axes>

Visualizing Waveforms

The Units table can also be used to view the waveforms of a units with the waveform_mean property, which consists of the mean waveform of that unit as measured by each channel along the probe. One channel will contain the peak waveform. With a bit of legwork, the peak_channel_id of the unit and the Electrodes table can be used to get the single peak waveform as shown below. Just as in the sections above, you don’t need to fully understand the Electrodes table except that it can be used to map channel IDs to probe IDs. There is also a timewise and channelwise view of all the mean waveforms and an average of the waveforms across all channels.

unit_num = 950952910
# translate unit id to row index
unit_idx = -1
for i in range(len(units.id)):
    if units.id[i] == unit_num:
        unit_idx = i
        break
print("Unit index:",unit_idx)
Unit index: 2513
# get sampling Hz for this unit's waveform
peak_channel = get_peak_channel_id(units, unit_idx)
electrodes = nwb.electrodes.to_dataframe()
probe_name = electrodes.loc[peak_channel].group_name

# Hz = nwb.devices[probe_name].sampling_rate
Hz = 30000 # hz is always 30,000 in Allen NWBs
Hz
30000

Peak waveform

# get peak waveform for this unit
unit_waveform = units["waveform_mean"][unit_idx]
peak_channel_idx = np.argmin(np.min(unit_waveform, axis=0))
peak_waveform = unit_waveform[:, peak_channel_idx]
fig, ax = plt.subplots(figsize=(10,6))

n_secs = peak_waveform.shape[0] / Hz
time_axis = np.linspace(0, n_secs * 1000, peak_waveform.shape[0])

ax.plot(time_axis, peak_waveform)

ax.set_xlabel("time (ms)")
ax.set_ylabel("membrane potential (uV)")
ax.set_title("Unit Peak Waveform", fontsize=20)

plt.show()
<Figure size 1000x600 with 1 Axes>

Waveforms

Given the way Kilosort attributes spikes to specific units, most electrodes along the probe will detect nothing for a given unit. As the electrodes get closer to the actual source of the spikes, the waveform amplitude should increase and take shape. To get a very clear representation of a unit’s activity through space, the waveforms for a unit from all probes can be plotted. It can be seen that as the electrode gets further from the peak waveform, the amplitude decreases until the unit is too far away to be detected.

unit_waveforms = units["waveform_mean"][unit_idx]
unit_waveforms.shape
(384, 82)
fig, ax = plt.subplots(figsize=(10,6))
colors = plt.cm.viridis(np.linspace(0, 1, unit_waveforms.shape[0]))
ax.set_prop_cycle(color=colors)

n_secs = unit_waveforms.shape[1] / Hz
time_axis = np.linspace(0, n_secs * 1000, unit_waveforms.shape[1])

ax.plot(time_axis, np.transpose(unit_waveforms))

norm = mpl.colors.Normalize(vmin=0, vmax=len(colors))
cb = fig.colorbar(mpl.cm.ScalarMappable(norm=norm), ax=ax, label='channel #')

ax.set_xlabel("time (ms)")
ax.set_ylabel("membrane potential (uV)")
ax.set_title("Unit Waveforms", fontsize=20)

plt.show()
<Figure size 1000x600 with 2 Axes>

Waveform Image

Below is an image of the waveform means for each channel of a unit. The further in space from the real neuron, the weaker the measurement of the response waveform, so it is usually only useful to view a subsection of the channels at once. Set start_channel and end_channel to be the bounds of the channels you want displayed. Because on neuropixels probes the channels are arranged into two rows along the length of a probe, typically a unit is only strongly detected by every other channel. The data shown below displays every other channel to avoid the resultant striping effect. If the waveform looks too dim, try incrementing start_channel by 1.

start_channel = 230
end_channel = 280
waveform_2d = unit_waveforms[start_channel:end_channel:2] # step by 2 to remove striping effect
n_channels = end_channel - start_channel

fig, ax = plt.subplots(figsize=(8, n_channels // 15))

n_secs = unit_waveforms.shape[1] / Hz
time_axis = np.linspace(0, n_secs * 1000, unit_waveforms.shape[1])

norm = mpl.colors.Normalize(vmin=np.min(waveform_2d), vmax=np.max(waveform_2d))
cb = fig.colorbar(mpl.cm.ScalarMappable(norm=norm), ax=ax, label="membrane potential (uV)")

ax.set_xlabel("time (ms)")
ax.set_ylabel("channel #")
ax.set_title("Unit Waveforms Image")

ax.imshow(waveform_2d, vmin=np.min(waveform_2d), vmax=np.max(waveform_2d), extent=[0, n_secs*1000, end_channel, start_channel], aspect="auto")
<Figure size 800x300 with 2 Axes>