vLLM’s hash chain: automatic prefix detection via content hashing
vLLM‘s Automatic Prefix Caching (APC) uses a content-hashing technique that eliminates the need for explicit prefix tracking. The KV cache is divided into fixed-size blocks (typically 16 tokens on CUDA). Each block is hashed with SHA-256, and the hash of block N incorporates the hashes of all preceding blocks (0 through N-1) plus block N’s own token content. This creates a hash chain: a content-dependent fingerprint for every block that encodes the full token history up to that point.
When a new request arrives, vLLM hashes each block-sized chunk of the input and looks up each hash in a cached block map. If a matching block exists, the engine reuses the precomputed KV cache for that block. If not, it allocates a new block. The hash table gives O(1) lookup per block. The doubly-linked list embedded in each KVCacheBlock gives O(1) LRU eviction.
Because each block’s hash depends on all preceding content, two requests with identical blocks at position N will produce the same hash at position N only if all preceding tokens are also identical. This means the system automatically detects shared prefixes without any explicit prefix-matching logic. The lookup is content-addressable, but it is not position-independent: the hash at any position is a function of the full preceding sequence, not just the block’s local content.
It is still prefix-bound
The hash chain is elegant, but it does not escape the prefix constraint. The dependency on all preceding blocks means that any divergence at any position invalidates every block after it. If two requests share the first 3,000 tokens but differ at token 3,001, the hash chain breaks at the block containing token 3,001, and every subsequent block gets a different hash. The reuse stops at the divergence point, exactly like traditional prefix caching.
This matters because real workloads are not always prefix-shaped. Consider a RAG pipeline where the system prompt is stable but the retrieved chunks vary between requests and appear in different orders. Even if two requests contain the same five documents, swapping the order of documents 2 and 3 breaks the hash chain from that point forward. The KV cache for documents 4 and 5 cannot be reused despite being token-for-token identical, because the preceding hash is different.
The same limitation applies to tool-calling scenarios where tool results are injected at variable positions, multi-turn conversations where the user message changes the middle of the context, and any workload where shared content appears after a point of divergence. The hash chain detects shared prefixes efficiently. It cannot detect shared suffixes or shared segments in non-prefix positions.
There are two additional constraints worth noting:
Block-aligned matching. Reuse only happens at block boundaries. If two requests share 1,000 tokens and the block size is 16, vLLM reuses 62 blocks (992 tokens) and recomputes the last 8. The waste is at most block_size - 1 tokens per request. For long prefixes this is negligible (under 0.4% for a 4K prefix). For short or variable-length shared segments it is more visible.
No sub-block matching. Two blocks that share 15 of 16 tokens but differ in one token produce completely different hashes. The hash is all-or-nothing at the block level. There is no partial reuse within a block.
The bottom line: vLLM’s hash chain is a well-engineered implementation of prefix caching, not a generalization beyond it. It automatically discovers shared prefixes via content hashing, which eliminates the need for explicit prefix tracking. But the reuse model is the same: shared prefixes are reusable, shared non-prefix content is not. SGLang’s radix tree, discussed below, has the same fundamental constraint.
Research directions like CacheBlend aim to close this gap by enabling chunk-level reuse beyond prefix positions. The first post in this series covers that in detail.
How vLLM implements it
The implementation lives in vllm/v1/core/kv_cache_manager.py. The key components:
KVCacheBlock: contains ablock_id(immutable physical address), ablock_hash(assigned when the block is full, cleared on eviction), a reference count, and doubly-linked list pointers for the free queue.BlockPool: manages all physical blocks. Maintains a free block queue (doubly linked list in LRU order) and a cached block map (hash table mappingBlockHashWithGroupIdtoKVCacheBlock).hash_request_tokens(): inkv_cache_utils.py, computes SHA-256 hashes for each block of tokens in a request. The hash for each block includes the parent block’s hash, creating the chain.
The design is intentionally simple. Blocks are either matched or not. Eviction is LRU. The hash table provides O(1) lookup. Fixed-size blocks map cleanly to GPU memory without fragmentation. The design documentation describes the full implementation.
References: RFC #2614 is the original design discussion for automatic prefix caching. Issue #16016 discusses cache salting (allowing users to partition the cache namespace).
RadixAttention: token-level radix tree matching
SGLang takes a different approach. Instead of hashing fixed-size blocks, it stores KV cache entries in a radix tree (a compressed trie) indexed by token sequences. The radix tree allows prefix matching at arbitrary token boundaries, not just at block-aligned positions.
Core data structures
The implementation lives in python/sglang/srt/mem_cache/radix_cache.py. The key components:
RadixCache: the main cache class. Maintains the radix tree and handles insertion, matching, and eviction.TreeNode: each node in the radix tree contains children (keyed by token subsequences), a parent reference, the token sequence as the key, and KV cache indices as the value.match_prefix(): traverses the radix tree to find the longest cached prefix for a given token sequence.cache_finished_req(): stores a completed request’s token sequence and KV cache indices in the radix tree.
How lookup works
When a new request arrives, match_prefix() walks the radix tree from the root, following edges that match the request’s token sequence. The walk continues as long as matching edges exist. The deepest matching node determines how many prefix tokens can be reused. The remaining tokens (the suffix) require fresh prefill computation.
Prefix matching can happen at page granularity when page_size > 1, with prefixes aligned to page boundaries. The default behavior allows token-level matching.
Specialized variants
SGLang extends the base radix cache for different model architectures and deployment scales:
- Sliding Window Attention cache: adapted for models with sliding window attention (e.g., Gemma 4), where only the window needs caching.
- Mamba cache: for state-space models that use recurrent state instead of KV pairs.
- HiRadix cache: hierarchical multi-tier caching with GPU memory as L1, host memory as L2, and distributed storage (e.g., Mooncake) as L3.
The scheduler (python/sglang/srt/managers/scheduler.py) integrates radix tree matching directly into scheduling decisions, enabling cache-aware request routing.
Token-level granularity. The radix tree can match prefixes at any token boundary, not just block boundaries. This means higher effective hit rates when shared prefixes have variable lengths. For multi-turn conversations where each turn appends to the prefix, the radix tree captures the exact shared portion without wasting tokens at block boundaries.
Structural overhead. Maintaining a radix tree requires dynamic allocation and pointer traversal, which introduces CPU overhead compared to a flat hash table lookup. The tree also creates memory fragmentation over time as nodes are allocated and freed. Under high concurrency, both the traversal cost and the allocation patterns can become visible, though in practice they are small relative to the GPU-side computation.
Eviction. When the cache is full, SGLang evicts leaf nodes from the radix tree in LRU order. Evicting a leaf frees its KV cache indices but preserves the parent node and any shared prefixes that other requests still reference. This means partially-shared prefixes are retained as long as at least one descendant is live. The tradeoff is that the tree structure itself consumes host memory proportional to the number of distinct prefixes, not just the number of cached KV entries.
Multi-turn awareness. A recent fix (PR #16521) corrected the radix cache key computation to include generated tokens in multi-turn scenarios, using token_ids = (req.origin_input_ids + req.output_ids)[:kv_committed_len]. This ensures that the cache correctly represents the full conversation history, not just the original input.
Learning resource: mini-sglang is a simplified educational implementation of SGLang’s radix cache design. The unified radix tree refactor (issue #20415) tracks ongoing work to consolidate the cache variants.
When the difference matters
| Dimension | vLLM (APC) | SGLang (RadixAttention) | Best for |
|---|---|---|---|
| Data structure | Hash table (block hash to physical block) | Radix tree (token sequence to KV indices) | — |
| Match granularity | Block-aligned (typically 16 tokens) | Token-level (or page-aligned) | SGLang: variable-length turns |
| Lookup complexity | O(1) per block (hash table) | O(prefix length) (tree traversal) | vLLM: high-concurrency steady state |
| Eviction | LRU via doubly-linked list, O(1) | LRU within tree structure | vLLM: simpler capacity planning |
| Multi-turn support | Hash chain captures full history | Radix tree naturally extends with turns | SGLang: branching conversations |
| Variable-length prefixes | Rounded down to block boundary | Exact match at any boundary | SGLang: short/irregular prefixes |
| CPU overhead | Lower (flat hash lookup) | Higher (tree traversal, dynamic allocation) | vLLM: CPU-constrained deployments |
| Hierarchical storage | Not built-in (external via llm-d) | HiRadix: GPU / host / distributed tiers | SGLang: multi-tier caching |
| Memory layout | Fixed blocks, mmap-friendly | Dynamic tree nodes, pointer-heavy | vLLM: persistent/shared cache stores |
| Architecture variants | Single implementation | Separate caches for SWA, Mamba, hierarchical | SGLang: non-standard architectures |
Where vLLM’s approach has advantages
Templated workloads with consistent prefix lengths. If every request starts with the same 4K-token system prompt, both engines will cache it effectively. vLLM’s flat hash lookup should have lower CPU overhead per request than a tree traversal, though this has not been profiled directly. The block-aligned design is also simpler to reason about for capacity planning: each block is a known fixed size, eviction is straightforward LRU, and the memory layout maps cleanly to fixed-size GPU allocations without fragmentation concerns.
mmap-friendly memory layout. vLLM’s fixed-size block design aligns naturally to memory-mapped I/O. Each block is a contiguous, fixed-size region that can be mapped directly from persistent storage without deserialization or pointer reconstruction. This matters for KV cache offloading to NVMe or shared storage: the blocks are self-describing (identified by content hash) and position-independent in physical memory. A radix tree’s pointer-heavy, dynamically-allocated structure does not share this property. For systems that want to persist, share, or transfer KV cache blocks across processes or nodes, vLLM’s flat layout is a more natural fit.
Where SGLang’s approach has advantages
Multi-turn conversations and variable-length shared context. In a multi-turn chat, each turn extends the prefix by a different number of tokens. The radix tree captures the exact shared history without boundary waste. RunPod benchmarked both engines on multi-turn scenarios with DeepSeek-R1-Distill-Llama-70B on 2xH100 and found SGLang maintained higher throughput under cache pressure, though the gap was roughly 10%, not an order of magnitude. Configuration differences (batch size, cache size, eviction tuning) make third-party benchmarks hard to generalize from. The architectural advantage of token-level matching is real but its production magnitude depends on the specific workload shape.
SGLang’s scheduler also uses radix tree matching for cache-aware request routing, directing requests to the node most likely to have a cache hit. This is an advantage for multi-node deployments where cache locality affects performance.
Where both converge
For the agentic workloads described in the the first post in this series, where stable system prompts and tool definitions create long, consistent prefixes, both approaches work well. The prefix is long enough that block boundary waste is negligible (for a 4K prefix with 16-token blocks, wasted recomputation is at most 15 tokens, under 0.4%), and the reuse pattern is regular enough that hash-based matching is efficient.
At production scale, the choice between vLLM and SGLang for prefix caching is rarely the binding constraint. Model support, runtime maturity, deployment tooling, and team familiarity typically dominate the decision. The data structure difference becomes visible mainly in multi-turn workloads with high concurrency and variable turn lengths.
Key files and links
vLLM
vllm/v1/core/kv_cache_manager.py: KVCacheManager, BlockPool, KVCacheBlockvllm/v1/core/kv_cache_utils.py:hash_request_tokens(), block hashing (SHA-256)- Design documentation: prefix caching
- RFC #2614: Automatic Prefix Caching
- Issue #16016: Cache Salting RFC
SGLang
radix_cache.py: core radix tree implementationswa_radix_cache.py: sliding window attention varianthiradix_cache.py: hierarchical multi-tier caching- Issue #20415: Unified radix tree refactor
- PR #16521: Multi-turn cache key fix
- mini-sglang: educational implementation
Khawaja Shams

