GPU compute

Writing CUDA kernels for transformer inference on T4

Notes on implementing FlashAttention, fused activations, RoPE, and paged KV-cache from scratch in CUDA C++ and Triton — and what the roofline model reveals about each kernel.

On this page

Why write CUDA kernels by hand

Production inference stacks (vLLM, TensorRT-LLM, FlashAttention) ship optimised kernels that most practitioners never read. The standard advice is "don't roll your own" — and that's correct for production. But if the goal is to understand why certain fusions matter, where the bandwidth wall is, and what warp stalls actually look like in a profiler, there's no substitute for writing the kernels yourself and profiling them against hardware ceilings.

FlashKernel is the result of that exercise: five kernel families implemented from scratch in both CUDA C++ and Triton, benchmarked against PyTorch eager and torch.compile, and roofline-mapped with Nsight Compute on an NVIDIA T4.

The hardware: NVIDIA T4

The T4 is a Turing-architecture GPU (SM 7.5) with clear, published ceilings that make roofline analysis straightforward:

At $0.16/hr spot on AWS (g4dn.xlarge), it's cheap enough to iterate quickly.

Tiled FlashAttention with online softmax

Standard attention materialises the full N×N attention matrix in HBM — O(N²) memory that dominates for long sequences. FlashAttention (Dao et al., 2022) eliminates this by tiling Q, K, V into blocks that fit in shared memory and computing attention incrementally with online softmax.

The core idea: instead of computing S = Q @ K^T fully, then softmax(S), then S @ V, we process one tile of K/V at a time. Each tile updates a running softmax statistic (log-sum-exp) and accumulates into the output using the rescaling trick from Milakov & Gimelshein (2018).

CUDA — tiled attention inner loop (simplified)
// Load Q tile to shared memory (Br × d)
// For each K/V tile (Bc × d):
//   S_tile = Q_smem @ K_tile^T        → Br × Bc in registers
//   Apply causal mask if needed
//   m_new = max(m_old, row_max(S_tile))
//   P_tile = exp(S_tile - m_new)       → rescaled softmax numerator
//   l_new = exp(m_old - m_new) * l_old + row_sum(P_tile)
//   O = (l_old/l_new) * exp(m_old - m_new) * O + (1/l_new) * P_tile @ V_tile
//   m_old = m_new; l_old = l_new

Our implementation uses 32×32 tiles with double-buffered shared memory loads. On T4 this achieves 38.2 TFLOPS — 59% of the fp16 Tensor Core ceiling. The main limiter is occupancy: each block uses ~24 KB of shared memory, capping occupancy at 50%. The dominant warp stall is "Math Pipe Throttle", confirming compute-boundedness at arithmetic intensity 341 F/B.

Fusing GeLU into the matmul

In a standard MLP layer, the projection Y = GeLU(X @ W^T + b) requires two HBM round-trips: one to write the linear output and one to read it back for the activation. Fusing them into a single kernel eliminates that intermediate write — saving M×N×2 bytes (6 MB for GPT-2's 1024×3072 MLP).

The fusion is simple: after the tiled GEMM accumulates C[i][j] in registers, apply GeLU before writing to HBM. The activation adds less than 2% overhead to the matmul.

Result: 31.5 TFLOPS (49% of fp16 peak) at arithmetic intensity 295 F/B. Compute-bound, with the main limiter being shared memory pressure limiting occupancy to 50%.

RoPE: fused vs table lookup

Rotary Position Embedding applies a rotation to Q and K based on token position. There are two natural implementations:

Both are firmly memory-bound. The table variant wins when the frequency table fits in L2 cache (true for most practical sequence lengths on T4's 4 MB L2). The fused variant wins for one-shot inference where caching tables isn't worth the memory.

Paged KV-Cache

Long-context inference with static KV buffers wastes memory — you must pre-allocate for max sequence length. vLLM introduced paged attention, borrowing the operating system's virtual memory idea: store KV data in fixed-size pages and use a page table to map logical positions to physical slots.

We implement two kernels: append (scatter-write new tokens into the page pool) and read (scatter-gather from pages into contiguous output). Both are purely memory-bound at AI=0.08 — near-zero compute, just data movement. Append achieves 65% of HBM peak (195 GB/s), and read achieves 59% (178 GB/s). The gap is due to the non-contiguous access pattern inherent in page indirection.

Roofline results

All 8 kernel variants mapped onto the T4 roofline. The x-axis is arithmetic intensity (FLOP/byte), the y-axis is achieved throughput (TFLOPS or effective GB/s). The diagonal ceiling is HBM bandwidth, the horizontal ceiling is fp16 Tensor Core peak.

Kernel AI (F/B) Achieved % Ceiling Bound
vector_add0.17248 GB/s83%Memory
reduce_sum0.50262 GB/s87%Memory
flash_attention34138.2 TFLOPS59%Compute
fused_gelu_linear29531.5 TFLOPS49%Compute
rope_fused3.25222 GB/s74%Memory
rope_table1.50240 GB/s80%Memory
kv_append0.08195 GB/s65%Memory
kv_read0.08178 GB/s59%Memory

Six kernels are memory-bound (65–87% of the 300 GB/s HBM ceiling), two are compute-bound (49–59% of the 65 TFLOPS fp16 ceiling). The reduction kernel is the most bandwidth-efficient at 87% — its warp-shuffle approach minimizes shared memory traffic and achieves near-peak HBM throughput.

End-to-end: plugging into GPT-2

To verify that individual kernel improvements compose into real model speedups, we integrated all custom kernels into GPT-2 (124M) via PyTorch C++ extensions. The integration module monkey-patches HuggingFace's GPT2Attention and GPT2MLP with custom implementations that call our CUDA kernels:

python — GPT-2 integration
from src.integration.gpt2_custom_kernels import patch_gpt2_model

model = GPT2LMHeadModel.from_pretrained("gpt2")
model = patch_gpt2_model(model, backend="cuda")
# Now forward() uses FlashAttention + fused GeLU

Architecture diagram

The full inference pipeline with custom kernels:

mermaid
flowchart TB Input["Input tokens"] --> Embed["Token + RoPE embedding
Custom CUDA kernel"] Embed --> Attn["Tiled FlashAttention
38.2 TFLOPS · 59% of fp16 peak"] Attn --> KV["Paged KV-Cache
195 GB/s append · 178 GB/s read"] KV --> Fused["Fused GeLU + Linear
31.5 TFLOPS · no HBM roundtrip"] Fused --> Norm["LayerNorm + Residual"] Norm --> Next["Next layer / Output"] style Embed fill:#eff6ff,stroke:#2563eb,color:#0f172a style Attn fill:#eff6ff,stroke:#2563eb,color:#0f172a style KV fill:#eff6ff,stroke:#2563eb,color:#0f172a style Fused fill:#eff6ff,stroke:#2563eb,color:#0f172a

Key takeaways

Full source, profiling data, and roofline plots are on GitHub. See the project page for the complete kernel inventory and architecture overview.