Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Classifying Fast-Spiking and Regular-Spiking Neurons

In this notebook, we will classify fast-spiking (FS) neurons from regular-spiking (RS) neurons in visual cortical areas of the mouse brain. We will also visualize the differences between these two types of cells by comparing their peak waveforms and optotagged unit responses. This notebook utilizes open source data from the Allen Institute, but is designed to be customizable with alternative datasets.

Environment Setup

⚠️Note: If running on a new environment, run this cell once and then restart the kernel⚠️

import warnings
warnings.filterwarnings('ignore')

try:
    from databook_utils.dandi_utils import dandi_download_open
except:
    !git clone https://github.com/AllenInstitute/openscope_databook.git
    %cd openscope_databook
    %pip install -e .
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from databook_utils.dandi_utils import dandi_stream_open
from hdmf.common import DynamicTable
from pynwb import NWBHDF5IO
from pynwb.misc import Units

%matplotlib inline

Streaming Ecephys Files

dandiset_id = "000021"
dandi_filepath = "sub-707296975/sub-707296975_ses-721123822.nwb"
dandi_api_key = None
# This can sometimes take a while depending on the size of the file
io = dandi_download_open(dandiset_id, dandi_filepath, dandi_api_key=dandi_api_key)
nwb = io.read()
File already exists
Opening file

Accessing Data Tables

Here, we want to access data from the Units table of the NWB file. This table contains information about each unit from a given experiment. A unit typically refers to an individual neuron that is recorded during an extracellular electrophysiology experiment. We will also be accessing data from the Electrodes table of the NWB file. This table contains information about each electrode (neuropixel probe) that was used during this experiment. This is a useful place to get information about specific channels, or recording sites, that are positioned along the probe. While we will not use every column of data from both of these tables, the relevant data we extract will be explained below.

# take long time up front, but makes analysis much faster
units = nwb.units.to_dataframe()
units[:10]
Loading...
# convert the electrodes table into a pandas df for easier analysis
electrodes = nwb.electrodes
electrodes_df = electrodes.to_dataframe()

# Reset the index to include it as a regular column
electrodes_df.reset_index(inplace=True)

electrodes_df[:10]
Loading...

Selecting and Aligning Particular Data From the Tables

Now we want to extract and align particular data from the Units and Electrodes tables from above. There are several criteria that need to be met in order for us to select only specific neurons that will be valuable to our analysis:

  1. Location: The tables above contain information from every area of the brain that the probes were inserted into. However, we only want to analyze data that was collected from areas on the probe that were inserted into the visual cortical areas of the mouse brain. For this reason, we create a dictionary called channel_to_probes_dict to map each channel of each probe to its position within the brain so we can exclude units that are not from these desired areas. Since the channel location information is stored in the Electrodes table, we wrote a function called get_unit_location where we can input a peak_channel_id from the stimulus table and receive the location of the channel. This function allows us to align data from the Electrodes table with data from the Units table. We use the cortex list to identify which units we want to select for particular coritcal locations in the example below, and exclude any units that are not from these particular locations in the list.

  2. Electrode Validity: The last column of the Electrodes table is labeled valid_data and tells us if data from a specific channel was valid or not. We only want to include data from channels that are ‘True’ (valid) for this column. We write a function called is_electrode_valid that will create a list of the id of each channel that provides valid data, so we can exclude any channel from our analysis that does not meet this criteria.

  3. Unit Quality: One of the columns in the Units table called quality contains information about the quality of the data of each unit. We only want to include units in our analysis that have ‘good’ quality and are not ‘noise’. To exclude the noisy units, we create a function that will return ‘True’ only when the unit’s quality is ‘good’.

# In older Allen NWBs the information must be manually cross-referenced between the units table and electrodes table

# create a dictionary that maps the id of a channel to its location 
channel_to_probes_dict = {}

for i in range(len(electrodes_df)):
    channel_id = electrodes_df['id'][i]
    location = electrodes_df['location'][i]
    channel_to_probes_dict[channel_id] = location

def get_peak_channel_idx(unit_row):
    return unit_row['peak_channel_id']

# function aligns location information from electrodes table with channel id from the units table
def get_unit_location(peak_id):
    return channel_to_probes_dict[peak_id]



# In newer Allen NWBs, the units table includes direct references to each unit's electrodes, so we can access them directly
# def get_peak_channel_idx(unit_row):
#     mean_waveforms = unit_row['waveform_mean']
#     waveform_mins = np.min(mean_waveforms, axis=0)
#     peak_channel_idx = np.argmin(waveform_mins)
#     return peak_channel_idx

# def get_unit_location(unit_row):
#     peak_channel_idx = get_peak_channel_idx(unit_row)
#     detected_electrodes = unit_row['electrodes']
#     return detected_electrodes.iloc[peak_channel_idx].location
# returns a set of channel ids that are valid within electrodes table

# older Allen NWBs used the valid_data column
if 'valid_data' in electrodes_df.columns:
    valid_electrode_ids = electrodes_df.loc[electrodes_df['valid_data'] == True, 'id'].tolist()
# newer Allen NWBs only include valid channels in the electrodes table, so we can just take all of the ids
else:
    valid_electrode_ids = electrodes_df['id'].tolist()

valid_electrode_id = set(valid_electrode_ids)

Now that we have defined some of the criteria for selecting data, we can combine what we have with some additional filtering criteria into one unified function called select_unit that will be used in a for loop below. When data from a specific unit is passed into this function, the function will return ‘True’ if that unit meets all the defined criteria. This will help us exclude units that do not meet the criteria in the loop below. You can modify this function to fit whatever criteria is most reasonable for your dataset.

cortex = ['VISrl', 'VISal', 'VISam', 'VISpm', 'VISp', 'VISI']

# this will work for older Allen NWBs
def select_unit(unit_row):
    if (unit_row['quality'] == 'good' and
    get_unit_location(unit_row['peak_channel_id']) in cortex and
    unit_row['peak_channel_id'] in valid_electrode_ids and
    unit_row['isi_violations'] < 0.5 and
    unit_row['amplitude_cutoff'] < 0.1 and
    unit_row['presence_ratio'] > 0.95):
        return True

# this will work for newer Allen NWBs
# def select_unit(unit_row):
# 	if unit_row['decoder_label'] == 'sua' and \
# 	unit_row['default_qc'] == True and \
# 	any(cortical_loc in get_unit_location(unit_row) for cortical_loc in cortex):
# 	# unit_row['isi_violations_count'] < 0.5 and \
# 	# unit_row['amplitude_cutoff'] < 0.1 and \
# 	# unit_row['presence_ratio'] > 0.95:
# 		return True

Plot the Distribution of Waveform Durations

Now that we have all of our unit selection criteria defined and in one function, we need to generate a list of properties from the Units table to plot. For this first plot, we want to plot a distribution of waveform_durations for our selected units. Ideally, this will form a bimodal distribution with one peak containing the waveform durations of fast-spiking (FS) neurons and the other containing the durations of regular-spiking (RS) neurons. Characteristically, FS neurons will have shorter spike durations than RS neurons, so they will be represented by the first peak in the distribution. In order to select the waveform_durations for the desired units, we first created a for-loop that loops through the Units table and appends the index of a unit when that unit meets the criteria. This list, selected_unit_indices, can be used to index specific data from the units table that meets the criteria for all the superseding plots. Next, we create a for-loop that appends a specific unit’s waveform_duration using the selected_unit_indices list. This list of durations called selected_waveform_durations is then plotted below.

selected_unit_indices = []

# loop that creates a list of units' waveform durations if the unit meets the filtering criteria
for i, unit_row in units.iterrows():
    print(f"evaluating unit {i}/{len(units)}")
    if select_unit(unit_row) == True: 
        selected_unit_indices.append(i)

print(selected_unit_indices)
print(f"Selected {len(selected_unit_indices)} units out of {len(units)}")
Fetching long content....
waveform_duration_unit = "ms"
# infer waveform duration_unit from data. Feel free to set manually if needed
halfwidth_key = 'half_width' if 'half_width' in units.columns else 'waveform_halfwidth'
if np.mean(units[halfwidth_key]) < 0.01:
	waveform_duration_unit = "s"
selected_waveform_duration = units.loc[selected_unit_indices, halfwidth_key].tolist()

print('Number of durations that will be plotted: ', len(selected_waveform_duration))
Number of durations that will be plotted:  161
data = selected_waveform_duration

fig, ax = plt.subplots()
ax.hist(data, bins=50)

plt.xlabel(f"Waveform Duration ({waveform_duration_unit})")
plt.ylabel("# units")
plt.title("Distribution of Waveform Durations from Visual Cortical Neurons")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()
<Figure size 640x480 with 1 Axes>

As you can see, there are two defined peaks in this bimodal distribution. We want to separate the FS neurons from the RS neurons, and to do so we must select a threshold point. This means that everything to the left of the point has shorter waveform durations and can be classified as FS neurons, and everything to the right of the point has longer waveform durations and can be classified as RS neurons. It is expected that this threshold is about 0.4 ms. The threshold point usually falls in the middle between the two peaks, and, for this plot, can be identified as the space between the 2 distributions that has the fewest number of units (0) for that time in the duration.

# define the threshold based on the plot above
threshold = 0.4

if waveform_duration_unit == "s":
	threshold /= 1000

Plot The Waveform Profile of RS vs FS Neurons

Now that we have defined a threshold point that can classify FS neurons from RS neurons, we can plot the unit peak waveforms for each of these two categories and compare them. There are several steps that must be taken in order to do so. The waveform information in the units table is a 2D array that contains many waveforms for an individual unit. In order to select the peak waveform we want to plot from this 2D array, we need to index the array using the local_index. However, the local index data is stored in the electrodes table and not the units table. First, we need to create a dictionary that maps the id of each channel in the electrodes table to its associated local_index. This dictionary is necessary because by mapping the index to the channel ID, we can cross reference the data between the units table and the electodes table based on a specific identifier. Next, we use a for-loop to append the selected_peak_waveform to selected_peak_waveform_rs if the unit’s duration is above the threshold value or selected_peak_waveform_fs if the unit’s duration is less than the threshold value. We then plot the peak waveforms on top of each other for each type of neuron independently.

selected_peak_waveform_rs = []
selected_peak_waveform_fs = []

for index, unit_row in units.loc[selected_unit_indices].iterrows():
    # column names can vary across nwbs:
    if 'waveform_duration' in units.columns:
        waveform_duration = unit_row['waveform_duration']
    else:
        waveform_duration = unit_row['half_width']
    unit_mean_waveform = unit_row['waveform_mean']
    peak_channel_idx = np.argmin(np.min(unit_mean_waveform, axis=0))
    peak_waveform = unit_mean_waveform[:, peak_channel_idx]

    if waveform_duration > threshold:
        selected_peak_waveform_rs.append(peak_waveform)
    else: 
        selected_peak_waveform_fs.append(peak_waveform)
        
print('number of RS waveforms to be plotted: ', len(selected_peak_waveform_rs))
print('number of FS waveforms to be plotted: ', len(selected_peak_waveform_fs))   

# convert lists into arrays and transpose for the plots
selected_peak_waveform_fs = np.array(selected_peak_waveform_fs).transpose()
selected_peak_waveform_rs = np.array(selected_peak_waveform_rs).transpose()
number of RS waveforms to be plotted:  134
number of FS waveforms to be plotted:  27
# these will be the same for each unit in this experiment, so we can index the first unit to define these values
one_device = list(nwb.devices.values())[0]
Hz = getattr(one_device, "sampling_rate", 30000) # default to 30000 if sampling_rate is not found
n_secs = peak_waveform.shape[0] / Hz

fig, ax = plt.subplots(figsize=(10,6))

n_secs = len(selected_peak_waveform_fs[:,0]) / Hz 
time_axis = np.linspace(0, n_secs * 1000, len(selected_peak_waveform_fs[:,0]))

ax.plot(time_axis, selected_peak_waveform_fs, color='k', alpha=0.1)
ax.plot(time_axis, np.mean(selected_peak_waveform_fs, axis=1), color='b', linewidth=2)

ax.set_xlabel("time (ms)")
ax.set_ylabel("spike amplitude (uV)")
ax.set_title("Unit Peak Waveform FS", fontsize=20)

plt.show()
<Figure size 1000x600 with 1 Axes>
fig, ax = plt.subplots(figsize=(10,6))

# can do this because they are all the same length for this dataset 
n_secs = len(selected_peak_waveform_rs[:,0]) / Hz 
time_axis = np.linspace(0, n_secs * 1000, len(selected_peak_waveform_rs[:,0]))

ax.plot(time_axis, selected_peak_waveform_rs, color='k', alpha=0.1)
ax.plot(time_axis, np.mean(selected_peak_waveform_rs, axis=1), color='r', linewidth=2)

ax.set_xlabel("time (ms)")
ax.set_ylabel("spike amplitude (uV)")
ax.set_title("Unit Peak Waveform RS", fontsize=20)

plt.show()
<Figure size 1000x600 with 1 Axes>