Skip to content

Sovereign AI on a Desktop, Part 2: Five Bugs in NVIDIA's Code

Mihai Chiorean | March 2026

Series: Sovereign AI on a Desktop

Part 1: The Stack -- What I'm running and why Part 2: Five Bugs in NVIDIA's Code -- Fixing TensorRT-LLM for DGX Spark (you are here) Part 3: The Autoresearcher -- An AI agent that optimizes its own inference Part 4: 100K Context -- KV cache compression via TurboQuant Part 5: The Bandwidth Wall -- What actually limits a $3,000 desktop


The DGX Spark shipped with a promise of local AI. The hardware can load a 229B model. But when I tried to run MiniMax M2.5 through NVIDIA's own TensorRT-LLM at native FP4 tensor core speed, it did not work.

Five bugs. Five layers of the stack. Each one small -- the largest fix is ~100 lines. But together they completely blocked MoE inference on consumer Blackwell hardware. Nobody at NVIDIA had run the full pipeline on a Spark or RTX 5090 before I did.

And after fixing all five: the tensor core path is still a dead end for this model. That is the punchline. I will explain why at the end.

The Motivation: 427 TFLOPS of FP4

The DGX Spark's GB10 GPU has 192 tensor cores capable of 427 TFLOPS of FP4 compute. TensorRT-LLM's NVFP4 format is designed specifically for Blackwell's hardware-accelerated 4-bit path -- no software dequantization, just native tensor core operations on 4-bit weights.

llama.cpp at Q3_K_XL gives me 24 tok/s with software dequantization on CUDA cores. The tensor cores sit mostly idle during decode. If NVFP4 through TRT-LLM could use those tensor cores, the performance ceiling goes up. That was the theory.

Bug 1: The Shared Memory Overflow Nobody Saw

PR #12141 | Merged

When I tried to run TensorRT-LLM's FP4 GEMM kernels on the DGX Spark, I got a cryptic kErrorInternal error. No useful message. Just a silent failure deep in the CUTLASS kernel dispatch.

After digging through NVIDIA's CUTLASS code, I found the problem: the FP4 GEMM tile configurations for SM120 (the Blackwell architecture family that includes RTX 5090 and DGX Spark) were copied from SM100 (the datacenter B200/B100) without accounting for a critical hardware difference.

SM100 (B200): ~227 KiB shared memory per block. SM12x (RTX 5090, DGX Spark): ~99 KiB shared memory per block.

The tile sizes 128x128x256B and 256x128x128B were designed for SM100's generous shared memory budget. On SM12x, CUTLASS's StageCountAutoCarveout would compute pipeline stage sizes at compile time using the architecture's SharedMemoryCapacity, and the resulting SharedStorage struct simply did not fit. This is a compile-time structural failure -- you cannot fix it at runtime.

The fix: add a CtaShape128x128x128B tile configuration that fits within SM12x's 99 KiB limit. This tile was already compiled in the codebase (for bf16, fp16, and fp32 output types) but was never wired into the SM120 dispatch path. Three changes:

  1. Added the 128x128x128B dispatch case to dispatchNVFP4xNVFP4GemmCTAShapeSm120
  2. Made it the default tile config for SM120+ (the autotuner still profiles all candidates)
  3. Left the SM100 path completely untouched -- datacenter GPUs use their own tile configs

An NVIDIA engineer (@eugr) caught a comment error in my first submission -- I had incorrectly labeled which arch had which SMEM size. Fixed in the same PR. NVIDIA's CI ran the full test suite twice, it passed, and it was merged on March 18.

This unblocks FP4 GEMM for every SM12x device: RTX 5090, RTX 5080, and DGX Spark.

Bug 2: The Unnecessary Memcpy

PR #12301 | Open

TensorRT-LLM's KV cache manager has a tiered memory system: a primary pool (GPU memory) and a secondary pool (CPU memory). When context gets long, it "offloads" KV cache blocks from GPU to CPU, and "onboards" them back when needed. On a datacenter GPU with separate VRAM and system RAM, this makes sense -- you are moving data between fast and slow memory.

On the DGX Spark, it is absurd. CPU and GPU share the same 128GB LPDDR5x pool. There is no "GPU memory" vs "CPU memory" -- it is all the same physical memory accessible via the same memory controller at the same bandwidth. The offload/onboard memcpy operations are copying data from a physical address to... the same physical address. Zero benefit, wasted bandwidth, added latency.

This matters because bandwidth is THE bottleneck on the Spark. At 273 GB/s, every unnecessary memory operation directly competes with the actual inference computation. On a bandwidth-limited device, wasting cycles on no-op copies is the difference between good and bad performance.

My fix adds runtime detection of unified memory using cudaDevAttrPageableMemoryAccess and three optimizations:

  1. Skip offload/onboard memcpy when unified memory is detected. Block metadata bookkeeping still happens -- just no pointless copy.
  2. Fold secondary cache into primary pool. Instead of two artificial tiers backed by the same memory, give the block manager one large pool. Eliminates eviction/promotion overhead entirely.
  3. Allocate secondary blocks as GPU memory. On unified memory, use BufferManager::managed instead of BufferManager::pinned, avoiding page-locking overhead.

All three optimizations are gated by a single runtime check. On discrete GPU systems, cudaDevAttrPageableMemoryAccess returns false and the code path is never entered. Zero risk of regression.

Bug 3: The Python Gate That Blocked All MoE Models

PR #12309 | Open

The CUTLASS FP4 kernels worked (thanks to Bug 1 being fixed). But the Python dispatch layer above them still threw NotImplementedError for SM120/SM121. The MoE routing code in fused_moe_trtllm_gen.py had an explicit gate: only SM100 and SM103 (datacenter Blackwell) were allowed. The consumer/desktop Blackwell chips were simply rejected before the code ever reached the working kernels.

This was not a hardware limitation -- it was an allowlist that nobody had updated. The fix extends can_implement() across four files to include SM120/SM121, and routes them to the TRTLLM MoE backend instead of letting them fall through to an incompatible CUTLASS path.

This unblocks MiniMax M2.5, DeepSeek, Qwen, and every other MoE architecture on consumer Blackwell.

Bug 4: The Autotuner Crash

PR #12310 | Open

With the MoE gate fixed, TRT-LLM got further -- into the autotuner warmup phase, where it immediately crashed with IndexError: list assignment index out of range. The autotuner's _find_nearest_profile() assumes that every operation produces tensors with the same shape dimensions as the profiling spec expects. On SM121, some ops produce fewer or differently-shaped tensors, and the indexing goes out of bounds.

This is the kind of bug that only surfaces on hardware the test suite does not cover. The fix adds bounds checks before indexing into the base profile -- out-of-bounds specs are skipped with a debug log, in-bounds behavior is unchanged. Eight test cases covering normal, out-of-range, and mixed scenarios.

Bug 5: The Wrong Attention Kernels

PR #12311 | Open

Past the autotuner, the next crash: the attention backend. TRT-LLM's trtllm-gen FMHA (Flash Multi-Head Attention) runner only has compiled cubins for SM100/SM103. SM120 uses different ISA instructions (mma.sync.aligned.block_scale vs tcgen05.mma), so SM100 cubins cannot just be reused -- they are fundamentally incompatible at the instruction set level.

The fix routes SM120/SM121 to the FMHA v2 fallback, which does work on SM12x. It is not the fastest path (trtllm-gen FMHA is optimized for SM100), but it runs correctly. SM100/SM103 keep using trtllm-gen FMHA with zero regression. The patch also adds clear warning messages instead of cryptic assertion failures, so the next person hitting this on a 5090 or Spark does not spend hours in a debugger.

The Pattern

Five PRs. Same root cause every time: TensorRT-LLM was built and tested on datacenter Blackwell (SM100/SM103). Consumer and desktop Blackwell (SM120/SM121) shares the architecture name but has different shared memory sizes, different ISA instructions, and different resource limits.

PR Layer Problem Status
#12141 CUTLASS GEMM Tile configs exceed SM12x shared memory Merged
#12301 KV cache No-op memcpy wastes bandwidth on unified memory Open
#12309 MoE dispatch Python allowlist blocks SM120/SM121 Open
#12310 Autotuner Index out of bounds on SM121 tensor shapes Open
#12311 Attention SM100 cubins incompatible with SM120 ISA Open

Each fix is small. But they span the entire stack from CUDA kernels to Python dispatch to the serving runtime. You cannot just fix one layer. You have to understand the full pipeline and where SM12x diverges from SM100 at each level.

The Convergence: NVIDIA Is Building the Same Week

While I was fixing SM121 blockers from the bottom up, NVIDIA engineers were building from the top down -- during GTC week. Three related PRs landed the same week as my five:

  • #12302: Core Qwen 3.5 model support (dense + MoE architectures)
  • #12265: Qwen 3.5 NVFP4 MoE performance (lm_head sharding, TP8 for the 400B variant)
  • #11997 (from @scottgl9): Ungate FusedMoE for SM120/SM121

The DGX Spark launched with a promise of local AI, but the software stack was not ready. Now the community and NVIDIA are converging on the same target from different directions.

The Dead End

Here is the part nobody wants to write: after fixing all five bugs, after getting MoE inference to run on SM121, after unblocking the tensor core FP4 path... the model does not fit.

MiniMax M2.5 at NVFP4 quantization is approximately 228GB. The DGX Spark has 128GB. The model is 100GB too large for the hardware.

NVFP4 is a 4-bit format. Q3_K_XL (what llama.cpp uses) is roughly a 3-bit format. Those extra bits per parameter cost 133GB of additional memory -- the difference between "fits with room for KV cache" and "does not fit at all." Tensor cores cannot help below 4-bit because the hardware FP4 datapath requires FP4 (E2M1) format at minimum.

This is not a bug to fix. It is a fundamental constraint: Blackwell's tensor cores speak FP4 at the lowest precision, and 4-bit M2.5 is too large for 128GB. The only inference path that fits a 229B model in 128GB is 3-bit or lower quantization, which means GGUF on llama.cpp, which means CUDA cores for decode, which means the tensor cores sit idle during generation.

The five bugs were real and worth fixing -- they unblock every SM12x user who wants to run TRT-LLM, and they unblock smaller MoE models that fit in 128GB at NVFP4. For the RTX 5090 (32GB VRAM + system RAM), the shared memory fix alone enables FP4 GEMM that was previously broken. For MiniMax M2.5 on the Spark specifically, the tensor core path is a dead end.

I spent two weeks on this dead end. I do not regret it. The fixes help the ecosystem, and the exercise taught me exactly where the hardware limits are. But the lesson is blunt: on a 128GB machine, memory is the constraint that dominates everything else. Tensor core throughput does not matter if the model does not fit.

What Actually Works

llama.cpp with Q3_K_XL at 95GB. CUDA cores for decode. INT8 tensor cores for prefill (already, via MMQ kernels). 24 tok/s. It is not glamorous. It is what fits.

With that established, the question became: given the llama.cpp path, how do I get the most out of it? That led to the autoresearcher -- an AI agent that optimizes its own inference configuration. Part 3 tells that story.


Next: Part 3: The Autoresearcher -- I spent two days manually tuning parameters. Then I automated myself out of the loop.


Mihai Chiorean is a software engineer in San Francisco. Previously CTO at Wendy Labs (edge OS on Yocto/Jetson), EM at Cash App (compliance rules engine, $100B+ txn volume), and engineer at Uber, Block/TBD, and InVision. He builds sovereign AI systems on NVIDIA hardware and contributes to TensorRT-LLM and NemoClaw.