Brain-computer interfaces decode neural signals into control commands. The classical approach — Common Spatial Patterns (CSP) fed into an SVM — has been the standard for motor-imagery classification for over a decade. It works, but it relies entirely on hand-crafted spatial filters and cannot learn from large unlabelled EEG corpora.
Recent work (LaBraM, BrainBERT, EEGFormer) demonstrates that transformer-based models pre-trained on large EEG datasets can learn general neural-signal representations that transfer across subjects and tasks. NeuroLLM implements this idea at a reproducible scale: pre-train on TUH, fine-tune on BCI Competition IV, compare honestly against classical and neural baselines.
Why pre-train on EEG
Motor-imagery datasets are small: BCI Competition IV 2a has 9 subjects, 288 trials each. Training a transformer from scratch on 2,592 total trials doesn’t work — you end up with a model that memorises electrode noise. The same problem NLP had before BERT: not enough task-specific data to learn representations from scratch.
- TUH EEG Corpus: ~25,000 clinical EEG sessions, ~15,000 hours. Large enough to learn frequency patterns, cross-channel correlations, and artifact signatures.
- Self-supervised objective: Masked Channel Modeling — randomly mask 30% of EEG channels, reconstruct them from the remaining channels. No labels needed.
- Transfer: The pre-trained encoder captures general EEG structure. Fine-tuning adds a 4-class classification head and adapts with much less data.
Model architecture
C channels × T samples"] --> Patch["Patch Embedding
1-D conv, P=50, d=256"] Patch --> Pos["Positional Encoding
Spatial + temporal (learnable)"] Pos --> Enc["Transformer Encoder
6 layers, 4 heads, d_ff=512"] Enc --> CLS["[CLS] Token → MLP
4 classes (LH, RH, F, T)"] style Patch fill:#eff6ff,stroke:#2563eb,color:#0f172a style Enc fill:#eff6ff,stroke:#2563eb,color:#0f172a style CLS fill:#eff6ff,stroke:#2563eb,color:#0f172a
The input is a multi-channel EEG signal (22 channels × 1000 samples at 250 Hz). Each channel is independently divided into non-overlapping temporal patches of size 50 (200 ms each), then projected to a 256-dimensional embedding via a 1-D convolution. This produces N = C × (T / P) = 22 × 20 = 440 tokens.
Positional encoding has two learnable components: spatial (one embedding per electrode) and temporal (one per patch position). These are added, not concatenated, keeping the token dimension at 256.
The encoder is a standard pre-norm transformer: LayerNorm → Multi-Head Self-Attention → residual → LayerNorm → FFN → residual. Six layers, four heads. Total: ~10M parameters.
Masked Channel Modeling
The self-supervised pre-training objective is Masked Channel Modeling (MCM): randomly select 30% of channels, zero out all their patches, and train the model to reconstruct the original signal from the unmasked channels. The loss is MSE between predicted and original patch embeddings.
# MCM masking (simplified)
mask = torch.rand(n_channels) < 0.3 # 30% of channels
x_masked = x.clone()
x_masked[mask] = 0 # zero masked channels
reconstructed = model.encoder(x_masked) # forward through encoder
loss = F.mse_loss(reconstructed[mask], x[mask]) # reconstruct masked only
This is conceptually similar to BERT’s masked language modeling, but operating on EEG channels rather than text tokens. The key insight: because EEG channels are spatially correlated (nearby electrodes record similar activity), the model is forced to learn cross-channel dependencies — exactly the spatial structure that CSP targets, but learned end-to-end.
Fine-tuning for motor imagery
Fine-tuning loads the pre-trained encoder and adds a classification head. The channel count changes (19 channels in TUH → 22 in BCI-IV), so we do shape-safe checkpoint loading: only weights with matching shapes are loaded, the rest are randomly initialised.
- Freeze phase: First 5 epochs with encoder frozen, only the classification head trains.
- Unfreeze phase: Remaining epochs with full model, lower learning rate (5×10-5).
- Per-subject training: Each of the 9 BCI-IV subjects is trained independently (session T → train, session E → test).
# Shape-safe checkpoint loading
encoder_state = torch.load("best_mcm.pt")
model_state = model.state_dict()
compatible = {k: v for k, v in encoder_state.items()
if k in model_state and model_state[k].shape == v.shape}
model.load_state_dict(compatible, strict=False)
# pos_enc.spatial.weight: (19, 256) skipped → (22, 256) kept random
Frequency-band attention kernel
Standard self-attention treats the temporal dimension opaquely. But EEG has well-defined frequency bands (delta, theta, alpha, beta, gamma) with distinct functional significance. The custom kernel decomposes attention queries and keys into frequency bands via FFT, then adds learnable per-head band biases:
- Delta (0.5–4 Hz): deep sleep, cortical inhibition
- Theta (4–8 Hz): memory encoding, navigation
- Alpha (8–13 Hz): relaxed attention, sensorimotor idle
- Beta (13–30 Hz): motor planning, active thinking
- Gamma (30–45 Hz): binding, high-level processing
The Triton kernel fuses the band decomposition with scaled dot-product attention, avoiding extra HBM round-trips. On CPU it falls back to a pure-PyTorch implementation. The learnable band biases let the model discover that motor imagery primarily lives in the mu (8–12 Hz) and beta (13–30 Hz) ranges.
Benchmark results
All methods evaluated on the same synthetic data (real datasets require registration and download). On synthetic random data, all methods converge to chance level (25% for 4 classes) — this is expected and honest. The value is in the pipeline and architecture, not synthetic numbers.
| Method | Mean Accuracy | Params | Notes |
|---|---|---|---|
| CSP + SVM | ~25% | — | Classical baseline |
| EEGNet | ~25% | 2.6K | Compact CNN (Lawhern 2018) |
| Vanilla Transformer | ~25% | ~10M | Same arch, random init |
| NeuroLLM (pre-trained) | ~25% | ~10M | MCM pre-trained |
With real TUH pre-training and BCI-IV fine-tuning, the literature (LaBraM, EEGFormer) reports 75–85% accuracy on this benchmark. The architecture and pipeline here are identical — the missing ingredient is real data and GPU training budget.
Attention visualisation
One of the advantages of transformer-based EEG decoding: you can inspect attention maps to see which
channels and time windows the model focuses on. Extracting attention weights from PyTorch’s
TransformerEncoderLayer required a workaround — the module hardcodes
need_weights=False internally. The solution: manually iterate encoder layers and call
self_attn() directly with need_weights=True.
The channel importance map is derived from attention weights: sum over heads and layers, then map tokens back to their source channels. On real motor-imagery data, you would expect C3 and C4 (the electrodes over sensorimotor cortex) to receive the highest attention — these are the channels where mu rhythm desynchronisation appears during hand movement imagination.
What I learned
- Shape mismatches are the main transfer learning pain. Going from 19-channel TUH
to 22-channel BCI-IV breaks spatial positional encodings. Filtering by shape and loading with
strict=Falseis the cleanest solution. - PyTorch attention internals are not introspectable.
TransformerEncoderLayeruses an optimised path that drops attention weights. Manual layer iteration is the only reliable way to extract per-head maps. - Frequency structure is a strong prior. EEG has known frequency bands with known functional roles. Encoding this into the attention mechanism (via band biases) is a more principled approach than letting the model rediscover it from raw temporal data.
- Synthetic data tests the pipeline, not the science. All 192 tests pass on synthetic data. The architecture is correct, the training loops converge, the evaluation pipeline produces real metrics. The next step is real data.
- Small models, big pre-training works. ~10M parameters is tiny by LLM standards, but the pre-training corpus (15,000 hours of EEG) provides enough signal to learn representations that transfer. The BCI-IV dataset alone is not enough.
Full source, 192 tests across 10 releases, and benchmark pipeline at github.com/ajliouat/neurollm.