From 33b8ca4e960f04e86c700c2609c0c01d4a41cab0 Mon Sep 17 00:00:00 2001 From: lty Date: Thu, 5 Feb 2026 10:36:52 +0800 Subject: [PATCH] [Feature]KV pool supports sparse attention (#6339) ### What this PR does / why we need it? The kv pooling feature is adapted to Sparse Attention to support models such as Deepseek V3.2. ### Does this PR introduce _any_ user-facing change? NA ### How was this patch tested? ``` vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \ --host $local_ip \ --port 8002 \ --served-model-name model \ --data-parallel-size 1 \ --tensor-parallel-size 8 \ --prefill-context-parallel-size 2 \ --decode-context-parallel-size 1 \ --cp-kv-cache-interleave-size 128 \ --block-size 128 \ --enable-expert-parallel \ --no-enable-prefix-caching \ --no-enable-chunked-prefill \ --max-num-seqs 4 \ --max-model-len 8192 \ --max-num-batched-tokens 8192 \ --gpu-memory-utilization 0.95 \ --trust-remote-code \ --enforce-eager \ --quantization ascend \ --additional_config '{"ascend_scheduler_config":{"enabled":false}}' \ --kv-transfer-config \ '{ "kv_connector": "AscendStoreConnector", "kv_role": "kv_both", "kv_connector_extra_config": { "backend": "mooncake", "lookup_rpc_port":"0", "use_layerwise": false } }' ``` - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: lty --- .../kv_pool/ascend_store/config_data.py | 32 ++++----- .../kv_pool/ascend_store/pool_worker.py | 66 ++++++++----------- 2 files changed, 40 insertions(+), 58 deletions(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py index 398cc3fb..8cc1bad1 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py @@ -92,10 +92,9 @@ class LayerPoolKey(PoolKey): class ChunkedTokenDatabase: - def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, partitions: list[int] | None): + def __init__(self, metadata: KeyMetadata, block_size: int, partitions: list[int] | None): self.metadata = metadata self.block_size = block_size - self.use_mla = use_mla self.kv_caches_base_addr: list[int] = [] self.block_len: list[int] = [] self.partitions = partitions @@ -117,29 +116,24 @@ class ChunkedTokenDatabase: addr_list = [] size_list = [] block_id = block_ids[start // self.block_size] + length = len(self.block_len) for index, base_addr in enumerate(self.kv_caches_base_addr): - block_len = self.block_len[index % 2] if self.use_mla else self.block_len[0] - - addr = base_addr + block_id * block_len - length = int(block_len / self.block_size * (end - start)) + addr = base_addr + block_id * self.block_len[index % length] + size = int(self.block_len[index % length] / self.block_size * (end - start)) addr_list.append(addr) - size_list.append(length) + size_list.append(size) return addr_list, size_list, block_id def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int): block_id = block_ids[start // self.block_size] - if self.use_mla: - addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[1] - length_k = int(self.block_len[0] / self.block_size * (end - start)) - length_v = int(self.block_len[1] / self.block_size * (end - start)) - size_list = [length_k, length_v] - else: - addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[0] - length = int(self.block_len[0] / self.block_size * (end - start)) - size_list = [length, length] - addr_list = [addr_k, addr_v] + addr_list = [] + size_list = [] + length = len(self.block_len) + for i in range(length): + addr = self.kv_caches_base_addr[layer_id * length] + block_id * self.block_len[i] + size = int(self.block_len[i] / self.block_size * (end - start)) + addr_list.append(addr) + size_list.append(size) return addr_list, size_list def process_tokens( diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py index 033ea34d..832dbe3c 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py @@ -56,6 +56,7 @@ class KVPoolWorker: self.use_mla = False if hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla: self.use_mla = True + self.use_sparse = hasattr(model_config.hf_text_config, "index_topk") self.use_layerwise = use_layerwize self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -127,7 +128,7 @@ class KVPoolWorker: for i in range(2, remaining_layers + 2): partitions[-i] += 1 - self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions) + self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, partitions) backend = backend_map.get(self.backend.lower()) assert backend is not None @@ -155,55 +156,42 @@ class KVPoolWorker: _, first_kv_cache_tuple = next(iter(kv_caches.items())) first_kv_cache = first_kv_cache_tuple[0] - # TODO(tms): Find a more robust way to detect and handle MLA - if self.use_mla: - # MLA case.[num_block, block_size, 1, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - block_rank = 3 # [block_size, latent_dim] - block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] - block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] - self.block_len = [ - first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe), - ] - logger.info( - "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", - self.num_blocks, - block_shape_norm, - block_shape_pe, - ) + self.num_blocks = first_kv_cache.shape[0] + logger.info("num_blocks: %s", self.num_blocks) + block_rank = 3 + self.block_len = [] + if self.use_mla or self.use_sparse: + for i in range(len(first_kv_cache_tuple)): + block_shape = first_kv_cache_tuple[i].shape[-block_rank:] + logger.info("block_shape: %s", block_shape) + self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape)) else: # [num_block, block_size, num_head, hidden_dim] - self.num_blocks = first_kv_cache.shape[0] - kv_elem_size = first_kv_cache.element_size() - block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] - self.block_len = [kv_elem_size * math.prod(block_shape)] - logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) + logger.info("block_shape: %s", block_shape) + self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)] - logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape) + logger.info( + "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", + self.use_mla, + self.use_sparse, + first_kv_cache.shape, + ) self.kv_caches = kv_caches self.kv_caches_base_addr = [] ptrs = [] lengths = [] + length = len(self.block_len) for cache_or_caches in kv_caches.values(): # Normalize to always be a list of caches - if self.use_mla: - for i, cache in enumerate(cache_or_caches, 0): - base_addr = cache.data_ptr() - self.kv_caches_base_addr.append(base_addr) - region_len = self.num_blocks * self.block_len[i % 2] - ptrs.append(base_addr) - lengths.append(region_len) - else: - cache_list = [cache_or_caches] if self.use_mla else cache_or_caches - for cache in cache_list: - base_addr = cache.data_ptr() - self.kv_caches_base_addr.append(base_addr) - region_len = self.num_blocks * self.block_len[0] - ptrs.append(base_addr) - lengths.append(region_len) + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % length] + self.kv_caches_base_addr.append(base_addr) + ptrs.append(base_addr) + lengths.append(region_len) + self.m_store.register_buffer(ptrs, lengths) self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr) self.token_database.set_block_len(self.block_len)