Why write CUDA kernels by hand
Production inference stacks (vLLM, TensorRT-LLM, FlashAttention) ship optimised kernels that most practitioners never read. The standard advice is "don't roll your own" — and that's correct for production. But if the goal is to understand why certain fusions matter, where the bandwidth wall is, and what warp stalls actually look like in a profiler, there's no substitute for writing the kernels yourself and profiling them against hardware ceilings.
FlashKernel is the result of that exercise: five kernel families implemented from scratch in both CUDA C++
and Triton, benchmarked against PyTorch eager and torch.compile, and roofline-mapped with
Nsight Compute on an NVIDIA T4.
The hardware: NVIDIA T4
The T4 is a Turing-architecture GPU (SM 7.5) with clear, published ceilings that make roofline analysis straightforward:
- fp16 Tensor Core peak: 65 TFLOPS
- fp32 CUDA Core peak: 8.1 TFLOPS
- HBM2 bandwidth: 300 GB/s
- Ridge point (fp16): ~217 FLOP/byte — below this, you're memory-bound
- L2 cache: 4 MB · Shared memory: 64 KB per SM
At $0.16/hr spot on AWS (g4dn.xlarge), it's cheap enough to iterate quickly.
Tiled FlashAttention with online softmax
Standard attention materialises the full N×N attention matrix in HBM — O(N²) memory that dominates for long sequences. FlashAttention (Dao et al., 2022) eliminates this by tiling Q, K, V into blocks that fit in shared memory and computing attention incrementally with online softmax.
The core idea: instead of computing S = Q @ K^T fully, then softmax(S), then
S @ V, we process one tile of K/V at a time. Each tile updates a running softmax statistic
(log-sum-exp) and accumulates into the output using the rescaling trick from Milakov & Gimelshein (2018).
// Load Q tile to shared memory (Br × d)
// For each K/V tile (Bc × d):
// S_tile = Q_smem @ K_tile^T → Br × Bc in registers
// Apply causal mask if needed
// m_new = max(m_old, row_max(S_tile))
// P_tile = exp(S_tile - m_new) → rescaled softmax numerator
// l_new = exp(m_old - m_new) * l_old + row_sum(P_tile)
// O = (l_old/l_new) * exp(m_old - m_new) * O + (1/l_new) * P_tile @ V_tile
// m_old = m_new; l_old = l_new
Our implementation uses 32×32 tiles with double-buffered shared memory loads. On T4 this achieves 38.2 TFLOPS — 59% of the fp16 Tensor Core ceiling. The main limiter is occupancy: each block uses ~24 KB of shared memory, capping occupancy at 50%. The dominant warp stall is "Math Pipe Throttle", confirming compute-boundedness at arithmetic intensity 341 F/B.
Fusing GeLU into the matmul
In a standard MLP layer, the projection Y = GeLU(X @ W^T + b) requires two HBM round-trips:
one to write the linear output and one to read it back for the activation. Fusing them into a single
kernel eliminates that intermediate write — saving M×N×2 bytes (6 MB for GPT-2's 1024×3072 MLP).
The fusion is simple: after the tiled GEMM accumulates C[i][j] in registers, apply GeLU
before writing to HBM. The activation adds less than 2% overhead to the matmul.
Result: 31.5 TFLOPS (49% of fp16 peak) at arithmetic intensity 295 F/B. Compute-bound, with the main limiter being shared memory pressure limiting occupancy to 50%.
RoPE: fused vs table lookup
Rotary Position Embedding applies a rotation to Q and K based on token position. There are two natural implementations:
- Table lookup: Precompute sin/cos into a [max_seq, D/2] table, load at runtime. Lower arithmetic intensity (AI=1.5) but higher bandwidth utilisation (80% of HBM peak, 240 GB/s).
- Fused: Compute sin/cos per-thread with
__sincosf— no table needed. Higher AI (3.25) but slightly lower bandwidth (74% of HBM peak, 222 GB/s) because SFU units contend with memory pipeline.
Both are firmly memory-bound. The table variant wins when the frequency table fits in L2 cache (true for most practical sequence lengths on T4's 4 MB L2). The fused variant wins for one-shot inference where caching tables isn't worth the memory.
Paged KV-Cache
Long-context inference with static KV buffers wastes memory — you must pre-allocate for max sequence length. vLLM introduced paged attention, borrowing the operating system's virtual memory idea: store KV data in fixed-size pages and use a page table to map logical positions to physical slots.
We implement two kernels: append (scatter-write new tokens into the page pool) and read (scatter-gather from pages into contiguous output). Both are purely memory-bound at AI=0.08 — near-zero compute, just data movement. Append achieves 65% of HBM peak (195 GB/s), and read achieves 59% (178 GB/s). The gap is due to the non-contiguous access pattern inherent in page indirection.
Roofline results
All 8 kernel variants mapped onto the T4 roofline. The x-axis is arithmetic intensity (FLOP/byte), the y-axis is achieved throughput (TFLOPS or effective GB/s). The diagonal ceiling is HBM bandwidth, the horizontal ceiling is fp16 Tensor Core peak.
| Kernel | AI (F/B) | Achieved | % Ceiling | Bound |
|---|---|---|---|---|
| vector_add | 0.17 | 248 GB/s | 83% | Memory |
| reduce_sum | 0.50 | 262 GB/s | 87% | Memory |
| flash_attention | 341 | 38.2 TFLOPS | 59% | Compute |
| fused_gelu_linear | 295 | 31.5 TFLOPS | 49% | Compute |
| rope_fused | 3.25 | 222 GB/s | 74% | Memory |
| rope_table | 1.50 | 240 GB/s | 80% | Memory |
| kv_append | 0.08 | 195 GB/s | 65% | Memory |
| kv_read | 0.08 | 178 GB/s | 59% | Memory |
Six kernels are memory-bound (65–87% of the 300 GB/s HBM ceiling), two are compute-bound (49–59% of the 65 TFLOPS fp16 ceiling). The reduction kernel is the most bandwidth-efficient at 87% — its warp-shuffle approach minimizes shared memory traffic and achieves near-peak HBM throughput.
End-to-end: plugging into GPT-2
To verify that individual kernel improvements compose into real model speedups, we integrated all custom kernels into GPT-2 (124M) via PyTorch C++ extensions. The integration module monkey-patches HuggingFace's GPT2Attention and GPT2MLP with custom implementations that call our CUDA kernels:
- Standard attention → FlashAttention with RoPE applied to Q/K
- MLP
c_fcprojection + GELU → fused GeLU+Linear kernel
from src.integration.gpt2_custom_kernels import patch_gpt2_model
model = GPT2LMHeadModel.from_pretrained("gpt2")
model = patch_gpt2_model(model, backend="cuda")
# Now forward() uses FlashAttention + fused GeLU
Architecture diagram
The full inference pipeline with custom kernels:
Custom CUDA kernel"] Embed --> Attn["Tiled FlashAttention
38.2 TFLOPS · 59% of fp16 peak"] Attn --> KV["Paged KV-Cache
195 GB/s append · 178 GB/s read"] KV --> Fused["Fused GeLU + Linear
31.5 TFLOPS · no HBM roundtrip"] Fused --> Norm["LayerNorm + Residual"] Norm --> Next["Next layer / Output"] style Embed fill:#eff6ff,stroke:#2563eb,color:#0f172a style Attn fill:#eff6ff,stroke:#2563eb,color:#0f172a style KV fill:#eff6ff,stroke:#2563eb,color:#0f172a style Fused fill:#eff6ff,stroke:#2563eb,color:#0f172a
Key takeaways
- The roofline doesn't lie. Every kernel lands exactly where you'd expect. Elementwise and data-movement kernels cluster at AI < 5 and hit the memory wall. Attention and GEMM land above the ridge point and are compute-limited.
- Fusion is about bytes, not FLOPs. Fusing GeLU into the matmul adds <2% compute overhead but saves a full 6 MB HBM round-trip. The roofline shifts the kernel firmly past the ridge.
- Occupancy is not everything. Flash attention runs at 50% occupancy but achieves 59% of peak — the tiling strategy keeps the math pipes busy despite fewer active warps.
- Scatter access is expensive. Paged KV-cache kernels achieve 59–65% of HBM peak, compared to 83–87% for simple linear access patterns. The indirection is the cost of dynamic memory.
- Triton closes the gap. Writing the same algorithms in Triton takes ~3× less code and gets within 10–15% of the CUDA C++ performance via autotuning.
Full source, profiling data, and roofline plots are on GitHub. See the project page for the complete kernel inventory and architecture overview.