[Feature]KV pool supports sparse attention (#6339)
### What this PR does / why we need it?
The kv pooling feature is adapted to Sparse Attention to support models
such as Deepseek V3.2.
### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?
```
vllm serve /mnt/weight/DeepSeek-V3.2-Exp-W8A8 \
--host $local_ip \
--port 8002 \
--served-model-name model \
--data-parallel-size 1 \
--tensor-parallel-size 8 \
--prefill-context-parallel-size 2 \
--decode-context-parallel-size 1 \
--cp-kv-cache-interleave-size 128 \
--block-size 128 \
--enable-expert-parallel \
--no-enable-prefix-caching \
--no-enable-chunked-prefill \
--max-num-seqs 4 \
--max-model-len 8192 \
--max-num-batched-tokens 8192 \
--gpu-memory-utilization 0.95 \
--trust-remote-code \
--enforce-eager \
--quantization ascend \
--additional_config '{"ascend_scheduler_config":{"enabled":false}}' \
--kv-transfer-config \
'{
"kv_connector": "AscendStoreConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"backend": "mooncake",
"lookup_rpc_port":"0",
"use_layerwise": false
}
}'
```
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
Signed-off-by: lty <linhebiwen@gmail.com>
This commit is contained in:
@@ -92,10 +92,9 @@ class LayerPoolKey(PoolKey):
|
||||
|
||||
|
||||
class ChunkedTokenDatabase:
|
||||
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, partitions: list[int] | None):
|
||||
def __init__(self, metadata: KeyMetadata, block_size: int, partitions: list[int] | None):
|
||||
self.metadata = metadata
|
||||
self.block_size = block_size
|
||||
self.use_mla = use_mla
|
||||
self.kv_caches_base_addr: list[int] = []
|
||||
self.block_len: list[int] = []
|
||||
self.partitions = partitions
|
||||
@@ -117,29 +116,24 @@ class ChunkedTokenDatabase:
|
||||
addr_list = []
|
||||
size_list = []
|
||||
block_id = block_ids[start // self.block_size]
|
||||
length = len(self.block_len)
|
||||
for index, base_addr in enumerate(self.kv_caches_base_addr):
|
||||
block_len = self.block_len[index % 2] if self.use_mla else self.block_len[0]
|
||||
|
||||
addr = base_addr + block_id * block_len
|
||||
length = int(block_len / self.block_size * (end - start))
|
||||
addr = base_addr + block_id * self.block_len[index % length]
|
||||
size = int(self.block_len[index % length] / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(length)
|
||||
size_list.append(size)
|
||||
return addr_list, size_list, block_id
|
||||
|
||||
def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int):
|
||||
block_id = block_ids[start // self.block_size]
|
||||
if self.use_mla:
|
||||
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[1]
|
||||
length_k = int(self.block_len[0] / self.block_size * (end - start))
|
||||
length_v = int(self.block_len[1] / self.block_size * (end - start))
|
||||
size_list = [length_k, length_v]
|
||||
else:
|
||||
addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0]
|
||||
addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[0]
|
||||
length = int(self.block_len[0] / self.block_size * (end - start))
|
||||
size_list = [length, length]
|
||||
addr_list = [addr_k, addr_v]
|
||||
addr_list = []
|
||||
size_list = []
|
||||
length = len(self.block_len)
|
||||
for i in range(length):
|
||||
addr = self.kv_caches_base_addr[layer_id * length] + block_id * self.block_len[i]
|
||||
size = int(self.block_len[i] / self.block_size * (end - start))
|
||||
addr_list.append(addr)
|
||||
size_list.append(size)
|
||||
return addr_list, size_list
|
||||
|
||||
def process_tokens(
|
||||
|
||||
@@ -56,6 +56,7 @@ class KVPoolWorker:
|
||||
self.use_mla = False
|
||||
if hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla:
|
||||
self.use_mla = True
|
||||
self.use_sparse = hasattr(model_config.hf_text_config, "index_topk")
|
||||
self.use_layerwise = use_layerwize
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -127,7 +128,7 @@ class KVPoolWorker:
|
||||
for i in range(2, remaining_layers + 2):
|
||||
partitions[-i] += 1
|
||||
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions)
|
||||
self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, partitions)
|
||||
|
||||
backend = backend_map.get(self.backend.lower())
|
||||
assert backend is not None
|
||||
@@ -155,55 +156,42 @@ class KVPoolWorker:
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks,
|
||||
block_shape_norm,
|
||||
block_shape_pe,
|
||||
)
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
logger.info("num_blocks: %s", self.num_blocks)
|
||||
block_rank = 3
|
||||
self.block_len = []
|
||||
if self.use_mla or self.use_sparse:
|
||||
for i in range(len(first_kv_cache_tuple)):
|
||||
block_shape = first_kv_cache_tuple[i].shape[-block_rank:]
|
||||
logger.info("block_shape: %s", block_shape)
|
||||
self.block_len.append(first_kv_cache[i].element_size() * math.prod(block_shape))
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
kv_elem_size = first_kv_cache.element_size()
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
self.block_len = [kv_elem_size * math.prod(block_shape)]
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape)
|
||||
logger.info("block_shape: %s", block_shape)
|
||||
self.block_len = [first_kv_cache.element_size() * math.prod(block_shape)]
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
||||
self.use_mla,
|
||||
self.use_sparse,
|
||||
first_kv_cache.shape,
|
||||
)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
self.kv_caches_base_addr = []
|
||||
ptrs = []
|
||||
lengths = []
|
||||
length = len(self.block_len)
|
||||
for cache_or_caches in kv_caches.values():
|
||||
# Normalize to always be a list of caches
|
||||
if self.use_mla:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % length]
|
||||
self.kv_caches_base_addr.append(base_addr)
|
||||
ptrs.append(base_addr)
|
||||
lengths.append(region_len)
|
||||
|
||||
self.m_store.register_buffer(ptrs, lengths)
|
||||
self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr)
|
||||
self.token_database.set_block_len(self.block_len)
|
||||
|
||||
Reference in New Issue
Block a user