Introduction to Generalized Linear Models using Pynapple and NeMoS#

Authors: Camila Maura, Edoardo Balzani & Guillaume Viejo

In this notebook, we will use Pynapple and NeMoS packages (supported by the Flatiron Institute), to model spiking neural data using Generalized Linear Models (GLM). We will explain what GLMs are and which are their components, then use Pynapple and NeMoS python packages to preprocess real data from the Primary Visual Cortex (VISp) of mice, and use a GLM model to predict spiking neural data as a function of passive visual stimuli. We will also show how, if you have recordings from a large population of neurons simultaneously, you can build connections between the neurons into the GLM in the form of coupling filters.

We will be analyzing data from the Visual Coding - Neuropixels dataset, published by the Allen Institute. This dataset uses extracellular electrophysiology probes to record spikes from multiple regions in the brain during passive visual stimulation. For simplicity, we will focus on the activity of neurons in the visual cortex (VISp) during passive exposure to full-field flashes of color either black (coded as “-1.0”) or white (coded as “1.0”) in a gray background.

We have three main goals in this notebook:

  1. Introduce the key components of Generalized Linear Models (GLMs),

  2. Demonstrate how to pre-process real experimental data recorded from mice using Pynapple, and

  3. Use NeMoS to fit GLMs to that data and explore model-based insights.

By the end of this notebook, you should have a clearer understanding of the fundamental building blocks of GLMs, as well as how Pynapple and NeMoS can streamline the process of modeling and analyzing neural data, making it a much more accessible and efficient endeavor.

Background on GLMs#

A GLM is a regression model which trains a filter to predict a value (output) as it relates to some other variable (or input). In the neuroscience context, we can use a particular type of GLM to predict spikes: the linear-nonlinear-Poisson (LNP) model. This type of model receives one or more inputs and then sends them through a linear “filter” or transformation, passes said transformation through a nonlinearity to get the firing rate and uses that firing rate as the mean of a Poisson distribution to generate spikes. We will go through each of these steps one by one:

LNP model schematic

LNP model schematic. Modified from Pillow et al. [2008] [1a].

  1. Sends the inputs through a linear “filter” or transformation

    The inputs (also known as “predictors” or “filters”) are first passed through a linear transformation:

    \[ \begin{aligned} L(X) = WX + c \end{aligned} \]

    Where \(X\) is the input (in matrix form), \(W\) is a matrix and \(c\) is a vector (intercept).

    \(L\) scales (makes bigger or smaller) or shifts (up or down) the input. When there is zero input, this is equivalent to changing the baseline rate of the neuron, which is how the intercept should be interpreted. So far, this is the same treatment of an ordinary linear regression.

  2. Passes the transformation through a nonlinearity to get the firing rate.

    The aim of a LNP model is to predict the firing rate of a neuron and use it to generate spikes, but if we were only to keep \(L(X)\) as it is, we would quickly notice that we could obtain negative values for firing rates, which makes no sense! This is what the nonlinearity part of the model handles: by passing the linear transformation through an exponential function, it is assured that the resulting firing rate will always be non-negative.

    As such, the firing rate in a LNP model is defined:

    \[ \begin{aligned} \lambda = exp(L(X)) \end{aligned} \]

    where \(\lambda\) is a vector containing the firing rates corresponding to each timepoint.

  1. Uses the firing rate as the mean of a Poisson distribution to generate spikes

    In this type of GLM, each spike train is modeled as a sample from a Poisson distribution whose mean is the firing rate — that is, the output of the linear-nonlinear components of the model.

    Spiking is a stochastic process. This means that a given firing rate can lead to many different possible spike trains. Since the model could generate an infinite number of spike train realizations, how do we evaluate how well it explains the single observed spike train? We do this by computing the log-likelihood: it quantifies how likely it is to observe the actual spike train given the predicted firing rate. If \( y(t) \) is the observed spike count and \( \lambda(t) \) is the predicted firing rate at time \( t \), then the log-likelihood at time \( t \):

    \[ \log P(y(t) \mid \lambda(t)) = y(t)\log\lambda(t) - \lambda(t) -\log(y(t)!) \]

    However, the term \( -\log(y(t)!) \) does not depend on \( \lambda \), and therefore is constant with respect to the model. As a result, it is usually dropped during optimization, leaving us with the simplified log-likelihood:

    \[ \log P(y(t) \mid \lambda(t)) = y(t) \log \lambda(t) - \lambda(t) \]

    This forms the loss function for LNPs. In practice, we aim to maximize this log-likelihood, which is equivalent to minimizing the negative log-likelihood — that is, finding the firing rate \(\lambda(t)\) that makes the observed spike train as likely as possible under the model.

Environment setup and library imports#

# Install requirements for the databook
try:
    from databook_utils.dandi_utils import dandi_download_open
except:
    !git clone https://github.com/AllenInstitute/openscope_databook.git
    %cd openscope_databook
    %pip install -e .
'\ntry:\n    from databook_utils.dandi_utils import dandi_download_open\nexcept:\n    !git clone https://github.com/AllenInstitute/openscope_databook.git\n    %cd openscope_databook\n    %pip install -e .\n'
# Import libraries
import seaborn as sns
from scipy.stats import zscore
import numpy as np
import matplotlib.pyplot as plt
import pynapple as nap
import nemos as nmo

Hide code cell source

# Imports for ease of visualization
import warnings
import matplotlib as mpl
warnings.filterwarnings("ignore")
from matplotlib.ticker import MaxNLocator
from scipy.stats import gaussian_kde
from matplotlib.patches import Patch

# Parameters for plotting
custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", palette="colorblind", font_scale=1.5, rc=custom_params)

Download data#

# Dataset information
dandiset_id = "000021"
dandi_filepath = "sub-726298249/sub-726298249_ses-754829445.nwb"
download_loc = "."

# Download the data using NeMoS
io = nmo.fetch.download_dandi_data(dandiset_id, dandi_filepath)

Now that we have downloaded the data, it is very simple to open the dataset with Pynapple

data = nap.NWBFile(io.read(), lazy_loading=False)
nwb = data.nwb

print(data)
754829445
┍━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━┑
│ Keys                                               │ Type        │
┝━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━┥
│ units                                              │ TsGroup     │
│ static_gratings_presentations                      │ IntervalSet │
│ spontaneous_presentations                          │ IntervalSet │
│ natural_scenes_presentations                       │ IntervalSet │
│ natural_movie_three_presentations                  │ IntervalSet │
│ natural_movie_one_presentations                    │ IntervalSet │
│ gabors_presentations                               │ IntervalSet │
│ flashes_presentations                              │ IntervalSet │
│ drifting_gratings_presentations                    │ IntervalSet │
│ timestamps                                         │ Tsd         │
│ running_wheel_rotation                             │ Tsd         │
│ running_speed_end_times                            │ Tsd         │
│ running_speed                                      │ Tsd         │
│ raw_gaze_mapping/screen_coordinates_spherical      │ TsdFrame    │
│ raw_gaze_mapping/screen_coordinates                │ TsdFrame    │
│ raw_gaze_mapping/pupil_area                        │ Tsd         │
│ raw_gaze_mapping/eye_area                          │ Tsd         │
│ optogenetic_stimulation                            │ IntervalSet │
│ optotagging                                        │ Tsd         │
│ filtered_gaze_mapping/screen_coordinates_spherical │ TsdFrame    │
│ filtered_gaze_mapping/screen_coordinates           │ TsdFrame    │
│ filtered_gaze_mapping/pupil_area                   │ Tsd         │
│ filtered_gaze_mapping/eye_area                     │ Tsd         │
│ running_wheel_supply_voltage                       │ Tsd         │
│ running_wheel_signal_voltage                       │ Tsd         │
│ raw_running_wheel_rotation                         │ Tsd         │
┕━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━┙

Extraction, preprocessing and stimuli revision#

Extracting spiking data#

We have a lot of information in data, but we are interested in the units.

(it might take a while the first time that you run this - it’s okay! the dataset is quite big)

units = data["units"]

# See the columns
print(f"columns : {units.metadata_columns}")

# See the dataset
print(units)
columns : ['rate', 'spread', 'velocity_below', 'silhouette_score', 'firing_rate', 'd_prime', 'nn_hit_rate', 'waveform_duration', 'amplitude', 'cluster_id', 'snr', 'local_index', 'peak_channel_id', 'PT_ratio', 'presence_ratio', 'max_drift', 'cumulative_drift', 'repolarization_slope', 'waveform_halfwidth', 'amplitude_cutoff', 'nn_miss_rate', 'quality', 'velocity_above', 'isolation_distance', 'l_ratio', 'recovery_slope', 'isi_violations']
Index      rate      spread    velocity_below    silhouette_score    firing_rate    d_prime    nn_hit_rate    ...
---------  --------  --------  ----------------  ------------------  -------------  ---------  -------------  -----
951763702  2.38003   30.0      nan               nan                 2.38           4.77       0.98           ...
951763707  0.01147   80.0      nan               0.03                0.01           3.48       0.0            ...
951763711  3.1503    50.0      nan               0.17                3.15           6.08       1.0            ...
951763715  6.53      40.0      nan               0.12                6.53           5.04       0.99           ...
951763720  2.00296   40.0      0.0               0.2                 2.0            6.45       0.99           ...
951763724  8.66233   60.0      -7.55             0.22                8.66           3.1        0.86           ...
951763729  11.13402  30.0      -0.69             0.01                11.13          4.61       0.98           ...
...        ...       ...       ...               ...                 ...            ...        ...            ...
951777559  0.02108   110.0     -2.59             nan                 0.02           2.95       0.0            ...
951777565  0.08143   140.0     0.46              nan                 0.08           4.37       0.29           ...
951777571  0.20088   70.0      0.69              nan                 0.2            6.03       0.82           ...
951777576  0.01085   80.0      -0.96             nan                 0.01           2.28       nan            ...
951777582  0.1457    140.0     -3.49             -0.08               0.15           5.2        0.43           ...
951777593  0.0464    90.0      nan               nan                 0.05           3.83       0.59           ...
951777600  0.0621    60.0      -0.69             nan                 0.06           6.12       0.25           ...

Taking a closer look at the columns, we can see there is a lot of information we do not need. We are solely interested in predicting the spiking activity from the neurons from VISp. Thus, we will remove the metadata from all columns except for rate, quality (to make sure we filter the bad-quality neurons) and peak_channel_id (this last one contains relevant information for brain area identification).

def restrict_cols(cols_to_keep, data):
    cols_to_remove = [col for col in data.metadata_columns if col not in cols_to_keep]
    data.drop_info(cols_to_remove)
    
# Choose which columns to remove and remove them
cols_to_keep = ['rate', 'quality','peak_channel_id']
restrict_cols(cols_to_keep,units)

# See the dataset
print(units)
Index      rate      peak_channel_id    quality
---------  --------  -----------------  ---------
951763702  2.38003   850135036          good
951763707  0.01147   850135036          noise
951763711  3.1503    850135038          good
951763715  6.53      850135038          good
951763720  2.00296   850135044          good
951763724  8.66233   850135044          noise
951763729  11.13402  850135044          noise
...        ...       ...                ...
951777559  0.02108   850139336          good
951777565  0.08143   850139526          noise
951777571  0.20088   850139738          good
951777576  0.01085   850139338          good
951777582  0.1457    850139622          good
951777593  0.0464    850139620          good
951777600  0.0621    850139642          good

Here we do not have the brain area information but we need it, so we need to do some preprocessing to extract brain area from the nwb object using the peak_channel_id metadata. Luckily, Pynapple stored the nwb object as well.

# Units and brain areas those units belong to are in two different places. 
# With the electrodes table, we can map units to their corresponding brain regions.
def get_unit_location(unit_id):
    """Aligns location information from electrodes table with channel id from the units table
    """
    return channel_probes[int(units[unit_id].peak_channel_id)]

channel_probes = {}
electrodes = nwb.electrodes

for i in range(len(electrodes)):
    channel_id = electrodes["id"][i]
    location = electrodes["location"][i]
    channel_probes[channel_id] = location

# Add a new column to include location in our spikes TsGroup
units.brain_area = [channel_probes[int(ch_id)] for ch_id in units.peak_channel_id]

# Remove peak_channel_id because we already got the brain_area information
units.drop_info("peak_channel_id")

print(units)
Index      rate      quality    brain_area
---------  --------  ---------  ------------
951763702  2.38003   good       PoT
951763707  0.01147   noise      PoT
951763711  3.1503    good       PoT
951763715  6.53      good       PoT
951763720  2.00296   good       PoT
951763724  8.66233   noise      PoT
951763729  11.13402  noise      PoT
...        ...       ...        ...
951777559  0.02108   good       LP
951777565  0.08143   noise      DG
951777571  0.20088   good       VISpm
951777576  0.01085   good       LP
951777582  0.1457    good       CA1
951777593  0.0464    good       CA1
951777600  0.0621    good       CA1

Extracting trial structure#

Mice were exposed to a series of stimuli (gabor patches, flashes, natural images, etc.), out of which we are exclusively interested in flashes presentation for this tutorial.

visual_stimuli_set.png

Visual stimuli set. Modified from Allen Institute for Brain Science [2019] [2].

During the flashes presentation trials, mice were exposed to white or black full-field flashes in a gray background, each lasting 250 ms, and separated by a 2 second inter-trial interval. In total, they were exposed to 150 flashes (75 black, 75 white).

# Extract flashes as an Interval Set object
flashes = data["flashes_presentations"]

# Remove unnecessary columns, similarly to above
cols_to_keep = ['color']
restrict_cols(cols_to_keep, flashes)

print(flashes)
index    start           end             color
0        1285.600869922  1285.851080039  -1.0
1        1287.602559922  1287.852767539  -1.0
2        1289.604229922  1289.854435039  -1.0
3        1291.605889922  1291.856100039  -1.0
4        1293.607609922  1293.857807539  1.0
5        1295.609249922  1295.859455039  -1.0
6        1297.610959922  1297.861155039  1.0
...      ...             ...             ...
143      1571.840009922  1572.090212539  -1.0
144      1573.841669922  1574.091877539  1.0
145      1575.843359922  1576.093562539  1.0
146      1577.845019922  1578.095227539  -1.0
147      1579.846709922  1580.096915039  1.0
148      1581.848389922  1582.098595039  1.0
149      1583.850039922  1584.100247539  -1.0
shape: (150, 2), time unit: sec.

Create an object for white and a separate object for black flashes

flashes_white = flashes[flashes["color"] == "1.0"]
flashes_black = flashes[flashes["color"] == "-1.0"]

Hide code cell source

def plot_stimuli():
    n_flashes = 5
    n_seconds = 13
    offset = .5

    start = data["flashes_presentations"]["start"].min() - offset
    end = start + n_seconds

    fig, ax = plt.subplots(figsize = (17, 4))
    [ax.axvspan(s, e, color = "silver", alpha=.4, ec="black") for s, e in zip(flashes_white[:n_flashes].start, flashes_white[:n_flashes].end)]
    [ax.axvspan(s, e, color = "black", alpha=.4, ec="black") for s, e in zip(flashes_black[:n_flashes].start, flashes_black[:n_flashes].end)]

    plt.xlabel("Time (s)")
    plt.ylabel("Absent = 0, Present = 1")
    ax.set_title("Stimuli presentation")
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))

    plt.xlim(start-.1,end)
    plt.show()

And we can plot the stimuli

plot_stimuli()
../_images/25e2694b05c5f8289b515ed5f670daf72ae399cc0a97ae928941583c5474bad3.png

To analyze how units’ activity evolves around the time of stimulus presentation, we can extend each stimulus interval to include some time before and after the flash. Specifically, we add 500 ms before the flash onset and 500 ms after the flash offset. This gives us a window that captures pre-stimulus baseline activity and any delayed neural responses.

dt = .50 # 500 ms
start = flashes.start - dt # Start 500 ms before stimulus presentation
end = flashes.end + dt # End 500 ms after stimulus presentation

extended_flashes = nap.IntervalSet(start,end, metadata=flashes.metadata) 
print(extended_flashes)
index    start           end             color
0        1285.100869922  1286.351080039  -1.0
1        1287.102559922  1288.352767539  -1.0
2        1289.104229922  1290.354435039  -1.0
3        1291.105889922  1292.356100039  -1.0
4        1293.107609922  1294.357807539  1.0
5        1295.109249922  1296.359455039  -1.0
6        1297.110959922  1298.361155039  1.0
...      ...             ...             ...
143      1571.340009922  1572.590212539  -1.0
144      1573.341669922  1574.591877539  1.0
145      1575.343359922  1576.593562539  1.0
146      1577.345019922  1578.595227539  -1.0
147      1579.346709922  1580.596915039  1.0
148      1581.348389922  1582.598595039  1.0
149      1583.350039922  1584.600247539  -1.0
shape: (150, 2), time unit: sec.

This extended IntervalSet, extended_flashes, will later allow us to restrict units’ activity to the periods surrounding each flash stimulus.

We now create one object for white and another for black extended flashes.

extended_flashes_white = extended_flashes[extended_flashes["color"]=="1.0"]
extended_flashes_black = extended_flashes[extended_flashes["color"]=="-1.0"]

Preprocessing spiking data#

There are multiple reasons for filtering units. Here, we will use four criteria: brain area, quality of units, firing rate and responsiveness

  1. Brain area: we are interested in analyzing VISp units for this tutorial

  2. Quality: we will only select “good” quality units

  3. Firing rate: overall, we want units with a firing rate larger than 2Hz around the presentation of stimuli

  4. Responsiveness: for the purposes of the tutorial, we will select the most responsive units (top 15%), and only use those for further analysis. We define responsiveness as the normalized difference between post stimulus and pre stimulus average firing rate.

# Filter units according criteria 1 & 2
units = units[
    (units["brain_area"]=="VISp") & 
    (units["quality"]=="good")
] 

# Restrict around stimuli presentation
units = units.restrict(extended_flashes) 

# Filter according to criterion 3
units = units[(units["rate"]>2.0)]

print(units)
Index      rate      quality    brain_area
---------  --------  ---------  ------------
951765440  2.32495   good       VISp
951765454  22.6523   good       VISp
951765460  2.29829   good       VISp
951765467  25.80912  good       VISp
951765485  22.96158  good       VISp
951765547  2.83687   good       VISp
951765552  5.69507   good       VISp
...        ...       ...        ...
951768823  5.40712   good       VISp
951768830  7.58276   good       VISp
951768835  5.60442   good       VISp
951768881  4.01534   good       VISp
951768894  4.2713    good       VISp
951769295  2.23963   good       VISp
951769344  2.91152   good       VISp

Now, to calculate responsiveness, we need to do some preprocessing to align units’ spiking timestamps with the onset of the stimulus repetitions, and then take an average over them. For this, we will use the compute_perievent function, which allows us to re-center time series and timestamps around particular events and compute spikes-triggered averages.

# Set window of perievent 500 ms before and after the start of the event
window_size = (-.250, .500) 

# Re-center timestamps for white stimuli
# +50 because we subtracted 500 ms at beginning of stimulus presentation
peri_white = nap.compute_perievent(timestamps = units,
                                        tref = nap.Ts(extended_flashes_white.start +.50), 
                                        minmax = window_size
)

# Re-center timestamps for black stimuli
# +50 because we subtracted 500 ms at beginning of stimulus presentation
peri_black = nap.compute_perievent(timestamps = units,
                                        tref = nap.Ts(extended_flashes_black.start +.50), 
                                        minmax = window_size
)

The output of the perievent is a dictionary of TsGroup objects, indexed by each unit ID.

When we index an element of this dictionary, we retrieve the spike times aligned to stimulus onset for a single unit, across all repetitions of the stimulus. These spike times are centered around the stimulus within the specified window_size. You’ll notice that the ref_times in the perievent output correspond exactly to the start times of the stimulus presentations!

# Let's select one unit
example_id = 951765485 
print(f"Number of trials: {len(peri_black[example_id])}\n ")

# And print it's rates
print(f"TsGroup of centered activity for unit {example_id}: \n {peri_black[example_id]}\n")

# Start times of black flashes presentation
print(f"black flashes start times: \n {flashes_black.starts}")
Number of trials: 75
 
TsGroup of centered activity for unit 951765485: 
 Index    rate      ref_times
-------  --------  -----------
0        18.66667  1285.6
1        18.66667  1287.6
2        12.0      1289.6
3        9.33333   1291.61
4        32.0      1295.61
5        20.0      1303.62
6        18.66667  1307.62
...      ...       ...
68       13.33333  1561.83
69       6.66667   1563.83
70       8.0       1565.83
71       12.0      1569.84
72       20.0      1571.84
73       12.0      1577.85
74       10.66667  1583.85

black flashes start times: 
 Time (s)
1285.600869922
1287.602559922
1289.604229922
1291.605889922
1295.609249922
1303.615919922
1307.619279922
...
1561.831629922
1563.833279922
1565.834999922
1569.838349922
1571.840009922
1577.845019922
1583.850039922
shape: 75

Let’s inspect a bit further our TsGroup objects with the centered spikes. If we grab the FIRST element of peri_black[example_id], we would get the spike times centered around the FIRST presentation of stimulus.

print(peri_black[example_id][0])
Time (s)
-0.128226254
-0.12362626
0.070806812
0.132140063
0.223106608
0.252739902
0.270706544
0.305906497
0.308506494
0.330173131
0.401073036
0.434739658
0.455372963
0.471839608
shape: 14

Negative spike times are expected here because the spike times in peri_white and peri_black are aligned relative to stimulus onset. A negative time means the spike occurred before the stimulus was presented, while a positive time indicates the spike occurred after stimulus onset. This alignment allows us to analyze how neuronal activity changes around the time of stimulus presentation.

We can also visualize these aligned spike times to better understand the timing and rate of neural responses relative to the stimulus. This type of plot is known as a Peristimulus Time Histogram (PSTH), and it shows how spiking activity is distributed around the stimulus onset. Let’s generate a PSTH for the first 9 units to explore their response patterns.

Hide code cell source

def plot_raster_psth(peri_color, units, color_flashes, bin_size, n_units = 9, smoothing=0.015):
    """
    Plot perievent time histograms (PSTHs) and raster plots for multiple units.

    Parameters:
    -----------
    peri_color : dict
        Dictionary mapping unit names to binned spike count peri-stimulus data (e.g., binned time series).
    units : dict
        Dictionary of neural units, e.g., spike trains or trial-aligned spike events.
    color_flashes : str
        A label indicating the flash color condition ('black' or other), used for visual styling.
    bin_size : float
        Size of the bin used for spike count computation (in seconds).
    smoothing : float
        Standard deviation for Gaussian smoothing of the PSTH traces.
    """

    # Layout setup: 9 columns (units), 2 rows (split vertically into PSTH and raster plot)
    n_cols = n_units
    n_rows = 2
    fig, ax = plt.subplots(n_rows, n_cols)
    fig.set_figheight(4)
    fig.set_figwidth(17)
    fig.tight_layout()

    # Use tab10 color palette for plotting different units
    colors = plt.cm.tab10.colors[:n_cols]

    # Extract unit names for iteration
    units_list = list(units.keys())

    start = 0
    end = int(n_rows / 2)  # Plot as many units as half the number of rows 
                            # each unit occupies 2 rows (one for psth and other for raster)

    for col in range(n_cols):
        for i, unit in enumerate(units_list[start:end]):
            u = peri_color[unit]
            line_color = colors[col]

            # Plot PSTH (smoothed firing rate)
            ax[2*i, col].plot(
                (np.mean(u.count(bin_size), 1) / bin_size).smooth(std=smoothing),
                linewidth=2,
                color=line_color
            )
            ax[2*i, col].axvline(0.0)  # Stimulus onset line
            span_color = "black" if color_flashes == "black" else "silver"
            ax[2*i, col].axvspan(0, 0.250, color=span_color, alpha=0.3, ec="black")  # Stim duration
            ax[2*i, col].set_xlim(-0.25, 0.50)
            ax[2*i, col].set_title(f'{unit}')

            # Plot raster
            ax[2*i+1, col].plot(u.to_tsd(), "|", markersize=1, color=line_color, mew=2)
            ax[2*i+1, col].axvline(0.0)
            ax[2*i+1, col].axvspan(0, 0.250, color=span_color, alpha=0.3, ec="black")
            ax[2*i+1, col].set_ylim(0, 75)
            ax[2*i+1, col].set_xlim(-0.25, 0.50)

            # Shift window for next units
            start += 1
            end += 1

    # Y-axis and title annotations
    ax[0, 0].set_ylabel("Rate (Hz)")
    ax[1, 0].set_ylabel("Trial")
    if n_rows > 2:
        ax[2, 0].set_ylabel("Rate (Hz)")
        ax[3, 0].set_ylabel("Trial")
    fig.text(0.5, 0.00, 'Time from stim(s)', ha='center')
    fig.text(0.5, 1.00, f'PSTH & Spike Raster Plot - {color_flashes} flashes', ha='center')
bin_size = 0.005 # Bin size

# Plot PSTH and spike raster plots
plot_raster_psth(peri_white, units, "white", bin_size)
plot_raster_psth(peri_black, units, "black", bin_size)
../_images/172db073bb7c56ae0534c39848ffc9338a7386abc2d2a8124c10f19575ba5895.png ../_images/6d66069510cb28660f20a358ceabc0de326a64ff71f9704b092110d474ce123a.png

In the plot above, we can see that some units (951765552, pink or 951765557, gray) are clearly more responsive than others (951765454, orange), which are apparently not modulated by the flashes. Thus, it would make sense to take a subset of the neurons, the most responsive ones, and model those.

We will now calculate responsiveness for each neuron as the normalized difference between average firing rate before and after stimulus presentation, and select the most responsive ones for further analyses.

Hide code cell source

def get_responsiveness(perievents, bin_size):
    """Calculate responsiveness for each neuron. This is
    computer as:

    post_presentation_avg  : 
        Average firing rate during presentation (250 ms) of stimulus across
        all instances of stimulus. 

    pre_presentation_avg :
        Average firing rate prior (250 ms) to the presentation of stimulus
        across all instances prior of stimulus. 

    responsiveness : 
        abs((post_presentation_avg - pre_presentation_avg) / (post_presentation_avg + pre_presentation_avg))

    Larger values indicate higher responsiveness to the stimuli.
        
    Parameters
    ----------
    perievents : TsGroup
        Contains perievent information of a subset of neurons
    bin_size : float
        Bin size for calculating spike counts

    Returns
    ----------   
    resp_array : np.array
        Array of responsiveness information.
    resp_dict : dict
        Dictionary of responsiveness information. Indexed by each neuron's,
        contains responsiveness, pre_presentation_avg and post_presentation_avg information

    """
    resp_dict = {}
    resp_array = np.zeros(len(perievents.keys()), dtype=float)

    for index,unit in enumerate(perievents.keys()):
        # Count the number of timestamps in each bin_size bin.
        peri_counts = perievents[unit].count(bin_size)

        # Get the firing rate
        peri_rate = peri_counts/bin_size

        # Compute average firing rate for each millisecond in the
        # the 250 ms before stimulus presentation
        pre_presentation = np.mean(peri_rate,1).restrict(nap.IntervalSet([-.25,0]))

        # Compute average firing rate for each millisecond in the
        # the 250 ms after stimulus presentation
        post_presentation = np.mean(peri_rate,1).restrict(nap.IntervalSet([0,.25]))

        pre_presentation_avg = np.mean(pre_presentation)
        post_presentation_avg = np.mean(post_presentation)
        responsiveness = abs((post_presentation_avg - pre_presentation_avg) / (post_presentation_avg + pre_presentation_avg))

        resp_dict[unit] = {
            "responsiveness": responsiveness,
            "pre_presentation_avg": pre_presentation_avg,
            "post_presentation_avg": post_presentation_avg,
        }
        resp_array[index] = responsiveness

    return resp_array, resp_dict
# Calculate responsiveness
responsiveness_white,_ = get_responsiveness(peri_white, bin_size)
responsiveness_black,_ = get_responsiveness(peri_black, bin_size)

# Add responsiveness as metadata for units
units.set_info(responsiveness_white=responsiveness_white)
units.set_info(responsiveness_black=responsiveness_black)

# See metadata
print(units)
Index      rate      quality    brain_area    responsiveness_white    responsiveness_black
---------  --------  ---------  ------------  ----------------------  ----------------------
951765440  2.32495   good       VISp          0.14                    0.28
951765454  22.6523   good       VISp          0.02                    0.02
951765460  2.29829   good       VISp          0.18                    0.0
951765467  25.80912  good       VISp          0.17                    0.22
951765485  22.96158  good       VISp          0.15                    0.39
951765547  2.83687   good       VISp          0.39                    0.0
951765552  5.69507   good       VISp          0.75                    0.81
...        ...       ...        ...           ...                     ...
951768823  5.40712   good       VISp          0.03                    0.4
951768830  7.58276   good       VISp          0.15                    0.47
951768835  5.60442   good       VISp          0.07                    0.27
951768881  4.01534   good       VISp          0.5                     0.27
951768894  4.2713    good       VISp          0.09                    0.46
951769295  2.23963   good       VISp          0.2                     0.45
951769344  2.91152   good       VISp          0.68                    0.92

Now I can keep the top 15% most responsive units for ongoing analyses.

# Get threshold for top 15% most responsive
thresh_black = np.percentile(units["responsiveness_black"], 85)
thresh_white = np.percentile(units["responsiveness_white"], 85)

# Only keep units that are within the 15% most responsive for either black or white
units = units[(units["responsiveness_black"] > thresh_black) | (units["responsiveness_white"] > thresh_white)]
print(units)
print(f"\nRemaining units: {len(units)}")
Index      rate     quality    brain_area    responsiveness_white    responsiveness_black
---------  -------  ---------  ------------  ----------------------  ----------------------
951765552  5.69507  good       VISp          0.75                    0.81
951765732  4.31929  good       VISp          0.9                     1.0
951768154  2.82087  good       VISp          0.64                    0.89
951768278  3.43944  good       VISp          0.5                     0.66
951768285  4.06867  good       VISp          0.91                    0.3
951768291  2.77821  good       VISp          0.99                    0.5
951768307  2.32495  good       VISp          0.88                    0.5
...        ...      ...        ...           ...                     ...
951768586  2.82087  good       VISp          0.59                    0.52
951768621  3.41278  good       VISp          0.35                    0.59
951768632  5.93503  good       VISp          0.16                    0.79
951768749  2.04767  good       VISp          0.94                    0.91
951768754  2.76755  good       VISp          0.78                    1.0
951768815  2.23963  good       VISp          0.59                    0.76
951769344  2.91152  good       VISp          0.68                    0.92

Remaining units: 19

Revision of stimuli and spiking data#

Now that we have selected the units we will use for our analyses, we can see how these look alongside the stimuli in a raster plot:

Hide code cell source

def raster_plot(data, units, n_neurons=len(units), n_flashes=5, n_seconds=13, offset=0.5):
    """
    Plot a raster of spiking activity for a subset of neurons during the initial stimulus presentations.

    This function visualizes spike times as a raster plot for a selected number of neurons
    over a specified time window. The spikes are aligned with stimulus presentations, and 
    flashes of different types (white or black) are shown as shaded regions in the background.

    Parameters
    ----------
    data : nap.TsGroup
        A `TsGroup` object which contains all info of all units' spikes without any filtering
    units : nap.TsdFrame
        A `TsdFrame` object where each column corresponds to the spike train of a neuron. It contains only the filtered neurons after preprocessing
    n_neurons : int, optional
        Number of neurons to include in the plot. Defaults to the total number of units.
    n_flashes : int, optional
        Number of white and black flash presentations to show. Defaults to 5.
    n_seconds : float, optional
        Total duration (in seconds) of the time window to display. Defaults to 13 seconds.
    offset : float, optional
        Time (in seconds) to start plotting before the first flash. Useful to visualize pre-stimulus activity.
        Defaults to 0.5 seconds.

    Returns
    -------
    None
        Displays a raster plot using `matplotlib`.

    Notes
    -----
    - The flashes are overlaid using `axvspan`, with white flashes in light gray and black flashes in black.
    - Each spike is drawn as a vertical line ('|') at its timestamp.
    - Spike times are converted to a single `Tsd` object with unique identifiers for each unit to facilitate plotting.
    """
    start = data["flashes_presentations"]["start"].min() - offset
    end = start + n_seconds

    # Select full spiking activity from units without restriction
    units_ = data["units"] 

    # Select a subset from those - the filtered ones
    units = units_[units.index]

    # Restrict the spike trains to the selected time window
    units = units.restrict(nap.IntervalSet(start, end))

    # Convert spike trains to a single Tsd object and label each neuron for plotting
    neurons_to_plot = units.to_tsd([i + 1 for i in range(n_neurons)])

    fig, ax = plt.subplots(figsize=(17, 4))
    ax.plot(neurons_to_plot, "|", markersize=2, mew=2)

    # Overlay white flashes
    for s, e in zip(flashes_white[:n_flashes].start, flashes_white[:n_flashes].end):
        ax.axvspan(s - .50,  e + .50, color="red", alpha=0.03, ec="black")
        ax.axvspan(s, e, color="silver", alpha=0.4, ec="black")
        ax.axvline(s - .50, color = "red")
        ax.axvline(e + .50, color = "red")

    # Overlay black flashes
    for s, e in zip(flashes_black[:n_flashes].start, flashes_black[:n_flashes].end):
        ax.axvspan(s - .50,  e + .50, color="red", alpha=0.03, ec="black")
        ax.axvspan(s, e, color="black", alpha=0.4, ec="black")
        ax.axvline(s - .50, color = "red")
        ax.axvline(e + .50, color = "red")

    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Unit")
    ax.set_title("Primary Visual Cortex (VISp) units spikes and stimuli")
    ax.set_xlim(start, end)

    plt.show()
raster_plot(data, units)
../_images/0e0fc259c8cf153a1a4f565ec5c71b00434e5fea9e0070044df4b83c6b48015a.png

Above we can see a spike raster plot from the selected VISp units displayed alongside the black and white flashes presented to the mice. Each row represents spike times from a different unit. Black and silver bars indicate the presentation of black and white flashes, respectively. The bright red vertical lines mark the windows of interest, 500 ms before stimulus onset and 500 ms after stimulus offset, which are shaded in light red. Spikes occurring within these windows (the shaded red areas) are the ones that will be used for model fitting.

We can also look with each unit’s activity centered around the flashes presentation with a PSTH, as we did before.

# Get the perievent for a subset of the units (most responsive ones)
peri_white = {k: peri_white[k] for k in units.index if k in peri_white}
peri_black = {k: peri_black[k] for k in units.index if k in peri_black}

params_obs = [peri_white, 
              peri_black]

# Plot PSTH and spike raster plots
plot_raster_psth(peri_white, units, "white", bin_size)
plot_raster_psth(peri_black, units, "black", bin_size)
../_images/37170405e8e80664d03063afa463640e3117122543bfee66f64a62c2ff00e332.png ../_images/74e89f0d89b10136cc882b93863197b18818575031ac5d81abbc9ce6528bb10f.png

Splitting the dataset in train and test#

We could train the model on the entire dataset. However, if we do so, we wouldn’t have a way to assess whether the model is truly capable of making accurate predictions or if it’s simply overfitting to the data. The simplest way around this is to have a reserved part of the data for testing.

Here, we will split the data in two: 70% will be for training and 30% will be for testing. However, we can’t simply grab the first 70% timeseries - what if we are biasing our sample and there are some neurons that respond only towards the end or the beginning of the recording? For that, we will gather one every three flashes, and those will go to the testing set. The rest, will go to the training set.

# We take one every three flashes (33% of all flashes of test)
flashes_test_white = extended_flashes_white[::3]
flashes_test_black = extended_flashes_black[::3]

Pynapple has a nice function to get all the epochs: set_diff. With it, we can get all of the interval sets which are not in the interval set passed as argument.

# The remaining is separated for training
flashes_train_white = extended_flashes_white.set_diff(flashes_test_white)
flashes_train_black = extended_flashes_black.set_diff(flashes_test_black)

Consider black and white for test and train using union

# Merge both stimuli types in a single interval set
flashes_test = flashes_test_white.union(flashes_test_black)
flashes_train = flashes_train_white.union(flashes_train_black)

Now that we have our intervals correctly, we can use restrict to get our test and train sets for units

# General spike counts
units_counts = units.count(bin_size, ep = extended_flashes)

# Restrict counts to test and train
units_counts_test = units_counts.restrict(flashes_test)
units_counts_train = units_counts.restrict(flashes_train)

Fitting a GLM#

Preparing the data for NeMoS#

Now that we have a good understanding of our data, and that we have split our dataset in the corresponding test and train subsets, we are almost ready to run our model. However, before we can construct it, we need to get our data in the right format.

When fitting a single neuron, NeMoS requires that the predictors and spike counts it operates on have the following properties:

  • predictors and spike counts must have the same number of time points.

  • predictors must be two-dimensional, with shape (n_time_bins, n_features). So far, we have two features in this tutorial: white and black flashes.

  • spike counts must be one-dimensional, with shape (n_time_bins,).

  • predictors and spike counts must be jax.numpy arrays, numpy arrays, Tsd or TsdFrame.

When fitting multiple neurons, spike counts must be two-dimensional: (n_time_bins, n_neurons). In that case, spike can be TsGroup objects as well.

First, we can make sure that our predictors and our spike counts have the same number of time bins.

# Create a TsdFrame filled by zeros, for the size of units_counts
predictors = nap.TsdFrame(
    t=units_counts.t,
    d=np.zeros((len(units_counts), 2)), 
    columns = ['white', 'black']
)

At the moment, the flashes are in a IntervalSet, we need to grab them and make them time series of stimuli, separated by black and white (because we are interested in how neurons’ responses are modulated by each individual stimulus type separately).

# Check whether there is a flash within a given bin of spikes
# If there is not, put a nan in that index
idx_white = flashes_white.in_interval(units_counts)
idx_black = flashes_black.in_interval(units_counts)

# Replace everything that is not nan with 1 in the corresponding column
predictors.d[~np.isnan(idx_white), 0] = 1
predictors.d[~np.isnan(idx_black), 1] = 1

print(predictors)
Time (s)        white    black
--------------  -------  -------
1285.103369922  0.0      0.0
1285.108369922  0.0      0.0
1285.113369922  0.0      0.0
1285.118369922  0.0      0.0
1285.123369922  0.0      0.0
1285.128369922  0.0      0.0
1285.133369922  0.0      0.0
...             ...      ...
1584.567539922  0.0      0.0
1584.572539922  0.0      0.0
1584.577539922  0.0      0.0
1584.582539922  0.0      0.0
1584.587539922  0.0      0.0
1584.592539922  0.0      0.0
1584.597539922  0.0      0.0
dtype: float64, shape: (37500, 2)

predictors and units_counts match in the first dimension because they have the same number of timepoints, as intended. Meanwhile, in the second dimension, predictors is 2 because we have black and white flashes, and counts has 19 because the selected units for this tutorial sums to 19.

print(f"predictors shape: {predictors.shape}")
print(f"\ncount shape: {units_counts.shape}")

Just to make sure that we got the right output, let’s plot our new predictors TsdFrame as lines alongside our first plot.

Hide code cell source

def stimuli_plot(predictors, n_flashes = 5, n_seconds = 13, offset = .5):
    n_flashes = 5
    n_seconds = 13
    offset = .5

    # Start a little bit earlier than the first flash presentation
    start = data["flashes_presentations"]["start"].min() - offset 
    end = start + n_seconds

    fig, ax = plt.subplots(figsize = (17, 4))

    # Different coloured flashes
    [ax.axvspan(s, e, color = "silver", alpha=.4, ec="black") for s, e in zip(flashes_white[:n_flashes].start, flashes_white[:n_flashes].end)]
    [ax.axvspan(s, e, color = "black", alpha=.4, ec="black") for s, e in zip(flashes_black[:n_flashes].start, flashes_black[:n_flashes].end)]
    plt.plot(predictors["white"], "o-", color= "silver")
    plt.plot(predictors["black"], "o-", color= "black")
    
    plt.xlabel("Time (s)")
    plt.ylabel("Absent = 0, Present = 1")
    ax.set_title("Presented Stimuli")
    
    plt.xlim(start,end) 

    # Only use integer values for ticks
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))

    plt.show()
stimuli_plot(predictors)
../_images/c59df3369701f2e6355c0ff625b4b80d4e422ac8d2c53fa04fa7f5e433165838.png

They match perfectly!

As a last preprocessing step, let’s just split predictors in train and test.

predictors_test = predictors.restrict(flashes_test)
predictors_train = predictors.restrict(flashes_train)

Constructing the design matrix using basis functions#

Right now, our predictors consist of the black and white flash values at each time point. However, this setup assumes that the neuron’s spiking behavior is driven only by the instantaneous flash presentation. In reality, neurons integrate information over time — so why not modify our predictors to reflect that?

We can achieve this by including variables that represent the history of exposure to the flashes. For this, we must decide the duration of time that we think is relevant: does the exposure to flashes 10 ms ago matter? What about 100 ms ago? 1s? We should use priori knowledge of our system to determine a initial value.

For this tutorial, we will use the whole duration of the stimuli as relevant history. That is, we will model each unit’s response to 250 ms full-field flashes by capturing how stimulus history over that duration influences spiking. We therefore define a 250 ms stimulus window, aligned with the flash onset, which spans the entire stimulus duration. This window enables the GLM to learn how the neuron’s firing rate evolves throughout the flash. Using a shorter window could omit delayed effects, while a longer window may incorporate unrelated post-stimulus activity.

To construct our stimulus history predictor, we could generate time-lagged copies of the stimulus input (in the form of a Hankel Matrix). Specifically, the value of the first predictor at time \( t \) would correspond to the stimulus at time \( t \), while the second predictor would capture the stimulus at time \( t - 1 \) , and so on, up to a maximum lag corresponding to the length of the stimulus integration window (250 ms in our case).

How do you build a Hankel matrix?

Every row is a shifted copy of the row above!

hankel_matrix

Construction of Hankel Matrix. Modified from Pillow [2018] [3].

For an example on how to build a design matrix using the raw history as a predictor, see this GLM notebook or this NeMoS Fit GLMs for neural coupling tutorial.

However, modeling each time lag with an independent parameter leads to a high-dimensional filter that is prone to overfitting (given that we are using a bin size of 0.005, we would end up with 50 lags = 50 parameters per flash color!) A better idea is to do some dimensionality reduction on these predictors, by parametrizing them using basis functions. This will allow us to capture interesting non-linear effects with a relatively low-dimensional parametrization that preserves convexity.

The way you perform this dimensionality reduction should be carefully considered. Choosing the appropriate type of basis functions, deciding how many to include, and setting their parameters all depend on the specifics of your problem. It’s essential to reflect on which aspects of the stimulus history are worth retaining and how best to represent them. For instance, do you expect sharp transient responses right after stimulus onset? Or are you more interested in slower, sustained effects?

Note

NeMoS has a whole library of basis objects available at nmo.basis. You can explore the different options available there and think carefully about which one best matches your problem and assumptions.

Certain aspects of our units’ response dynamics suggest that applying multiple basis transformations to the stimulus can help better capture the structure that drives the units’ activity. In particular, some neurons exhibit a strong response immediately after flash onset, while others (or the same unit) show an additional peak in activity at flash offset.

Hide code cell source

# Plot perievent for a single unit as an example
def plot_peri_single_unit(peri_white, peri_black, unit_id):

    fig, ax = plt.subplots(1,2,figsize=(17, 4), sharey=True)
    ### white
    # observed
    peri_u = peri_white[unit_id]
    peri_u_count = peri_u.count(bin_size)

    peri_u_count_conv_mean = np.mean(peri_u_count, 1).smooth(std=0.015)
    peri_u_rate_conv = peri_u_count_conv_mean / bin_size

    ax[0].plot(peri_u_rate_conv, linewidth=2, color="black")
    ax[0].axvspan(0, 0.250, color="silver", alpha=0.3, ec="black")
    ax[0].set_xlim(-.25, .5)

    ax[0].set_title("White flashes")

    #### black
    # observed
    peri_u = peri_black[unit_id]
    peri_u_count = peri_u.count(bin_size)

    peri_u_count_conv_mean = np.mean(peri_u_count, 1).smooth(std=0.015)
    peri_u_rate_conv = peri_u_count_conv_mean / bin_size
    ax[1].plot(peri_u_rate_conv, linewidth=2, color="black", label = "observed")
    ax[1].axvspan(0, 0.250, color="black", alpha=0.3, ec="black")
    ax[1].set_xlim(-.25, .5)

    ax[1].set_title("Black flashes")
    ax[0].set_ylabel("Rate (Hz)")

    fig.text(0.5, -.05, 'Time from stim(s)', ha='center')
    fig.text(0.5, .95, f'PSTH unit {unit_id}', ha='center')
    plt.show()

# Choose an example unit
unit_id = 951768318

plot_peri_single_unit(peri_white, peri_black, unit_id)
../_images/5e8296fef662fd3ee48b91ab73475fa7221d7b7ea58a1552a7f71de81b27f88f.png

These distinct temporal features motivate us to model the onset and offset components separately, using tailored basis functions for each. Considering that, we will apply three distinct transformations to our predictors, each targeting a specific portion or feature of the stimulus:

  1. Flash onset (beginning): We will convolve the early phase of the flash presentation with a basis function. This allows for fine temporal resolution immediately after stimulus onset, where rapid neural responses are often expected.

  2. Flash offset (end): We will convolve the later phase of the flash (around its end) with a different basis function. This emphasizes activity changes around stimulus termination.

  3. Full flash duration (smoothing): We will convolve the entire flash period with a third basis function, serving as a smoother to capture more sustained or slowly varying effects across the full stimulus window.

First, we will transform our predictors to get flash onset and offset. We should do this for train

white_train_on = nap.Tsd(
    t=predictors_train.t, 
    d=np.hstack((0,np.diff(predictors_train["white"])==1)),
    time_support = units_counts_train.time_support
)

white_train_off = nap.Tsd(
    t=predictors_train.t, 
    d=np.hstack((0,np.diff(predictors_train["white"])==-1)),
    time_support = units_counts_train.time_support
)

# Black train
black_train_on = nap.Tsd(
    t=predictors_train.t, 
    d=np.hstack((0,np.diff(predictors_train["black"])==1)),
    time_support = units_counts_train.time_support
)
black_train_off = nap.Tsd(
    t=predictors_train.t, 
    d=np.hstack((0,np.diff(predictors_train["black"])==-1)),
    time_support = units_counts_train.time_support
)

and test set.

# White test
white_test_on = nap.Tsd(
    t=predictors_test.t, 
    d=np.hstack((0,np.diff(predictors_test["white"])==1)),
    time_support=units_counts_test.time_support
    
)
white_test_off = nap.Tsd(
    t=predictors_test.t, 
    d=np.hstack((0,np.diff(predictors_test["white"])==-1)),
    time_support=units_counts_test.time_support
)

# Black test
black_test_on = nap.Tsd(
    t=predictors_test.t, 
    d=np.hstack((0,np.diff(predictors_test["black"])==1)),
    time_support=units_counts_test.time_support
)
black_test_off = nap.Tsd(
    t=predictors_test.t, 
    d=np.hstack((0,np.diff(predictors_test["black"])==-1)),
    time_support=units_counts_test.time_support
)

We now have our predictors, it’s time to choose which basis functions are the most suitable for our ends.

For history-type inputs like we’re discussing, the raised cosine log-stretched basis first described in Pillow et al. [2005] [4] is a good fit. This basis set has the nice property that their precision drops linearly with distance from event, which makes sense for many history-related inputs in neuroscience: whether an input happened 1 or 5 ms ago matters a lot, whereas whether an input happened 51 or 55 ms ago is less important. We will apply this basis function to the beginning of the flash.

Hide code cell source

# Duration of stimuli
stimulus_history_duration = 0.25 

# Window length in bin size units
window_len = int(stimulus_history_duration / bin_size)

basis_example = nmo.basis.RaisedCosineLogConv(
    n_basis_funcs = 5, 
    window_size = window_len, 
)
sample_points, basis_values = basis_example.evaluate_on_grid(100)

fig, ax = plt.subplots(figsize=(5, 3))
ax.plot(sample_points, basis_values)
ax.set_title("Raised cosine log-stretched basis")
ax.set_ylabel("Amplitude")
ax.set_xlabel("Time Lag")
plt.show()
../_images/1303067481b9e4d2af916f0accb2296800c66a97db602675a696b38c1d7321fc.png

Another very useful transformation we can apply to our predictors is that of the raised cosine linearly spaced basis, in which the domain is uniformly covered. This is an interesting basis because it is symmetric. We will apply this to the end of the flash.

Hide code cell source

# Duration of stimuli
stimulus_history_duration = 0.25 

# Window length in bin size units
window_len = int(stimulus_history_duration / bin_size)

basis_example = nmo.basis.RaisedCosineLinearConv(
    n_basis_funcs = 5, 
    window_size = window_len, 
)
sample_points, basis_values = basis_example.evaluate_on_grid(100)

fig, ax = plt.subplots(figsize=(5, 3))
ax.plot(sample_points, basis_values)
ax.set_title("Raised cosine linear basis")
ax.set_ylabel("Amplitude")
ax.set_xlabel("Time Lag")
plt.show()
../_images/8e92b6a95e4e6188f2139f99ecb09ac5aef4748f6a93958990dbe52b1d323186.png

To see how these look convolved with the stimuli, let’s create our basis objects!

When we instantiate a basis object, the only arguments we must specify is the number of functions we want and the mode of operation of the basis:

  • Number of functions: with more basis functions, we’ll be able to represent the effect of the corresponding input with the higher precision, at the cost of adding additional parameters.

  • Mode of operation: either Conv for convolutional or Evalfor evaluation form of the basis. This is determined by the type of feature we aim to represent. This is not a parameter; instead, the choice of basis will include Conv or Eval in the name.

When should I use the convolutional or evaluation form of the basis?

  • Evaluation bases transform the input through the non-linear function defined by the basis. This can be used to represent features such as spatial location and head direction.

  • Convolution bases apply a convolution of the input data to the bank of filters defined by the basis, and is particularly useful when analyzing data with inherent temporal dependencies, such as spike history or the history of flash exposure in this example. In convolution mode, we must additionally specify the window_size (the length of the filters in bins).

Since we are using Convolution bases, we need to specify the window_size. In this tutorial, we will use 250 ms.

# Duration of stimuli
stimulus_history_duration = 0.250

# Window length in bin size units
window_len = int(stimulus_history_duration / bin_size)

Now we can initialize our basis objects. As mentioned, for each flash type (white and black), we will create three separate basis objects: one for the onset of the flash, one for the offset, and one that spans the entire duration of the flash. In this tutorial, each basis object will have 5 basis functions.

# Initialize basis objects
# White
# Raised Cosine Log Stretched basis for "On"
basis_white_on = nmo.basis.RaisedCosineLogConv(
    n_basis_funcs = 5, 
    window_size = window_len, 
    label = "white_on"
)

# Raised Cosine Linear basis for "Off"
basis_white_off = nmo.basis.RaisedCosineLinearConv(
    n_basis_funcs = 5, 
    window_size = window_len, 
    label = "white_off",
    conv_kwargs = {"predictor_causality":"acausal"}
)

# Raised Cosine Log Stretched basis for smoothing throughout stimuli presentaiton
basis_white_stim= nmo.basis.RaisedCosineLogConv(
    n_basis_funcs = 5, 
    window_size = window_len, 
    label = "white_stim"
)

# Black
# Raised Cosine Log Stretched basis for "On"
basis_black_on = nmo.basis.RaisedCosineLogConv(
    n_basis_funcs = 5, 
    window_size = window_len, 
    label = "black_on"
)

# Raised Cosine Linear basis for "Off"
basis_black_off = nmo.basis.RaisedCosineLinearConv(
    n_basis_funcs = 5, 
    window_size = window_len, 
    label = "black_off",
    conv_kwargs = {"predictor_causality":"acausal"}
)

# Raised Cosine Log Stretched basis for smoothing throughout stimuli presentaiton
basis_black_stim = nmo.basis.RaisedCosineLogConv(
    n_basis_funcs = 5, 
    window_size = window_len, 
    label = "black_stim"
)

Using the compute_features function, NeMoS convolves our input features (predictors) with the basis object to compress them. Let’s see how that looks!

Hide code cell source

def plot_basis_feature_summary(
    basis_object,
    predictors,
    interval,
    label,
    window_len,
    title
):
    """
    Plot summary of basis functions and convolved features for a given input.

    Parameters
    ----------
    basis_object : object
        A basis object implementing `compute_features()` and `evaluate_on_grid()`.
    predictors : Tsd or TsdFrame
        Time series of stimulus predictors to convolve.
    interval : nap.IntervalSet
        Interval to restrict the predictors and features to.
    label : str
        Label for the raw stimulus trace (e.g. "Flash").
    window_len : float
        Duration of the window used to scale the basis time axis.
    title : str
        Title to display above the figure.
    """

    # Compute features
    features = basis_object.compute_features(predictors).restrict(interval)

    # Restrict raw stimulus as well
    restricted_input = predictors.restrict(interval)

    # Create time axis for basis
    time, basis = basis_object.evaluate_on_grid(basis_object.window_size)
    time *= window_len

    # Initialize plot
    fig, axes = plt.subplots(2, 3, sharey="row", figsize=(12, 2.5), tight_layout=True)

    # Plot raw predictors
    for ax in axes[1, :]:
        ax.plot(restricted_input, "k--", label="true")
        ax.set_xlabel("Time (s)")

    # Plot first basis function and its feature
    axes[0, 0].plot(time, basis, alpha=0.1)
    axes[0, 0].plot(time, basis[:, 0], "C0", alpha=1)
    axes[0, 0].set_xticks([])
    axes[0, 0].set_ylabel("Amp.")
    axes[0, 0].set_yticks([])
    axes[0, 0].set_title("Feature 1")

    axes[1, 0].plot(features[:, 0], label="conv.")
    axes[1, 0].set_ylabel(label)
    axes[1, 0].set_yticks([])

    # Plot last basis function and its feature
    last_idx = basis.shape[1] - 1
    color = f"C{last_idx}"
    axes[0, 1].plot(time, basis[:, last_idx], color, alpha=1)
    axes[0, 1].plot(time, basis, alpha=0.1)
    axes[0, 1].set_xticks([])
    axes[0, 1].set_title(f"Feature {basis.shape[1]}")

    axes[1, 1].plot(features[:, -1], color)

    # Plot all basis functions and features
    axes[0, 2].plot(time, basis)
    axes[0, 2].set_xticks([])
    axes[0, 2].set_title("All features")
    axes[1, 2].plot(features)
    
    fig.text(0.5, 1.00, title, ha='center')
    plt.show()
interval = flashes_train_white[0]

plot_basis_feature_summary(
    basis_white_on,
    white_train_on,
    interval,
    label="Flash",
    window_len=window_len,
    title="Flashes On - Raised Cosine Log-Stretched Conv"
)
../_images/a719784555cb5aded9c42090ef1129465b2d5e4ba21afbc7d254636c30c0f194.png

On the top row, we can see the basis function, same as in the plot “Raised cosine log-stretched basis” above. On the bottom row, we are showing the beginning of one flash presentation, as a dashed line, and corresponding features over a small window of time. These features are the result of a convolution between the basis function on the top row with the black dashed line showed below. The basis functions get progressively wider and delayed from the flash onset, so we can think of the features as weighted averages that get progressively later and smoother.

In the leftmost plot, we can see that the first feature almost perfectly tracks the input. Looking at the basis function above, that makes sense: this function’s max is at 0 and quickly decays. In the middle plot, we can see that the last feature has a fairly long lag compared to the flash beginning, and is a lot smoother. Looking at the rightmost plot, we can see that the other features vary between these two extremes, getting smoother and more delayed.

Now let’s see how our convolved features look for the basis for a instance of full flash duration:

plot_basis_feature_summary(
    basis_white_stim,
    predictors_train["white"],
    interval,
    label="Flash",
    window_len=window_len,
    title="Flash Presentation - Raised Cosine Log-Stretched Conv"
)
../_images/a9d2f0ec83b01a40c623216df3f0beeac31a6b56d9768149201be96d5c6e64de.png

This is very similar to the Flashes On convolution, just a bit wider, as the duration of the flash is longer than a single instance of initiation of flash.

Finally, let’s see how our Raised Cosine Linear Conv basis is transforming our Flashes Off predictor.

plot_basis_feature_summary(
    basis_white_off,
    white_train_off,
    interval,
    label="Flash",
    window_len=window_len,
    title="Flashes Off - Raised Cosine Linear Conv"
)
../_images/7dbed380a062d1008543bcb924a116b0128fad91700943e08efc1dd1b1a59515.png

This basis might look a bit different, and that’s because we’re using the "acausal"setting for the "predictor_causality" option. In this mode, the center of the convolution is aligned with the end of the flash, rather than strictly following the stimulus forward in time.

This acausal alignment allows the model to capture changes in firing rate that occur both just before and after the flash ends. This is particularly useful for smoothing transitions between basis-driven components: it helps avoid abrupt or artificial jumps in the predicted firing rate at stimulus offset. Instead, we can interpolate more smoothly across time, producing more interpretable predictions.


These are the elements of our feature matrix: representations of not just the instantaneous presentation of a flash, but also of its history. Let’s see what this looks like when we go to fit the model!

In our case, we want our basis to be composed by both black and white flashes features. For that, we can build an additive basis. This will concatenate our already declared basis objects.

# Define additive basis object
additive_basis = (
    basis_white_on + 
    basis_white_off + 
    basis_white_stim + 
    basis_black_on + 
    basis_black_off + 
    basis_black_stim
)

We can convolve our predictors with each basis within our additive basis by calling compute_features.

# Convolve basis with inputs - train set
X_train = additive_basis.compute_features(
    white_train_on,
    white_train_off,
    nap.Tsd(t= white_train_on.t,d=predictors_train["white"], time_support=units_counts_train.time_support),
    black_train_on,
    black_train_off,
    nap.Tsd(t= black_train_on.t,d=predictors_train["black"], time_support=units_counts_train.time_support)
)

# Convolve basis with inputs - test set
X_test = additive_basis.compute_features(
    white_test_on,
    white_test_off,
    nap.Tsd(t= white_test_on.t,d=predictors_test["white"], time_support=units_counts_test.time_support),
    black_test_on,
    black_test_off,
    nap.Tsd(t= black_test_on.t,d=predictors_test["black"], time_support=units_counts_test.time_support)
)

Initialize and fit a GLM: single unit#

Now we are finally ready to start our model!

First, we need to define our GLM model object. To initialize our model, we need to specify the solver_name, the regularizer, the regularizer_strength and the observation_model. All of these are optional.

  • solver_name: this string specifies the solver algorithm. The default behavior depends on the regularizer, as each regularization scheme is only compatible with a subset of possible solvers.

  • regularizer: this string or object specifies the regularization scheme. Regularization modifies the objective function to reflect your prior beliefs about the parameters, such as sparsity. Regularization becomes more important as the number of input features, and thus model parameters, grows. NeMoS’s solvers can be found within the nemos.regularizer module.

  • observation_model: this object links the firing rate and the observed data (in this case spikes), describing the distribution of neural activity (and thus changing the log-likelihood). For spiking data, we use the Poisson observation model.

For this tutorial, we’ll use a LBFGS solver_name with Ridge regularizer, and a regularizer_strength of 7.745e-06

regularizer_strength = 7.745e-06
# Initialize model object of a single unit
model = nmo.glm.GLM(
    regularizer = "Ridge",
    regularizer_strength = regularizer_strength,
    solver_name="LBFGS", 
)

First let’s choose an example unit to fit.

# Choose an example unit
unit_id = 951768318

# Get counts for train and test for said unit
u_counts_train = units_counts_train.loc[unit_id]
u_counts_test = units_counts_test.loc[unit_id]

NeMoS models are intended to be used like scikit-learn estimators. In these, a model instance is initialized with hyperparameters (like regularization strength, solver, etc), and then we can call the fit() function to fit the model to data. Since we have already created our model and have our data, we can go ahead and call fit().

model.fit(X_train, u_counts_train)
GLM(
    observation_model=PoissonObservations(inverse_link_function=exp),
    regularizer=Ridge(),
    regularizer_strength=7.745e-06,
    solver_name='LBFGS'
)

Now that we have fit our data, we can retrieve the resulting parameters. Similar to scikit-learn , these are stored as the coef_.and intercept_ attributes:

print(f"firing_rate(t) = exp({model.coef_} * flash(t) + {model.intercept_})")
firing_rate(t) = exp([ 0.42339513 -0.14622475 -1.6153557   2.4666765   2.4574258   1.0377612
  0.8538695   1.4687707   2.131266   -0.04929369  0.4735016   0.44654387
 -0.99847394  0.31412154 -0.1562404   0.73338383  0.03729352 -1.1021913
  1.0413783   2.3387983  -0.02767342  1.3869107  -0.33982408  0.52867544
  0.01454498  0.01976705  0.08946032 -0.17806664 -0.00803406  0.08180097] * flash(t) + [-3.94904])

Note

Note that model.coef_ has shape (n_features, ), while model.intercept_ has shape (n_neurons):

print(model.coef_.shape)
>>> (30,)

print(model.intercept_.shape)
>>> (1,)

Assess GLM performance: predict and PSTH#

Although it is helpful to examine the model parameters, they don’t tell us much about how well the model is performing. So how can we assess its quality?

One way is to use the model to predict firing rates and compare those predictions to the smoothed spike train. By calling predict, we obtain the model’s predicted firing rate for the input data — that is, the output of the nonlinearity.

# Use predict to obtain the firing rates
pred_unit = model.predict(X_test)

# Convert units from spikes/bin to spikes/sec
pred_unit = pred_unit/ bin_size

print(pred_unit)
Time (s)
--------------  ---
1285.103369922  nan
1285.108369922  nan
1285.113369922  nan
1285.118369922  nan
1285.123369922  nan
1285.128369922  nan
1285.133369922  nan
...
1576.560859922  nan
1576.565859922  nan
1576.570859922  nan
1576.575859922  nan
1576.580859922  nan
1576.585859922  nan
1576.590859922  nan
dtype: float32, shape: (12500,)

Now, we can use Pynapple function compute_perievent_continuous to re-center the timestamps of the predicted rates around the stimulus presentations, in a similar manner than at the beginning of the tutorial. In contrast to compute_perievent, compute_perievent_continuous allows us to center a continuous time series.

We re-center the timestamps in the same way as we did at the beginning of the tutorial.

# Re-center timestamps around white stimuli
# +50 because we subtracted .50 at beginning of stimulus presentation
peri_white_pred_unit = nap.compute_perievent_continuous(
    timeseries = pred_unit, 
    tref = nap.Ts(flashes_test_white.start+.50),
    minmax=window_size
)  
# Re-center timestamps for black stimuli
# +50 because we subtracted .50 at beginning of stimulus presentation
peri_black_pred_unit = nap.compute_perievent_continuous(
    timeseries = pred_unit, 
    tref = nap.Ts(flashes_test_black.start+.50), 
    minmax=window_size
)  

# Print centered spikes
print(peri_white_pred_unit)
Time (s)    0        1        2        3        4        ...
----------  -------  -------  -------  -------  -------  -----
-0.25       nan      nan      3.85464  nan      nan      ...
-0.245      3.85464  3.85464  3.85464  3.85464  3.85464  ...
-0.24       3.85464  3.85464  3.85464  3.85464  3.85464  ...
-0.235      3.85464  3.85464  3.85464  3.85464  3.85464  ...
-0.23       3.85464  3.85464  3.85464  3.85464  3.85464  ...
-0.225      3.85464  3.85464  3.85464  3.85464  3.85464  ...
-0.22       3.85464  3.85464  3.85464  3.85464  3.85464  ...
...         ...      ...      ...      ...      ...      ...
0.47        3.82009  3.82009  3.83403  3.82009  3.82009  ...
0.475       3.83403  3.83403  3.84355  3.83403  3.83403  ...
0.48        3.84355  3.84355  3.84954  3.84355  3.84355  ...
0.485       3.84954  3.84954  3.85284  3.84954  3.84954  ...
0.49        3.85284  3.85284  3.85429  3.85284  3.85284  ...
0.495       3.85429  3.85429  3.85464  3.85429  3.85429  ...
0.5         3.85464  3.85464  3.85464  3.85464  3.85464  ...
dtype: float64, shape: (151, 25)

The resulting object is a Pynapple TsdFrame of shape (n_time_bins,n_trials) (we are defining one trial as one presentation of stimuli).

With that, we can plot the PSTH of both the average firing rate of this unit and the average predicted rate.

Hide code cell source

def plot_peri_predict(
        peri_white_pred_unit, 
        peri_black_pred_unit, 
        peri_white, 
        peri_black,
        unit_id = unit_id
    
):
    fig, ax = plt.subplots(1,2,figsize=(17, 4), sharey=True)
    ### white
    # predicted
    ax[0].plot(np.mean(peri_white_pred_unit,axis=1), linewidth=1.5, color="red", label = "predicted")

    peri_u = peri_white[unit_id]
    peri_u_count = peri_u.count(bin_size)

    peri_u_count_conv_mean = np.mean(peri_u_count, 1).smooth(std=0.015)
    peri_u_rate_conv = peri_u_count_conv_mean / bin_size
    # observed
    ax[0].plot(peri_u_rate_conv, linewidth=2, color="black")
    ax[0].axvline(0.0)
    ax[0].axvspan(0, 0.250, color="silver", alpha=0.3, ec="black")
    ax[0].set_xlim(-.25, .5)

    ax[0].set_title("White flashes")

    #### black
    # predicted
    ax[1].plot(np.mean(peri_black_pred_unit,axis=1), linewidth=1.5, color="red")

    peri_u = peri_black[unit_id]
    peri_u_count = peri_u.count(bin_size)

    peri_u_count_conv_mean = np.mean(peri_u_count, 1).smooth(std=0.015)
    peri_u_rate_conv = peri_u_count_conv_mean / bin_size
    # observed
    ax[1].plot(peri_u_rate_conv, linewidth=2, color="black", label = "observed")
    ax[1].axvline(0.0)
    ax[1].axvspan(0, 0.250, color="black", alpha=0.3, ec="black")
    ax[1].set_xlim(-.25, .5)

    ax[1].set_title("Black flashes")
    ax[0].set_ylabel("Rate (Hz)")

    fig.text(0.5, -.05, 'Time from stim(s)', ha='center')
    fig.text(0.5, .95, f'PSTH unit {unit_id}', ha='center')
    fig.legend()
    plt.show()
plot_peri_predict(peri_white_pred_unit, 
        peri_black_pred_unit, 
        peri_white, 
        peri_black
)
../_images/5429a83cda90b850db4c34fabe46a6848f48bb9e440d0b8f3b147944a61a90c1.png

Now, we can move to fit all neurons!

Initialize and fit a GLM: PopulationGLM#

NeMoS has a separate PopulationGLM object for fitting a population of neurons. This is equivalent to fitting each individually in a loop, but faster. It operates very similarly to the GLM object we used a moment ago.

The first step is initializing the model, as with the GLM object.

model_stimuli = nmo.glm.PopulationGLM(
    regularizer = "Ridge",
    regularizer_strength = regularizer_strength,
    solver_name="LBFGS" 
)

Our input for the PopulationGLM can be the same basis object we used for fitting a single unit. Since we now want to fit all neurons, the counts for our model will be units_counts_train. With that, we call model_stimuli.fit() to fit our model.

model_stimuli.fit(
    X_train,
    units_counts_train
)
PopulationGLM(
    observation_model=PoissonObservations(inverse_link_function=exp),
    regularizer=Ridge(),
    regularizer_strength=7.745e-06,
    solver_name='LBFGS'
)

Same as before, our coefficients live in the coef_ attribute, while our intercept is stored in the intercept_ attribute.

However, since here we have fitted all units, the shape of our coef_ output will be (n_coefficients, n_units). Similarly, the shape of our intercept_ output will be (n_units,) because there is one intercept per unit.

print(model_stimuli.coef_.shape)
print(model_stimuli.intercept_.shape)
(30, 19)
(19,)

Assess PopulationGLM performance: PSTH#

To evaluate how well our PopulationGLM model captures the neural responses, we can visualize the activity of individual units using a PSTH.

For that, we first use the predict function.

# Predict spikes rate of all neurons in the population
predicted = model_stimuli.predict(X_test)

# Convert units from spikes/bin to spikes/sec
predicted = predicted/ bin_size

Then, we use Pynapple function compute_perievent_continuous to re-center the timestamps of the observed and predicted rates around the stimulus presentations.

# Re-center timestamps for test set
peri_white_test = nap.compute_perievent_continuous(
    timeseries = units.restrict(flashes_test).count(bin_size),
    tref = nap.Ts(flashes_test_white.start+.50),
    minmax = (window_size)
)

# Re-center timestampsfor test set
peri_black_test = nap.compute_perievent_continuous(
    timeseries =  units.restrict(flashes_test).count(bin_size),
    tref = nap.Ts(flashes_test_black.start+.50),
    minmax = (window_size)
)
# Re-center timestamps for predicted
peri_white_pred = nap.compute_perievent_continuous(
    timeseries = predicted, 
    tref = nap.Ts(flashes_test_white.start+.50),
    minmax=(window_size)
)  

# Re-center timestamps for predicted
peri_black_pred =  nap.compute_perievent_continuous(
    timeseries = predicted, 
    tref = nap.Ts(flashes_test_black.start+.50),
    minmax=(window_size)
)  

We can then plot the centered spikes in a PSTH! Let’s see how that looks for the first 9 units.

Hide code cell source

def plot_pop_psth(
        peri_color, 
        color_flashes, 
        bin_size, 
        smoothing=0.015,
        n_units = 9, 
        **peri_others
        ):
    """
    Plot perievent time histograms (PSTHs) and raster plots for multiple units.

    Parameters:
    -----------
    peri_color : dict
        Dictionary mapping unit names to binned spike count peri-stimulus data (e.g., binned time series).
    units : dict
        Dictionary of neural units, e.g., spike trains or trial-aligned spike events.
    color_flashes : str
        A label indicating the flash color condition ('black' or other), used for visual styling.
    bin_size : float
        Size of the bin used for spike count computation (in seconds).
    smoothing : float
        Standard deviation for Gaussian smoothing of the PSTH traces.
    """

    # Layout setup: 7 columns (units), 2 rows (split vertically into PSTH and raster plot)
    n_cols = n_units
    n_rows = 1
    fig, ax = plt.subplots(n_rows, n_cols)
    fig.set_figheight(2.5)
    fig.set_figwidth(17)
    fig.tight_layout()

    # Use tab20 color palette for plotting different units
    colors = plt.cm.tab20.colors[:n_cols]

    start = 0
    end = n_rows # Plot as many units as half the number of rows 
                            # each unit occupies 2 rows (one for psth and other for raster)

    for i in range(n_units):

        u = peri_color[:,:,i]

        # Plot PSTH (smoothed firing rate)
        if (i == 0):
            ax[i].plot(
            (np.mean(u, axis=1) / bin_size).smooth(std=smoothing),
            linewidth=2,
            color="black",
            label = "Observed"
        )
        else:
            ax[i].plot(
                (np.mean(u, axis=1) / bin_size).smooth(std=smoothing),
                linewidth=2,
                color="black",
            )

        ax[i].axvline(0.0)  # Stimulus onset line
        span_color = "black" if color_flashes == "black" else "silver"
        ax[i].axvspan(0, 0.250, color=span_color, alpha=0.3, ec="black")  # Stim duration
        ax[i].set_xlim(-0.25, 0.50)
        ax[i].set_title(f'{i}')
        for (label, color, peri_pred) in peri_others.values():
            u_pred = peri_pred[:,:,i]
            if (i == 0):
                ax[i].plot(
                (np.mean(u_pred, axis=1)),
                linewidth=1.5,
                color=color,
                label = label
            )
            else:
                ax[i].plot(
                (np.mean(u_pred, axis=1)),
                linewidth=1.5,
                color=color,
            )

        # Shift window for next units
        start += 1
        end += 1
    
    # Y-axis and title annotations
    ax[0].set_ylabel("Rate (Hz)")
    if n_rows > 2:
        ax[2, 0].set_ylabel("Rate (Hz)")
        ax[3, 0].set_ylabel("Trial")
    fig.legend()
    fig.text(0.5, 0.00, 'Time from stim(s)', ha='center')
    fig.text(0.5, 1.00, f'PSTH - {color_flashes} flashes', ha='center')
plot_pop_psth(
    peri_white_test, 
    "white", 
    bin_size,
    peri_pred_stimuli = ("Predicted", "red", peri_white_pred)
    )

plot_pop_psth(
    peri_black_test, 
    "black", 
    bin_size,
    peri_pred_stimuli = ("Predicted", "red", peri_black_pred)
    )
../_images/86b3c45d22cb3f8777aef0f5c452c4eb837459284057b1636204d117313130c4.png ../_images/c342983cc405e195c27c97b4cce990bd15ecc954cd677eb6a76f20131cd57b34.png

The model does pretty good! However, we can see some artifacts and unnatural peaks. What could we try to improve this model a little bit?

Adding coupling as a new predictor#

We can try extending the model in order to improve its performance. There are many ways one can do this: the iterative refinement and improvement of your model is an important part of the scientific process! In this tutorial, we’ll discuss one such extension, but you’re encouraged to try others.

Now, we’ll extend the model by adding coupling terms—that is, including the activity of other neurons as predictors—to account for shared variability within the network. It’s been shown by Pillow et al. [2008] [1b] that spike times can be predicted more accurately when taking into account the spiking of neighbouring units.

We start by creating a new basis object

# New basis object for coupling
basis_coupling = nmo.basis.RaisedCosineLogConv(
    n_basis_funcs=8, window_size=window_len, label="spike_history"
)

We can add this new basis to our old additive basis

# New additive basis with coupling term
additive_basis_coupling = additive_basis + basis_coupling

And use compute_features to convolve our input features with the basis object to compress them.

# Compute the features for train and test
X_coupling_train = additive_basis_coupling.compute_features(
    white_train_on,
    white_train_off,
    nap.Tsd(t= white_train_on.t,d=predictors_train["white"], time_support=units_counts_train.time_support),
    black_train_on,
    black_train_off,
    nap.Tsd(t= black_train_on.t,d=predictors_train["black"], time_support=units_counts_train.time_support),
    nap.TsdFrame(t=units_counts_train.t, d=units_counts_train, time_support=units_counts_train.time_support) # Our spike counts
)
X_coupling_test = additive_basis_coupling.compute_features(
    white_test_on,
    white_test_off,
    nap.Tsd(t= white_test_on.t,d=predictors_test["white"], time_support=units_counts_test.time_support),
    black_test_on,
    black_test_off,
    nap.Tsd(t= black_test_on.t,d=predictors_test["black"], time_support=units_counts_test.time_support),
    nap.TsdFrame(t=units_counts_test.t, d=units_counts_test, time_support=units_counts_test.time_support) # Our spike counts
)

What is the result of running compute_features with our raised cosine log-stretched basis and the spike counts?

compute_features is convolving all spike counts from all neurons with a raised cosine log-stretched basis. We do this because adding a coupling filter would resemble interactions between cells, and can mimic the effects of shared input noise, as mentioned in Pillow et al. [2008] [1c].

We initialize a new PopulationGLM object

regularizer_strength = 0.005 

model_coupling = nmo.glm.PopulationGLM(
    regularizer = "Ridge",
    regularizer_strength = regularizer_strength,
    solver_name="LBFGS"
)

And we fit calling model_coupling.fit()

model_coupling.fit(X_coupling_train,units_counts_train)
PopulationGLM(
    observation_model=PoissonObservations(inverse_link_function=exp),
    regularizer=Ridge(),
    regularizer_strength=0.005,
    solver_name='LBFGS'
)

Assess coupling PopulationGLM performance: heatmap#

Another way to visually inspect how well our PopulationGLM model captures the neural responses is to summarize the activity from all the units using a heatmap. Here’s how we construct it:

  1. Predict: we get the predicted firing rate of each timepoint for each neuron using predict with our PopulationGLM model object.

  2. Re center timestamps: we can use Pynapple function compute_perievent_continuous to re-center spiking activity timestamps around the presentation of stimuli.

  3. Z-scoring: We normalize the activity of each unit by converting it to z-scores. This removes differences in firing rate scale and allows us to focus on the relative response patterns across neurons.

  4. Sorting by peak time: We then sort neurons by the time at which they show their peak response in the observed data. This reveals any sequential or structured dynamics in the population response. We sort the observed data, and then use that sorting to order the prediction.

  5. Side-by-side comparison: Finally, we plot the observed and predicted population responses side by side. If the model captures the key features of the response, the predicted plot should resemble the observed one: we would expect to see a similar diagonal or curved band of activity, reflecting the ordered peak responses.

Step 1: The same way as before, we can obtain the predictions using predict

predicted = model_coupling.predict(X_coupling_test)/ bin_size

Step 2: we can center unit’s activity around stimuli presentation with compute_perievent_continuous

# Re-center timestamps for predicted
peri_white_pred_coupling = nap.compute_perievent_continuous(
    timeseries = predicted, 
    tref = nap.Ts(flashes_test_white.start+.50),
    minmax=(window_size)
)  

peri_black_pred_coupling =  nap.compute_perievent_continuous(
    timeseries = predicted, 
    tref = nap.Ts(flashes_test_black.start+.50),
    minmax=(window_size)
)  

Steps 3 and 4: Z-scoring and sorting according to peak time

def create_zscore_dic(
        peri_white_test, 
        peri_black_test,
        peri_white_pred,
        peri_black_pred,
        smoothing = 0.015):
    """
    Computes z-scored, time-aligned population responses for both observed 
    and predicted data and stores the outputs in separate dictionaries
    
    For each stimulus condition (white, black), the function:
    - Averages peri-stimulus time series across trials
    - Restricts to a fixed time window around stimulus onset
    - Applies z-scoring across time for each neuron
    - Sorts neurons by time of peak response (in observed data)
    - Returns sorted z-scored matrices for both  and predicted data

    Parameters
    ----------
    peri_white_test : TsdFrame
        Observed responses to white stimuli (trials × time × neurons).
    peri_black_test : TsdFrame
        Observed responses to black stimuli.
    peri_white_pred : TsdFrame
        Predicted responses to white stimuli.
    peri_black_pred : TsdFrame
        Predicted responses to black stimuli.
    smoothing : float
        Standard deviation for Gaussian smoothing of the perievent traces.

    Returns
    -------
    dic_test : dict
        Dictionary containing:
            - 'z': z-scored and sorted observed activity (time × neurons)
            - 'order': neuron sorting indices based on peak response
    dic_pred : dict
        Dictionary containing:
            - 'z': z-scored predicted activity, sorted using test order
    """

    # Time window around the stimulus (250 ms before and 500ms after)
    restriction = [-.24, .5]

    # Initialize dictionaries to store processed data
    dic_test = {
        "white": {"z": None, "order": None},  # Z-scored + sorted activity + sort order
        "black": {"z": None, "order": None}
    }
    dic_pred = {
        "white": {"z": None},  # Z-scored + sorted predicted activity
        "black": {"z": None}
    }

    # Process TEST data for each stimulus type
    for color, peri in zip(["white", "black"], [peri_white_test, peri_black_test]):
        # Restrict time window and average across trials
        mean_peri = np.mean(
            peri.restrict(nap.IntervalSet(restriction)), axis=1
        ).smooth(std=smoothing)
        # Z-score across time for each neuron (independently)
        z_mean_peri = zscore(mean_peri, axis=0)
        # Sort neurons by time of their peak response
        order = np.argsort(np.argmax(z_mean_peri, axis=0))
        # Apply sorting to z-scored data
        z_sorted = z_mean_peri[:, order]
        # Store results in dictionary
        dic_test[color]["z"] = z_sorted
        dic_test[color]["order"] = order

    # Process PREDICTED data
    for color, peri in zip(
        ["white", "black"], 
        [peri_white_pred, peri_black_pred]):  
        # Restrict time window and average across trials
        mean_peri = np.mean(
            peri.restrict(nap.IntervalSet(restriction)), axis=1
        )
        # Z-score across time for each neuron
        z_mean_peri = zscore(mean_peri, axis=0)
        # Use the same neuron ordering as in test data for comparison
        order = dic_test[color]["order"]
        # Sort predicted responses using test-data order
        z_sorted = z_mean_peri[:, order]
        # Store in dictionary
        dic_pred[color]["z"] = z_sorted

    return dic_test, dic_pred

Create our dictionaries of z-scored mean activity

dic_test, dic_pred_coupling = create_zscore_dic(
    peri_white_test,
    peri_black_test,
    peri_white_pred_coupling,
    peri_black_pred_coupling
)

dic_test, dic_pred_stimuli = create_zscore_dic(
    peri_white_test, 
    peri_black_test,
    peri_white_pred,
    peri_black_pred)

Step 5: Plot side by side comparison

Hide code cell source

def plot_zscores(dic_test, dic_pred_stimuli, dic_pred_coupling, bin_size = bin_size):
    """
    Plot heatmaps of z-scored neuronal responses for both observed and predicted data.

    For each stimulus type (white and black), the function:
    - Plots a heatmap of z-scored activity for each unit, sorted by time of peak response
    - Compares observed and predicted activity side-by-side
    - Adds time markers at stimulus onset and offset

    Parameters
    ----------
    dic_test : dict
        Dictionary with observed z-scored and sorted activity for 'white' and 'black' stimuli.
    dic_pred : dict
        Dictionary with predicted z-scored activity for 'white' and 'black' stimuli,
        sorted using the same order as dic_test.

    Returns
    -------
    None
        Displays matplotlib figures.
    """
    for color in ["white", "black"]:
        fig, ax = plt.subplots(1, 3)
        fig.tight_layout()
        fig.set_figheight(4)
        fig.set_figwidth(20)

        # Number of time bins in z-scored matrix (time x neurons)
        #num_bins = dic_test[color]["z"].shape[0]

        # Create time axis assuming bin size defined elsewhere
        time = np.arange(-0.24, 0.5, bin_size)

        # Image limits: [x_min, x_max, y_min, y_max]
        limits = [time[0], time[-1], 0, dic_test[color]["z"].shape[1]]

        # Plot observed activity
        im = ax[0].imshow(
            np.array(dic_test[color]["z"]).T,  # neurons on y-axis, time on x-axis
            aspect="auto",
            extent=limits
        )
        ax[0].set_title(f"{color.capitalize()} Observed")

        # Plot predicted activity with stimuli model
        im = ax[1].imshow(
            np.array(dic_pred_stimuli[color]["z"]).T,
            aspect="auto",
            extent=limits
        )
        ax[1].set_title(f"{color.capitalize()} Predicted (Stimuli)")

        # Plot predicted activity with coupling model
        im = ax[2].imshow(
            np.array(dic_pred_coupling[color]["z"]).T,
            aspect="auto",
            extent=limits
        )
        ax[2].set_title(f"{color.capitalize()} Predicted (Stimuli + Coupling)")

        # Add vertical lines for stimulus onset (0s) and offset (0.25s)
        for a in ax:
            a.axvline(0.0, color='k', linestyle='--')
            a.axvline(0.25, color='k', linestyle='--')
            a.set_ylabel("Unit")

        # Shared x-axis label
        fig.text(0.45, 0.00, 'Time from stim (s)', ha='center')

        # Colorbar
        fig.colorbar(im, ax=ax, location='right', label='Z-score')

        plt.show()
plot_zscores(dic_test, dic_pred_stimuli, dic_pred_coupling)
../_images/394e6a206ecd703b64b2cfc9244ba24f7ce7d510f91df2554570dc528ce74e39.png ../_images/18e827e1378abd97f89f5cf907b6e0f9f1656cbd8ab26b658c860f27955b5859.png

To the left of this plot we can see the observed z-scored activity, sorted by peak response time. In the middle, we can see the z-scored predictions of the model with Stimuli filters. To the right, we can see the z-scored predictions of the model with Stimuli and Coupling filters.

We can see that the average peak activity looks similar! Let’s compare the prediction of the Stimuli versus the Stimuli + Coupling model using a PSTH:

plot_pop_psth(
    peri_white_test, 
    "white", 
    bin_size,
    peri_pred_stimuli = ("Stimuli", "red", peri_white_pred),
    peri_pred_coupling = ("Stimuli + Coupling", "blue", peri_white_pred_coupling)
    )

plot_pop_psth(
    peri_black_test, 
    "black", 
    bin_size,
    peri_pred_stimuli = ("Stimuli", "red", peri_black_pred),
    peri_pred_coupling = ("Stimuli + Coupling", "blue", peri_black_pred_coupling)
    )
../_images/be8ec0648702cee7fb7fb8035416f6d307ec67f363389a74934a780a3905c9cd.png ../_images/4e3e75612ee629870714c0b171a86ffcf853687efd1f65f01e6123fd6a874d00.png

Evaluate model performance quantitatively: Pseudo-\(R^2\) McFadden#

Comparing the two models by examining their predictions is important, but you may also want a number with which to evaluate and compare your models’ performance. As discussed earlier, the GLM optimizes log-likelihood to find the best-fitting weights, and we can calculate this number using NeMoS score method.

This function takes the following as required inputs:

  • Predictors (in our case, our additive basis X_coupling_test, which includes information of our black and white predictors, as well as the subset of units counts corresponding to the test set)

  • Counts (in our case, units_counts_test, because we wish to evaluate how good the model is at predicting unseen data)

By default, score computes the mean log-likelihood. However, because the log-likelihood is un-normalized, it should not be compared across datasets (because e.g., it won’t account for difference in noise levels). We provide the ability to compute the pseudo-\(R^2\) for this purpose. For that, you only need to pass "pseudo-r2-McFadden" as score_type (optional input):

# Calculate the mean score for the Stimuli + Coupling model
# using pseudo-r2-McFadden
score_coupling = model_coupling.score(
    X_coupling_test,
    units_counts_test,
    score_type = "pseudo-r2-McFadden"
)

print(score_coupling)
0.20056623

We can also access each unit’s score by adding a lambda function to the optional parameter aggregate_sample_scores

# Obtain individual units' scores
score_units = model_coupling.score(
    X_coupling_test, 
    units_counts_test, 
    score_type = "pseudo-r2-McFadden",
    aggregate_sample_scores=lambda x:np.mean(x,axis=0), 
)

print(score_units)
[0.10131961 0.4054796  0.15242094 0.10345095 0.17804486 0.5537226
 0.23751181 0.1442216  0.09226435 0.04176563 0.15948278 0.14035589
 0.0955326  0.19902009 0.2875896  0.17398357 0.38125545 0.11390072
 0.36725706]

Let’s calculate the score separately for white and black flashes, and for both models (Stimuli and Stimuli + Coupling) using helper functions

Hide code cell source

def evaluate_model(model, X, y, score_type="pseudo-r2-McFadden"):
    """
    Evaluate a model's performance at the population and unit levels.

    Parameters
    ----------
    model : object
        A model object that implements a `.score()` method with arguments `X`, `y`, and `score_type`.
    X : array-like or pynapple-compatible object
        Input features used for evaluation.
    y : array-like or pynapple-compatible object
        Target outputs corresponding to `X`.
    score_type : str, optional
        The scoring metric to use (e.g., "log-likelihood" or "pseudo-r2-McFadden"). Passed to the model's `score` method.

    Returns
    -------
    score_pop : float
        The population-level score (aggregated across all units and samples).
    score_unit : np.ndarray
        The unit-level scores, computed by averaging over samples for each unit.
    """
    score_pop = model.score(
        X, 
        y, 
        score_type=score_type,
    )
    score_unit = model.score(
        X, 
        y, 
        aggregate_sample_scores=lambda x: np.mean(x, axis=0), 
        score_type=score_type,
    )
    return score_pop, score_unit

def evaluate_models_by_color(models, X_sets, y_sets, flashes_color, score_type="log-likelihood"):
    """
    Evaluate multiple models on a dataset filtered by flash color.

    Parameters
    ----------
    models : dict of str -> model
        Dictionary of model names and model objects to evaluate. Each model must implement `.score()`.
    X_sets : dict of str -> array-like or pynapple-compatible object
        Dictionary mapping model names to their corresponding input features.
    y_sets : array-like or pynapple-compatible object
        The target outputs to be used with all models.
    flashes_color : str
        The flash color condition used to restrict the data (e.g., "black" or "white").
    score_type : str, optional
        The scoring metric to use (e.g., "log-likelihood").

    Returns
    -------
    model_base_pop : float
        Population-level score of the first model in `models`.
    model_base_unit : np.ndarray
        Unit-level scores of the first model in `models`.
    model_hist_pop : float
        Population-level score of the second model in `models`.
    model_hist_unit : np.ndarray
        Unit-level scores of the second model in `models`.

    Notes
    -----
    The function assumes exactly two models in the `models` dictionary, 
    and returns their scores in fixed order based on insertion into the dictionary,
    first returning the population score, then the unit scores for each.

    """
    models_list = []
    for model_name, model in models.items():
        X = X_sets[model_name].restrict(flashes_color)
        y = y_sets.restrict(flashes_color)
        models_list.append(evaluate_model(model, X, y, score_type))
    return models_list[0][0], models_list[0][1], models_list[1][0], models_list[1][1]
# Define model dictionary
models = {
    "stimuli": model_stimuli,
    "coupling": model_coupling
}

# Calculate scores when predicting during white flashes
(score_white_stimuli_pop, 
 score_white_stimuli_unit, 
 score_white_coupling_pop,
 score_white_coupling_unit) = evaluate_models_by_color(
    models,
    {"stimuli": X_test, 
    "coupling": X_coupling_test},
    units_counts_test,
    flashes_test_white,
    "pseudo-r2-McFadden"
)

# Calculate scores when predicting during black flashes
(score_black_stimuli_pop, 
 score_black_stimuli_unit, 
 score_black_coupling_pop,
 score_black_coupling_unit) = evaluate_models_by_color(
    models,
    {"stimuli": X_test, 
    "coupling": X_coupling_test},
    units_counts_test,
    flashes_test_black,
    "pseudo-r2-McFadden"
)

We can also see the individual scores for each unit!

Hide code cell source

def half_violin(ax, data, center, side="left", color="blue", width=0.3, alpha=1):
    """Draw a half violin on a matplotlib Axes."""
    kde = gaussian_kde(data)
    x = np.linspace(min(data), max(data), 200)
    y = kde(x)
    y = y / y.max() * width  # normalize

    if side == "left":
        ax.fill_betweenx(x, center, center - y, facecolor=color, alpha=alpha, edgecolor="black")
    elif side == "right":
        ax.fill_betweenx(x, center, center + y, facecolor=color, alpha=alpha, edgecolor="black")

def plot_half_violin_scores():
    # Data
    stim_white = score_white_stimuli_unit
    coupling_white = score_white_coupling_unit
    stim_black = score_black_stimuli_unit
    coupling_black = score_black_coupling_unit

    positions = [0, 1]
    labels = ["White", "Black"]
    width = 0.3
    num_units = len(stim_white)

    fig, ax = plt.subplots(figsize=(10, 5))

    # Plot violins
    half_violin(ax, stim_white, center=positions[0], side="left", color="#1f77b4", width=width)
    half_violin(ax, coupling_white, center=positions[0], side="right", color="#ff7f0e", width=width)
    half_violin(ax, stim_black, center=positions[1], side="left", color="#1f77b4", width=width)
    half_violin(ax, coupling_black, center=positions[1], side="right", color="#ff7f0e", width=width)

    # Plot individual data points and connect within-color models
    for i in range(num_units):
        jitter = np.random.normal(scale=0.01)

        # White
        ax.plot([positions[0]-0.05+jitter, positions[0]+0.05+jitter],
                [stim_white[i], coupling_white[i]],
                color="gray", alpha=0.8, linewidth=0.8)
        ax.scatter([positions[0]-0.05+jitter, positions[0]+0.05+jitter],
                   [stim_white[i], coupling_white[i]],
                   color="black", s=8, alpha=.3)

        # Black
        ax.plot([positions[1]-0.05+jitter, positions[1]+0.05+jitter],
                [stim_black[i], coupling_black[i]],
                color="gray", alpha=.8, linewidth=0.8)
        ax.scatter([positions[1]-0.05+jitter, positions[1]+0.05+jitter],
                   [stim_black[i], coupling_black[i]],
                   color="black", s=8, alpha=.3)

    # Plot mean markers and annotate with values
    offset = 0.07
    mean_kwargs = dict(marker='d', color="black", markersize=6, zorder=3)

    def annotate_mean(x, y):
        ax.plot(x, y, **mean_kwargs)
        ax.text(x, y + 0.005, f"{y:.3f}", ha='center', va='bottom', fontsize=9)

    annotate_mean(positions[0] - offset, np.mean(stim_white))
    annotate_mean(positions[0] + offset, np.mean(coupling_white))
    annotate_mean(positions[1] - offset, np.mean(stim_black))
    annotate_mean(positions[1] + offset, np.mean(coupling_black))

    # Axes formatting
    ax.set_xticks(positions)
    ax.set_xticklabels(labels)
    ax.set_ylabel("Pseudo-$R^2$ McFadden")
    ax.set_title("Unit Scores per Flash Color and Model Type")

    # Legend
    legend_elements = [
        Patch(facecolor="#1f77b4", edgecolor='black', label="Stimuli"),
        Patch(facecolor="#ff7f0e", edgecolor='black', label="Stimuli + Coupling")
    ]
    ax.legend(
        handles=legend_elements,
        title="Model",
        loc='upper left',
        fontsize=9,
        title_fontsize=10,
        bbox_to_anchor=(.8, .2),
        borderaxespad=0.
    )

    plt.tight_layout()
    plt.show()
plot_half_violin_scores()
../_images/600e1c7a76e268d56e2d83dbe65da66973431e34d8dbc227cd43fb4733161186.png

Although there is some variability between neurons, in general, the Stimuli + Coupling model is better at predicting spike trains than the model which only includes Stimuli filters. This makes a lot of sense! Noise is shared across neurons, and the information of a single cell response is also encoded in the population activity, beyond the information provided by stimuli alone Pillow et al. (2008) [1].

Food for thought#

We intentionally left out many details in this tutorial, as including everything would have resulted in a longer and more complex notebook. Our main goal was to introduce the key components of GLMs, walk through a real-data example, and demonstrate how using NeMoS and Pynapple can greatly simplify the modeling process. While going into every detail of model fitting was beyond the scope of this tutorial, it should not be beyond the scope of your own work. When applying GLMs to your own research questions, it’s crucial to be rigorous and intentional in your modeling choices.

In particular:

  • Explore different ways to split your data. Here, we used train and test data, but you could also try train, validate and test - specially if you will be trying different models and tweaking parameters before finally assessing the performance. Furthermore, different splitting strategies may be needed for different input statistics. For example, picking samples in a random uniform manner may be ideal for independent samples, but not recommended for time series (for which samples close in time are likely highly correlated).

  • Cross-validate the regularizer strength for each neuron individually, as using a fixed value across the population may lead to suboptimal fits. For example, the regularizer we used here does a reasonable job at capturing the activity of neurons that are strongly modulated by the flash (see units 1 and 5 in the PSTH of the Stimuli model). However, for neurons with weaker modulation (i.e., smaller changes in firing rate), the model tends to produce flattened predictions, possibly due to over-regularization (see units 4 or 3 in the PSTH of the Stimuli model).

  • Think carefully about and cross-validate the basis functions parameters, including the type of basis and the number of components. These choices can greatly influence the model’s performance, and it is important to remember that the basis of choice will force assumptions in your data, so it is key to be aware of those. For example, the raised cosine log stretched basis assumes that the precision of the basis decreases with the distance from the event. This makes the basis great to model rapid changes of the firing rate just after an event, and slow decay back to baseline. This may or may not be the case depending on the dynamics of the neuron you want to fit. There is a helpful NeMoS notebook on the topic dedicated to tuning basis functions — we encourage you to check it out.

  • We made one specific improvement to our model, i.e. adding coupling filters - what do you think would be another reasonable improvement to add? (hint: Pillow et al. [2008] [1d])

References#

[1a] [1b] [1c] [1d] Pillow, J. W., Shlens, J., Paninski, L., Sher, A., Litke, A. M., Chichilnisky, E. J., & Simoncelli, E. P. (2008). Spatio-temporal correlations and visual signalling in a complete neuronal population. Nature, 454(7207), 995-999. https://doi.org/10.1038/nature07140

[2] Allen Institute for Brain Science. Allen Brain Observatory - Neuropixels Visual Coding - Technical White paper. Technical Report, Allen Institute for Brain Science, October 2019.

[3] Pillow, J. [Cosyne Talks]. (2018, March 1-4) Jonathan Pillow - Tutorial: Statistical models for neural data - Part 1 (Cosyne 2018) [Video]. Youtube

[4] Pillow, J. W., Paninski, L., Uzzell, V. J., Simoncelli, E. P., & Chichilnisky, E. J. (2005). Prediction and decoding of retinal ganglion cell responses with a probabilistic spiking model. The Journal of neuroscience : the official journal of the Society for Neuroscience, 25(47), 11003–11013. https://doi.org/10.1523/JNEUROSCI.3305-05.2005

Data citation#

The data used in this tutorial is from the Allen Brain Map, with the following citation:

Dataset: Allen Institute MindScope Program (2019). Allen Brain Observatory – Neuropixels Visual Coding [Dataset]. Available from brain-map.org/explore/circuits

Primary publication: Siegle, J. H., Jia, X., Durand, S., et al. (2021). Survey of spiking in the mouse visual system reveals functional hierarchy. Nature, 592(7612), 86-92. https://doi.org/10.1038/s41586-020-03171-x

Resources#

We have left some resources here and there throughout the notebook. Here is a complete list of all of them: