TBQ4_0: TurboBlockQuant CUDA Flash Attention Implementation Guide¶
TL;DR¶
TBQ4_0 compresses the KV cache in LLM inference from 16 bits per value down to ~4.1 bits, using a fixed rotation matrix + Lloyd-Max codebook. This lets you fit ~4x more context in the same memory. The quality cost is tiny: +0.87% perplexity on Llama-3.1-8B (well within the 2% acceptable threshold).
This guide explains how TBQ4_0 works, how it integrates with llama.cpp's CUDA flash attention, and how to build, run, and modify the code. Everything lives on one branch and compiles to one binary.
Branch: feat/tbq4-cuda-fa-sm121 Paper: TurboQuant (arXiv 2504.19874) Discussion: llama.cpp #20969 Upstream PR: llama.cpp #21089
Table of Contents¶
- What is TurboQuant?
- The Algorithm
- The Data Format (block_tbq4_0)
- CUDA Flash Attention Integration
- File Map
- How to Build and Test
- Key Concepts Glossary
- Common Questions
- What's Next: TBQ3_0
1. What is TurboQuant?¶
The Problem: KV Cache Eats Your Memory¶
When an LLM generates text, it stores the Key and Value tensors from every previous token in a structure called the KV cache. This is what lets the model "remember" the conversation. The catch: each token adds data to the cache for every layer and every attention head. At 16 bits per value (fp16), a 32-layer model with 32 heads and head_dim=128 needs:
32 layers x 32 heads x 128 dims x 2 bytes x 2 (K+V) = 512 KB per token
At 65K tokens, that is 33 GB just for the KV cache. On a system with 128 GB total (like the DGX Spark), the model weights already consume ~95 GB for a large MoE model, leaving precious little room for context.
In plain English: The KV cache is a running transcript of everything the model has seen. Compress the transcript, fit more conversation.
The Tradeoff: Memory vs Quality¶
Standard approaches like q4_0 (uniform 4-bit quantization) are simple -- divide values into
blocks of 32, find a scale factor, and round each value to one of 16 levels. This works, but the
rounding errors accumulate, especially in the Key cache where errors directly affect which tokens
get attended to.
TurboQuant takes a different approach: instead of brute-force rounding, it first rotates the data into a domain where the values are more uniform (closer to a Gaussian distribution), then applies an optimal codebook designed specifically for Gaussian data. The rotation is a fixed, precomputed orthogonal matrix -- no per-token overhead.
Result: TBQ4_0 achieves +0.87% PPL at 4.125 bits per value, compared to q4_0's +2.05% at 4.5 bits per value. Better quality AND better compression.
How Does TBQ4_0 Compare?¶
| Format | Bits/value | PPL vs f16 | Block size | Technique |
|---|---|---|---|---|
| f16 | 16.0 | baseline | -- | No compression |
| q8_0 | 8.5 | ~0% | 32 | Uniform 8-bit |
| q4_0 | 4.5 | +2.05% | 32 | Uniform 4-bit |
| TBQ4_0 | 4.125 | +0.87% | 128 | Rotation + Lloyd-Max 4-bit |
| TBQ3_0 (upstream CPU) | 3.06 | +0.81% | 256 | Rotation + Lloyd-Max 3-bit |
2. The Algorithm¶
TBQ4_0 quantizes a 128-dimensional vector (one attention head's worth of data for one token) in three steps. Here is the full pipeline:
Step 1: Normalize and Rotate¶
Input: x = [0.23, -0.15, 0.87, ...] (128 floats from one attention head)
┌─────────────┐
x -> │ Compute norm │ -> norm = ||x|| = 2.34
└──────┬──────┘
v
┌─────────────┐
│ Normalize │ -> u = x / norm = [0.098, -0.064, 0.372, ...]
└──────┬──────┘
v
┌─────────────────────────────────────────┐
│ Rotate: r = R * u │
│ │
│ R is a fixed 128x128 orthogonal matrix │
│ (precomputed from seed, same on all │
│ devices, deterministic forever) │
└──────┬──────────────────────────────────┘
v
┌──────────────┐
│ Scale up by │ -> r_scaled = r * sqrt(128)
│ sqrt(128) │ (maps unit-normalized values to N(0,1) range)
└──────┬───────┘
v
r_scaled = [-1.42, 0.89, 2.31, ...] (now roughly standard-normal)
Why rotate? Attention head values are not uniformly distributed -- some dimensions carry much more signal than others. A random orthogonal rotation "spreads" the information evenly across all 128 dimensions, making the distribution closer to Gaussian. This is the key insight from the TurboQuant paper.
In plain English: Imagine you have a photo where all the detail is in one corner. Rotating the image spreads the detail everywhere, so when you compress it, you lose less important information.
The rotation matrix R is generated once from a fixed seed (0x517cc1b727220a95) using Householder
QR decomposition of a Gaussian random matrix. Because the seed is fixed, every device generates the
exact same matrix. It is orthogonal, meaning R^T * R = I (the inverse is just the transpose).
Reference: Hadamard Transform (Wikipedia) -- note: our implementation uses a random orthogonal matrix rather than the Walsh-Hadamard transform, but the principle is the same: spread information uniformly before quantizing.
Step 2: Lloyd-Max Quantization¶
After rotation, the 128 values are approximately standard-normal. We quantize each value to one of 16 levels (4 bits) using an optimal codebook:
r_scaled = [-1.42, 0.89, 2.31, ...]
|
v
┌─────────────────────────────────────────────────────┐
│ For each value, find the nearest centroid: │
│ │
│ Centroids (16 levels): │
│ idx: 0 1 2 3 4 5 6 7 │
│ val: -2.73 -2.07 -1.62 -1.26 -0.94 -0.66 -0.39 -0.13 │
│ │
│ idx: 8 9 10 11 12 13 14 15 │
│ val: 0.13 0.39 0.66 0.94 1.26 1.62 2.07 2.73 │
│ │
│ Decision boundaries (15 midpoints): │
│ -2.40 -1.84 -1.44 -1.10 -0.80 -0.52 -0.26 0.00 │
│ 0.26 0.52 0.80 1.10 1.44 1.84 2.40 │
│ │
│ Example: -1.42 -> idx=2 (between -1.84 and -1.10) │
│ 0.89 -> idx=11 (between 0.80 and 1.10) │
│ 2.31 -> idx=14 (between 1.84 and 2.40) │
└─────────────────────────────────────────────────────┘
|
v
indices = [2, 11, 14, ...] (128 x 4-bit values)
These 16 centroids are the Lloyd-Max optimal quantizer for the standard normal distribution. "Lloyd-Max optimal" means: if your input data follows N(0,1), these 16 levels minimize the mean squared error of quantization. No other set of 16 levels can do better.
In plain English: Standard quantization divides the number line into equal slices. Lloyd-Max puts more slices where the data is dense (near zero) and fewer where it is sparse (the tails). Like having more ruler markings in the range you actually use.
Quantization is branchless: The GPU code uses 15 comparisons (no branches, no loops):
static __device__ __forceinline__ uint8_t tbq4_quantize_gpu(float val) {
uint8_t idx = 0;
idx += (val >= TBQ4_MIDPOINTS[ 0]); idx += (val >= TBQ4_MIDPOINTS[ 1]);
idx += (val >= TBQ4_MIDPOINTS[ 2]); // ... 15 total comparisons
// Each comparison adds 0 or 1. Result: idx = number of midpoints <= val
return idx;
}
Reference: Lloyd's Algorithm (Wikipedia)
Step 3: Store the Norm¶
The original vector's L2 norm is stored alongside the indices as a single fp16 value. During
dequantization, each centroid is multiplied by norm / sqrt(128) to reconstruct the original
scale.
Stored block (66 bytes total):
qs[64 bytes]: 128 x 4-bit indices, packed 2 per byte
d[2 bytes]: fp16 norm of original vector
Reconstruction:
value[i] = centroid[idx[i]] * (1/sqrt(128)) * norm
The 1/sqrt(128) factor (= 0.08838834764831845) undoes the scale-up from Step 1. In the code,
it is precomputed into the centroid table to save one multiply per element:
constexpr float C[16] = {
-2.7326f * 0.08838834764831845f, // = -0.2416
-2.0690f * 0.08838834764831845f, // = -0.1829
// ... pre-multiplied centroids
};
The Full Round-Trip¶
Quantize (SET_ROWS kernel, runs once when K/V is stored):
x -> normalize -> rotate(R) -> scale_up -> quantize_4bit -> pack -> store
Dequantize (inside flash attention, runs every token):
load -> unpack_4bit -> lookup_centroid * norm -> (still in rotated domain)
The rotated-domain trick:
K is stored rotated. At attention time, Q is also rotated.
<R*K, R*Q> = <K, Q> because R is orthogonal (R^T * R = I).
V accumulation happens in rotated domain; output is inverse-rotated once.
In plain English: Instead of un-rotating every K vector (expensive, per-token), we rotate Q once (cheap, one vector) and do all the math in the rotated world. At the end, we un-rotate the output once. Net cost: two 128x128 matrix-vector multiplies per attention call, regardless of how many KV tokens there are.
3. The Data Format (block_tbq4_0)¶
Struct Definition¶
#define QK_TBQ4 128
typedef struct {
uint8_t qs[QK_TBQ4 / 2]; // 64 bytes: 4-bit codebook indices (2 per byte)
ggml_half d; // 2 bytes: block L2 norm
} block_tbq4_0; // 66 bytes total = 4.125 bits per value
Byte Layout Diagram¶
Offset Size Contents
------ ---- --------
0 64B qs[0..63]: packed 4-bit indices
Each byte holds 2 indices:
low nibble = element 2*i
high nibble = element 2*i+1
64 2B d: fp16 norm (L2 norm of original 128-element vector)
------ ----
66B total per 128 elements
Example: qs[5] = 0xB3
Element 10: index = 0x3 = 3 -> centroid = -1.2562
Element 11: index = 0xB = 11 -> centroid = 0.9424
Bit Packing Detail¶
The packing is standard llama.cpp 4-bit nibble packing:
byte index: qs[0] qs[1] qs[2] ... qs[63]
┌────┬────┐┌────┬────┐┌────┬────┐ ┌────┬────┐
elements: │ e0 │ e1 ││ e2 │ e3 ││ e4 │ e5 │ │e126│e127│
└────┴────┘└────┴────┘└────┴────┘ └────┴────┘
bits: 3..0 7..4 3..0 7..4 3..0 7..4 3..0 7..4
Extract element i:
if (i % 2 == 0): idx = qs[i/2] & 0x0F // low nibble
if (i % 2 == 1): idx = (qs[i/2] >> 4) & 0x0F // high nibble
Comparison with q4_0¶
| Property | q4_0 | TBQ4_0 |
|---|---|---|
| Block size | 32 elements | 128 elements |
| Storage per block | 18 bytes (16B data + 2B scale) | 66 bytes (64B data + 2B norm) |
| Bits per value | 4.5 | 4.125 |
| Metadata | fp16 scale (max absolute value) | fp16 norm (L2 magnitude) |
| Codebook | Uniform (scale * [-8..7]) | Lloyd-Max optimal for N(0,1) |
| Rotation | None | 128x128 orthogonal matrix |
| PPL vs f16 | +2.05% | +0.87% |
In plain English: q4_0 is like a ruler with 16 evenly spaced marks. TBQ4_0 is like a ruler with marks concentrated where the measurements cluster, plus a rotation step that makes the measurements cluster predictably.
4. CUDA Flash Attention Integration¶
How llama.cpp's Flash Attention Works¶
llama.cpp implements flash attention as a set of CUDA kernel templates in fattn-vec.cuh. The
"vec" variant handles single-token decode (the common case during text generation). Here is the
simplified flow:
flash_attn_ext_vec<type_K, type_V, D, ...>():
1. Load Q (query) into registers
2. For each KV token (the inner loop):
a. Compute KQ[i] = dot(K[i], Q) // calls vec_dot_KQ()
b. Compute softmax incrementally
c. Accumulate V[i] weighted by attention // calls dequantize_V()
3. Write final output
The kernel is templated on the K and V types. Adding a new quantization format requires:
- A vec_dot_fattn_vec_KQ_<type>() function (K dot product)
- A dequantize_V_<type>() function (V element access)
- Registering these in the get_vec_dot_KQ() and get_dequantize_V() dispatch tables
- A template instance file that instantiates the kernel for the specific K/V type combination
What We Added for TBQ4_0¶
Four things, each in a specific location:
1. Q Rotation (fattn-vec.cuh, before the KV loop)¶
Before the main attention loop begins, we rotate Q into the same domain as the stored K:
if constexpr (type_K == GGML_TYPE_TBQ4_0) {
// Use shared memory (KQ buffer, which is unused at this point)
// to perform a 128x128 matrix-vector multiply:
// Q_rotated = R * Q
//
// Each thread writes its Q elements to shared memory,
// then all threads cooperate on the matrix multiply,
// then each thread reads back its rotated Q elements.
}
This runs once per attention call, not once per KV token. Cost: one 128x128 matvec = 16K multiply-adds, which takes ~0.5 microseconds on SM121. Compare to the KV loop which does 128 multiplies per token times potentially thousands of tokens.
The shared memory trick: The rotation needs all 128 Q elements visible to all threads
simultaneously. The kernel's KQ shared memory buffer (used later for softmax) is repurposed here
since it is not yet in use. No extra shared memory allocation needed.
2. Codebook Lookup in K Dot Product (fattn-common.cuh)¶
template<int D, int nthreads>
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_tbq4_0(...) {
// For each pair of elements:
// 1. Load one byte from K block (contains two 4-bit indices)
// 2. Look up centroids: cn[idx0], cn[idx1]
// 3. Multiply by rotated Q: sum += cn[idx0]*Q.x + cn[idx1]*Q.y
//
// cn[16] = centroid * norm, cached per-block (norm changes once per 128 elements)
}
The centroid table cn[16] is recomputed only when the block changes (every 128 elements). Within
a block, it is a simple table lookup + multiply-add. This is the hot path -- it runs for every KV
token in the sequence.
3. V Dequantization (fattn-common.cuh)¶
template <typename T, int ne>
static __device__ __forceinline__ void dequantize_V_tbq4_0(...) {
// For each V element needed:
// 1. Load byte, extract nibble
// 2. value = centroid[idx] * norm
// Output stays in rotated domain (inverse rotation happens later)
}
V dequantization is simpler than K because there is no dot product -- we just need the raw (rotated) values to accumulate the weighted sum.
4. Output Inverse Rotation (fattn-vec.cuh, after the KV loop)¶
if constexpr (type_V == GGML_TYPE_TBQ4_0) {
// The accumulated output is in rotated domain.
// Apply R^T (transpose = inverse for orthogonal R):
// output_final = R^T * output_rotated
//
// Same shared memory trick as Q rotation.
}
This also runs once per attention call. The mathematical justification:
Standard attention: O = sum_j (a_j * V_j)
With rotation: O = sum_j (a_j * R * V_orig_j) // V stored as R*V
= R * sum_j (a_j * V_orig_j) // R is linear
= R * O_original
So: O_original = R^T * O // R^T = R^-1
Why This is O(1) Per-Token Overhead¶
The rotation cost is: - Q rotation: 128 x 128 = 16,384 multiply-adds (once) - Output inverse rotation: 16,384 multiply-adds (once) - Total: 32,768 multiply-adds, independent of sequence length
The per-token cost in the KV loop is: - K dot product: 128 multiply-adds + 64 byte loads + 16 table lookups (same as q4_0) - V dequant: 128 multiplies + 64 byte loads (same as q4_0)
In plain English: We pay a small fixed tax (two matrix multiplies) to avoid paying a per-token tax. At 1000 tokens of context, the fixed cost is < 1% of total work. At 10,000 tokens it is < 0.1%. The longer the context, the more "free" the rotation becomes.
5. File Map¶
All files are on the branch feat/tbq4-cuda-fa-sm121. Paths are relative to the repo root.
Core Type Definition¶
| File | What it does |
|---|---|
ggml/include/ggml.h |
Defines GGML_TYPE_TBQ4_0 = 43 in the type enum |
ggml/src/ggml-common.h |
Defines block_tbq4_0 struct, QK_TBQ4 = 128 |
ggml/src/ggml.c |
Registers type traits (name, block size, quantize/dequant functions) |
ggml/src/ggml-quants.h |
Declares quantize_row_tbq4_0_ref, dequantize_row_tbq4_0, quantize_tbq4_0 |
common/arg.cpp |
Adds "tbq4_0" to the --cache-type-k / --cache-type-v CLI options |
CPU Quantization (Reference Implementation)¶
| File | What it does |
|---|---|
ggml/src/ggml-turbo-quant.c |
CPU quantize + dequant for TBQ4_0 (and turbo3/turbo4). Contains the rotation matrix, codebook, and scalar quantizer. This is the "ground truth" implementation. |
ggml/src/tbq-rotation-128.h |
128x128 rotation matrix as a C array (shared between CPU and CUDA) |
CUDA Kernels¶
| File | What it does |
|---|---|
ggml/src/ggml-cuda/tbq-quant.cu |
SET_ROWS kernel: GPU quantization (normalize, rotate, quantize, pack). Also: dequant kernels for convert ops (contiguous + non-contiguous). Loads rotation matrix to device memory on first use. |
ggml/src/ggml-cuda/tbq-rotation-128.h |
Same 128x128 rotation matrix, with __device__ qualifier for CUDA. Included by tbq-quant.cu (device global memory) and fattn-common.cuh (compile-time constant for FA). |
ggml/src/ggml-cuda/fattn-common.cuh |
Flash attention building blocks: vec_dot_fattn_vec_KQ_tbq4_0 (K dot product) and dequantize_V_tbq4_0 (V element access). Also registers TBQ4_0 in the type dispatch tables. |
ggml/src/ggml-cuda/fattn-vec.cuh |
Flash attention vec kernel: Q pre-rotation (before KV loop) and output inverse-rotation (after KV loop). Also: thread count configuration for TBQ4_0, Q_q8_1 exclusion. |
ggml/src/ggml-cuda/fattn.cu |
Flash attention dispatch: Routes TBQ4_0 to vec kernel, adds FATTN_VEC_CASE(128, TBQ4_0, TBQ4_0). Also contains the (bypassed) shadow cache infrastructure. |
Dispatch and Wiring¶
| File | What it does |
|---|---|
ggml/src/ggml-cuda/ggml-cuda.cu |
Adds TBQ4_0 to supports_op for FLASH_ATTN_EXT, GET_ROWS, and SET_ROWS |
ggml/src/ggml-cuda/convert.cu |
Adds TBQ4_0 to all dequant dispatch tables (fp16, fp32, bf16, contiguous, non-contiguous) |
ggml/src/ggml-cuda/set-rows.cu |
Adds TBQ4_0 dispatch: calls ggml_cuda_op_set_rows_tbq4 |
ggml/src/ggml-cuda/turbo-quant.cuh |
Declares ggml_cuda_op_set_rows_tbq4 and tbq_ensure_rotation_loaded |
ggml/src/ggml-cuda/CMakeLists.txt |
Adds fattn-vec-instance-tbq4_0-tbq4_0.cu to build |
Template Instances¶
| File | What it does |
|---|---|
ggml/src/ggml-cuda/template-instances/fattn-vec-instance-tbq4_0-tbq4_0.cu |
Instantiates flash_attn_ext_vec<128, TBQ4_0, TBQ4_0>. Only D=128 is supported. |
KV Cache Setup¶
| File | What it does |
|---|---|
src/llama-kv-cache.cpp |
Tensor overhead calculation includes +2 for turbo rotation tensors |
6. How to Build and Test¶
Prerequisites¶
- NVIDIA GPU with SM121 (DGX Spark GB10) or compatible architecture
- CUDA Toolkit 13.0+ (13.2 recommended)
- CMake 3.21+, Ninja
- A GGUF model file (Llama-3.1-8B-Instruct Q4_K_M recommended for testing)
Build Commands¶
cd ~/workspace/llama-cpp-turboquant
# Clean build for SM121 with flash attention quant support
cmake -B build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_CUDA=ON \
-DCMAKE_CUDA_ARCHITECTURES=121 \
-DGGML_CUDA_FA_ALL_QUANTS=ON
cmake --build build -j$(nproc)
Key CMake flags:
- CMAKE_CUDA_ARCHITECTURES=121: Target SM121 (DGX Spark). Change to 89 for RTX 4090, 90 for H100, etc.
- GGML_CUDA_FA_ALL_QUANTS=ON: Required. Enables flash attention template instances for non-standard quant types (turbo3, turbo4, TBQ4_0). Without this, only f16/q4_0/q8_0 FA kernels are compiled.
Run Perplexity (Quality Test)¶
./build/bin/llama-perplexity \
-m ~/workspace/models/Llama-3.1-8B-Instruct-Q4_K_M.gguf \
-ctk tbq4_0 -ctv tbq4_0 \
-c 2048 --chunks 4 \
-ngl 99
Expected output (Llama-3.1-8B):
Final estimate: PPL = 7.6852 +/- 0.0xxx
Compare to f16 baseline (PPL = 7.6186), that is +0.87% -- well within the 2% threshold.
To get the f16 baseline:
./build/bin/llama-perplexity \
-m ~/workspace/models/Llama-3.1-8B-Instruct-Q4_K_M.gguf \
-ctk f16 -ctv f16 \
-c 2048 --chunks 4 \
-ngl 99
Run llama-bench (Speed Test)¶
./build/bin/llama-bench \
-m ~/workspace/models/Llama-3.1-8B-Instruct-Q4_K_M.gguf \
-ctk tbq4_0 -ctv tbq4_0 \
-ngl 99 \
-p 512 -n 128
This measures:
- pp512: Prefill throughput (tokens/sec for 512-token prompt processing)
- tg128: Text generation throughput (tokens/sec for 128-token generation)
Expected Results (DGX Spark, Llama-3.1-8B Q4_K_M)¶
| Config (K/V) | PPL | vs f16 | pp512 tok/s | tg128 tok/s |
|---|---|---|---|---|
| f16 / f16 | 7.6186 | -- | 3059 | 43.65 |
| tbq4_0 / tbq4_0 | 7.6852 | +0.87% | ~1400 | ~35 |
| q4_0 / q4_0 | 7.7748 | +2.05% | ~3000 | ~42 |
| q8_0 / turbo3 | 7.6459 | +0.36% | 1577 | 38.25 |
Note: TBQ4_0 prefill is slower than q4_0 because the rotation adds work during the SET_ROWS (quantization) phase. Decode throughput is comparable. The payoff is quality: TBQ4_0 is significantly better than q4_0 at similar or better compression.
Other Useful Configs¶
# Asymmetric: q8_0 keys (better quality) + tbq4_0 values (compressed)
# Not yet implemented as a template instance -- would need:
# fattn-vec-instance-q8_0-tbq4_0.cu
# But you can test turbo3 values with q8_0 keys (already implemented):
./build/bin/llama-perplexity \
-m ~/workspace/models/Llama-3.1-8B-Instruct-Q4_K_M.gguf \
-ctk q8_0 -ctv turbo3 \
-c 2048 --chunks 4 -ngl 99
# Expected: PPL 7.6459 (+0.36%)
7. Key Concepts Glossary¶
Flash Attention¶
An algorithm for computing attention without materializing the full N x N attention matrix. Instead of storing all attention scores in memory (O(N^2) space), it processes KV tokens in tiles and accumulates results incrementally. llama.cpp's implementation uses CUDA kernel templates with different specializations per quantization type. Reference: Flash Attention paper (arXiv 2205.14135)
KV Cache¶
The stored Key and Value tensors from all previous tokens in a generation sequence. Each new token
adds one K vector and one V vector per layer per attention head. The KV cache is the main memory
bottleneck for long-context inference.
In llama.cpp: Managed by llama-kv-cache.cpp. Type controlled by --cache-type-k and
--cache-type-v CLI flags.
Walsh-Hadamard Transform (WHT)¶
A fast orthogonal transform that can be computed in O(n log n) using a butterfly pattern (similar to FFT). The earlier turbo3/turbo4 implementations use WHT for rotation. TBQ4_0 uses a full random orthogonal matrix instead (O(n^2) but only n=128 so it does not matter). The WHT is relevant to the turbo3 code that coexists on this branch. Reference: Hadamard Transform (Wikipedia)
Lloyd-Max Quantization¶
An optimal scalar quantizer for a known distribution. Given that you have N bits (2^N levels) and the input follows distribution P, Lloyd-Max finds the N levels and decision boundaries that minimize mean squared error. For TBQ4_0, the input is standard-normal (after rotation), and the 16 optimal levels are precomputed. Reference: Lloyd's Algorithm (Wikipedia)
SM121¶
NVIDIA's compute capability for the GB10 Grace Blackwell chip in the DGX Spark. SM121 has ~99 KB of shared memory per block (not 228 KB -- that is SM100/B200). It supports all standard CUDA features needed for flash attention. Reference: DGX Spark specs
bpw (Bits Per Weight / Bits Per Value)¶
A measure of compression. fp16 = 16 bpw. q4_0 = 4.5 bpw (4 bits data + amortized scale). TBQ4_0 = 4.125 bpw (4 bits data + amortized norm over 128 elements). Lower bpw = more compression = more context fits in memory.
Shared Memory¶
Fast on-chip SRAM in CUDA GPUs, shared among all threads in a block. Much faster than global
memory (~100x lower latency). Used in TBQ4_0 for the Q rotation and output inverse rotation, by
repurposing the KQ buffer that the flash attention kernel already allocates.
Register Pressure¶
Each CUDA thread has a limited number of registers (~255 on most architectures). Storing a full 128-element float array requires 128 x 4 = 512 bytes = 128 registers, which is more than available. The compiler "spills" excess to local memory (slower). The SET_ROWS kernel has high register pressure because it holds the full vector during rotation. The FA kernel avoids this by using shared memory for the rotation step.
Orthogonal Matrix¶
A square matrix R where R^T * R = I (identity). This means R preserves lengths and angles:
||Rx|| = ||x|| and
Template Instance (llama.cpp specific)¶
llama.cpp compiles flash attention kernels as C++ templates parameterized by K type, V type, and
head dimension D. Each combination needs an explicit instantiation in a .cu file. TBQ4_0 has one:
fattn-vec-instance-tbq4_0-tbq4_0.cu for D=128 with both K and V as TBQ4_0.
8. Common Questions¶
Q: Why not just use q4_0? It is simpler.¶
q4_0 has +2.05% PPL degradation, which is right at the community's "acceptable" threshold. TBQ4_0 achieves +0.87% at slightly better compression (4.125 vs 4.5 bpw). For production use where quality matters, TBQ4_0 is strictly superior. The complexity cost is real but contained -- all the rotation logic is in ~200 lines of CUDA code.
Q: Why is prefill slower with TBQ4_0?¶
Prefill processes many tokens at once, and each token's KV vector must be quantized (rotated + codebook-encoded) via the SET_ROWS kernel. This 128x128 matrix multiply per vector adds latency during prefill. During decode (the steady state), only one new token is quantized, so the overhead is negligible. The decode speed is limited by the KV loop, which is comparable to q4_0.
Q: Can I use TBQ4_0 for K and q4_0 for V (or other combinations)?¶
Not yet with this branch. Each K/V type combination needs a template instance. Currently only
tbq4_0/tbq4_0 is instantiated. Adding tbq4_0/q4_0 or q8_0/tbq4_0 requires adding a new
.cu file with the template instantiation and a dispatch entry in fattn.cu. The kernel code
itself is generic -- it is just a compilation issue.
Q: What happens if head_dim is not 128?¶
TBQ4_0 currently requires head_dim = 128 (the rotation matrix is 128x128). This covers most
popular models (Llama, Mistral, MiniMax M2.5, etc.). Models with head_dim=64 (some small models)
or head_dim=256 (DeepSeek) would need different rotation matrices. The algorithm generalizes; the
code does not yet.
Q: Why store the rotation matrix in the code instead of in the model file?¶
The rotation matrix is deterministic (generated from a fixed seed). Storing it in the code means: - No model file format changes needed - Every device generates the same matrix - No risk of mismatched rotation matrices between quantizer and dequantizer
The tradeoff is ~64 KB of compiled-in data (128 x 128 x 4 bytes). This is negligible for a GPU binary.
Q: Is TBQ4_0 the same as the upstream PR #21089?¶
Similar but not identical. The upstream PR uses 256-element blocks (QK_K=256) and 3.06 bpw (3-bit codebook). Our TBQ4_0 uses 128-element blocks (matching head_dim=128 for simpler rotation handling) and a 4-bit codebook (4.125 bpw). The upstream format is CPU-only; ours has full CUDA flash attention support. The rotation matrix generation uses a different seed than upstream's WHT-based approach, but the mathematical principle is the same.
Q: What is the difference between TBQ4_0 and turbo3/turbo4?¶
turbo3 and turbo4 are earlier implementations on this same branch. They use a Walsh-Hadamard Transform (fast, O(n log n)) with smaller block sizes and a 3-bit codebook. turbo3 works correctly for V but has quality issues for K (+5.12%). turbo4 adds a QJL (Quantized Johnson-Lindenstrauss) residual but has a broken native vec kernel.
TBQ4_0 is the clean redesign: larger blocks (128), 4-bit codebook, full random rotation, and a simpler/more robust implementation. It matches the upstream paper's approach more faithfully.
Q: Why is there both a CPU and GPU implementation?¶
The CPU implementation in ggml-turbo-quant.c serves as the reference/fallback. It is used when:
- KV cache is not offloaded to GPU (--no-kv-offload)
- Running on CPU-only systems
- Validating correctness (CPU results must match GPU results)
The CUDA kernels in tbq-quant.cu and fattn-common.cuh are the production path for GPU
inference.
9. What's Next: TBQ3_0¶
The current TBQ4_0 implementation validates the architecture. The next step is TBQ3_0 -- a 3-bit variant that would achieve ~3 bpw, enabling significantly more context.
Projected Impact¶
| Config | bpw | Est. Context (M2.5, 128GB) | PPL vs f16 |
|---|---|---|---|
| q4_0/q4_0 (current) | 4.50 | ~65K | +2.05% |
| tbq4_0/tbq4_0 | 4.125 | ~72K | +0.87% |
| tbq4_0/tbq3_0 | ~3.5 | ~85K | ~+2-3% (estimated) |
| tbq3_0/tbq3_0 | ~3.0 | ~96K | ~+7% (needs optimization) |
What Needs to Happen¶
-
3-bit packing: Pack 128 x 3-bit indices into 48 bytes. The bit packing is trickier than 4-bit (3 does not divide 8 evenly). The upstream PR (#21089) handles this with a specific layout that packs indices in groups.
-
3-bit codebook: 8 Lloyd-Max optimal centroids for N(0,1). Already defined in the codebase as
CENTROIDS_3BIT. -
FA kernel template: New
vec_dot_fattn_vec_KQ_tbq3_0anddequantize_V_tbq3_0functions, plus template instances. -
Quality validation: The upstream CPU-only TBQ3_0 achieves +0.81% PPL for q8_0/tbq3_0 and +7.01% for tbq3_0/tbq3_0. The sweet spot may be tbq4_0 K + tbq3_0 V (compressed V is more forgiving than compressed K).
Roadmap¶
- Phase 1 (done): TBQ4_0 CUDA FA -- validates rotation + codebook in GPU flash attention
- Phase 2 (next): Asymmetric configs -- q8_0/tbq4_0, tbq4_0/q8_0 template instances
- Phase 3: TBQ3_0 format + CUDA FA kernels
- Phase 4: Test on MiniMax M2.5 at 85K+ context
- Phase 5: MMA kernel support (currently vec-only; MMA would improve prefill throughput)
Appendix: Mathematical Proof of Rotated-Domain Correctness¶
For readers who want to verify the math:
K dot product (attention scores):
Standard: score = <K, Q>
Stored K: K_stored = R * (K_normalized) * norm_K
Rotated Q: Q_rot = R * Q
score = <K_stored, Q_rot>
= sum_i (R * K_norm * norm_K)_i * (R * Q)_i
= norm_K * (R * K_norm)^T * (R * Q)
= norm_K * K_norm^T * R^T * R * Q // matrix transpose
= norm_K * K_norm^T * I * Q // R^T * R = I (orthogonal)
= norm_K * <K_norm, Q>
= <K, Q> // K = norm_K * K_norm
V accumulation (attention output):
Standard: O = sum_j a_j * V_j
Stored V: V_stored_j = R * V_norm_j * norm_j
Rotated O: O_rot = sum_j a_j * V_stored_j
= sum_j a_j * R * V_norm_j * norm_j
= R * sum_j a_j * V_j // R is linear, V_j = norm_j * V_norm_j
= R * O
Therefore: O = R^T * O_rot // multiply both sides by R^T
This is why the output inverse rotation (R^T * O_rot) recovers the correct attention output.
Document generated 2026-03-27. Based on code at commit 3cf3a716e on branch feat/tbq4-cuda-fa-sm121.