BCI × Deep Learning

Neural signal decoding via transformer pre-training

Notes on building NeuroLLM: a foundation-model approach to EEG decoding where a small transformer pre-trained on large-scale clinical EEG data is fine-tuned for motor-imagery classification.

On this page

Brain-computer interfaces decode neural signals into control commands. The classical approach — Common Spatial Patterns (CSP) fed into an SVM — has been the standard for motor-imagery classification for over a decade. It works, but it relies entirely on hand-crafted spatial filters and cannot learn from large unlabelled EEG corpora.

Recent work (LaBraM, BrainBERT, EEGFormer) demonstrates that transformer-based models pre-trained on large EEG datasets can learn general neural-signal representations that transfer across subjects and tasks. NeuroLLM implements this idea at a reproducible scale: pre-train on TUH, fine-tune on BCI Competition IV, compare honestly against classical and neural baselines.

Why pre-train on EEG

Motor-imagery datasets are small: BCI Competition IV 2a has 9 subjects, 288 trials each. Training a transformer from scratch on 2,592 total trials doesn’t work — you end up with a model that memorises electrode noise. The same problem NLP had before BERT: not enough task-specific data to learn representations from scratch.

Model architecture

pipeline
flowchart TB EEG["EEG Signal
C channels × T samples"] --> Patch["Patch Embedding
1-D conv, P=50, d=256"] Patch --> Pos["Positional Encoding
Spatial + temporal (learnable)"] Pos --> Enc["Transformer Encoder
6 layers, 4 heads, d_ff=512"] Enc --> CLS["[CLS] Token → MLP
4 classes (LH, RH, F, T)"] style Patch fill:#eff6ff,stroke:#2563eb,color:#0f172a style Enc fill:#eff6ff,stroke:#2563eb,color:#0f172a style CLS fill:#eff6ff,stroke:#2563eb,color:#0f172a

The input is a multi-channel EEG signal (22 channels × 1000 samples at 250 Hz). Each channel is independently divided into non-overlapping temporal patches of size 50 (200 ms each), then projected to a 256-dimensional embedding via a 1-D convolution. This produces N = C × (T / P) = 22 × 20 = 440 tokens.

Positional encoding has two learnable components: spatial (one embedding per electrode) and temporal (one per patch position). These are added, not concatenated, keeping the token dimension at 256.

The encoder is a standard pre-norm transformer: LayerNorm → Multi-Head Self-Attention → residual → LayerNorm → FFN → residual. Six layers, four heads. Total: ~10M parameters.

Masked Channel Modeling

The self-supervised pre-training objective is Masked Channel Modeling (MCM): randomly select 30% of channels, zero out all their patches, and train the model to reconstruct the original signal from the unmasked channels. The loss is MSE between predicted and original patch embeddings.

# MCM masking (simplified)
mask = torch.rand(n_channels) < 0.3          # 30% of channels
x_masked = x.clone()
x_masked[mask] = 0                            # zero masked channels

reconstructed = model.encoder(x_masked)       # forward through encoder
loss = F.mse_loss(reconstructed[mask], x[mask])  # reconstruct masked only

This is conceptually similar to BERT’s masked language modeling, but operating on EEG channels rather than text tokens. The key insight: because EEG channels are spatially correlated (nearby electrodes record similar activity), the model is forced to learn cross-channel dependencies — exactly the spatial structure that CSP targets, but learned end-to-end.

Fine-tuning for motor imagery

Fine-tuning loads the pre-trained encoder and adds a classification head. The channel count changes (19 channels in TUH → 22 in BCI-IV), so we do shape-safe checkpoint loading: only weights with matching shapes are loaded, the rest are randomly initialised.

# Shape-safe checkpoint loading
encoder_state = torch.load("best_mcm.pt")
model_state = model.state_dict()
compatible = {k: v for k, v in encoder_state.items()
              if k in model_state and model_state[k].shape == v.shape}
model.load_state_dict(compatible, strict=False)
# pos_enc.spatial.weight: (19, 256) skipped → (22, 256) kept random

Frequency-band attention kernel

Standard self-attention treats the temporal dimension opaquely. But EEG has well-defined frequency bands (delta, theta, alpha, beta, gamma) with distinct functional significance. The custom kernel decomposes attention queries and keys into frequency bands via FFT, then adds learnable per-head band biases:

The Triton kernel fuses the band decomposition with scaled dot-product attention, avoiding extra HBM round-trips. On CPU it falls back to a pure-PyTorch implementation. The learnable band biases let the model discover that motor imagery primarily lives in the mu (8–12 Hz) and beta (13–30 Hz) ranges.

Benchmark results

All methods evaluated on the same synthetic data (real datasets require registration and download). On synthetic random data, all methods converge to chance level (25% for 4 classes) — this is expected and honest. The value is in the pipeline and architecture, not synthetic numbers.

MethodMean AccuracyParamsNotes
CSP + SVM~25%Classical baseline
EEGNet~25%2.6KCompact CNN (Lawhern 2018)
Vanilla Transformer~25%~10MSame arch, random init
NeuroLLM (pre-trained)~25%~10MMCM pre-trained

With real TUH pre-training and BCI-IV fine-tuning, the literature (LaBraM, EEGFormer) reports 75–85% accuracy on this benchmark. The architecture and pipeline here are identical — the missing ingredient is real data and GPU training budget.

Attention visualisation

One of the advantages of transformer-based EEG decoding: you can inspect attention maps to see which channels and time windows the model focuses on. Extracting attention weights from PyTorch’s TransformerEncoderLayer required a workaround — the module hardcodes need_weights=False internally. The solution: manually iterate encoder layers and call self_attn() directly with need_weights=True.

The channel importance map is derived from attention weights: sum over heads and layers, then map tokens back to their source channels. On real motor-imagery data, you would expect C3 and C4 (the electrodes over sensorimotor cortex) to receive the highest attention — these are the channels where mu rhythm desynchronisation appears during hand movement imagination.

What I learned

Full source, 192 tests across 10 releases, and benchmark pipeline at github.com/ajliouat/neurollm.