[Feature]KV pool supports sparse attention (#6339)

### What this PR does / why we need it?
The kv pooling feature is adapted to Sparse Attention to support models
such as Deepseek V3.2.

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?
```
vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
  --host $local_ip \
  --port 8002 \
  --served-model-name model \
  --data-parallel-size 1 \
  --tensor-parallel-size 8 \
  --prefill-context-parallel-size 2 \
  --decode-context-parallel-size 1 \
  --cp-kv-cache-interleave-size 128 \
  --block-size 128 \
  --enable-expert-parallel \
  --no-enable-prefix-caching \
  --no-enable-chunked-prefill \
  --max-num-seqs 4 \
  --max-model-len 8192 \
  --max-num-batched-tokens 8192 \
  --gpu-memory-utilization 0.95 \
  --trust-remote-code \
  --enforce-eager \
  --quantization ascend \
  --additional_config '{"ascend_scheduler_config":{"enabled":false}}' \
  --kv-transfer-config \
    '{
            "kv_connector": "AscendStoreConnector",
            "kv_role": "kv_both",
            "kv_connector_extra_config": {
	            "backend": "mooncake",
              "lookup_rpc_port":"0",
              "use_layerwise": false
            }
    }'
```

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

Signed-off-by: lty <linhebiwen@gmail.com>
This commit is contained in:
lty
2026-02-05 10:36:52 +08:00
committed by GitHub
parent 13c4a9c78b
commit 33b8ca4e96
2 changed files with 40 additions and 58 deletions

View File

@@ -92,10 +92,9 @@ class LayerPoolKey(PoolKey):
class ChunkedTokenDatabase: class ChunkedTokenDatabase:
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, partitions: list[int] | None): def __init__(self, metadata: KeyMetadata, block_size: int, partitions: list[int] | None):
self.metadata = metadata self.metadata = metadata
self.block_size = block_size self.block_size = block_size
self.use_mla = use_mla
self.kv_caches_base_addr: list[int] = [] self.kv_caches_base_addr: list[int] = []
self.block_len: list[int] = [] self.block_len: list[int] = []
self.partitions = partitions self.partitions = partitions
@@ -117,29 +116,24 @@ class ChunkedTokenDatabase:
addr_list = [] addr_list = []
size_list = [] size_list = []
block_id = block_ids[start // self.block_size] block_id = block_ids[start // self.block_size]
length = len(self.block_len)
for index, base_addr in enumerate(self.kv_caches_base_addr): for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = self.block_len[index % 2] if self.use_mla else self.block_len[0] addr = base_addr + block_id * self.block_len[index % length]
size = int(self.block_len[index % length] / self.block_size * (end - start))
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr) addr_list.append(addr)
size_list.append(length) size_list.append(size)
return addr_list, size_list, block_id return addr_list, size_list, block_id
def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int): def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int):
block_id = block_ids[start // self.block_size] block_id = block_ids[start // self.block_size]
if self.use_mla: addr_list = []
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0] size_list = []
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[1] length = len(self.block_len)
length_k = int(self.block_len[0] / self.block_size * (end - start)) for i in range(length):
length_v = int(self.block_len[1] / self.block_size * (end - start)) addr = self.kv_caches_base_addr[layer_id * length] + block_id * self.block_len[i]
size_list = [length_k, length_v] size = int(self.block_len[i] / self.block_size * (end - start))
else: addr_list.append(addr)
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0] size_list.append(size)
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length]
addr_list = [addr_k, addr_v]
return addr_list, size_list return addr_list, size_list
def process_tokens( def process_tokens(

View File

@@ -56,6 +56,7 @@ class KVPoolWorker:
self.use_mla = False self.use_mla = False
if hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla: if hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla:
self.use_mla = True self.use_mla = True
self.use_sparse = hasattr(model_config.hf_text_config, "index_topk")
self.use_layerwise = use_layerwize self.use_layerwise = use_layerwize
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
@@ -127,7 +128,7 @@ class KVPoolWorker:
for i in range(2, remaining_layers + 2): for i in range(2, remaining_layers + 2):
partitions[-i] += 1 partitions[-i] += 1
self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions) self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, partitions)
backend = backend_map.get(self.backend.lower()) backend = backend_map.get(self.backend.lower())
assert backend is not None assert backend is not None
@@ -155,55 +156,42 @@ class KVPoolWorker:
_, first_kv_cache_tuple = next(iter(kv_caches.items())) _, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0] first_kv_cache = first_kv_cache_tuple[0]
# TODO(tms): Find a more robust way to detect and handle MLA
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0] self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim] logger.info("num_blocks: %s", self.num_blocks)
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] block_rank = 3
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] self.block_len = []
self.block_len = [ if self.use_mla or self.use_sparse:
first_kv_cache[0].element_size() * math.prod(block_shape_norm), for i in range(len(first_kv_cache_tuple)):
first_kv_cache[1].element_size() * math.prod(block_shape_pe), block_shape = first_kv_cache_tuple[i].shape[-block_rank:]
] logger.info("block_shape: %s", block_shape)
logger.info( self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks,
block_shape_norm,
block_shape_pe,
)
else: else:
# [num_block, block_size, num_head, hidden_dim] # [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:] block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)] logger.info("block_shape: %s", block_shape)
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)]
logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape) logger.info(
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
self.use_mla,
self.use_sparse,
first_kv_cache.shape,
)
self.kv_caches = kv_caches self.kv_caches = kv_caches
self.kv_caches_base_addr = [] self.kv_caches_base_addr = []
ptrs = [] ptrs = []
lengths = [] lengths = []
length = len(self.block_len)
for cache_or_caches in kv_caches.values(): for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches # Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0): for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % length]
self.kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr.append(base_addr)
region_len = self.num_blocks * self.block_len[i % 2]
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
region_len = self.num_blocks * self.block_len[0]
ptrs.append(base_addr) ptrs.append(base_addr)
lengths.append(region_len) lengths.append(region_len)
self.m_store.register_buffer(ptrs, lengths) self.m_store.register_buffer(ptrs, lengths)
self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr) self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr)
self.token_database.set_block_len(self.block_len) self.token_database.set_block_len(self.block_len)