[Refactor] Adapt deepseek-v3.2 to vllm 0.11.0 (#3432)

### What this PR does / why we need it?
Adapt deepseek-v3.2 to vllm 0.11.0, removing the useless patch.

The final goal is to remove all the patches and align the code arch to
vllm, thus we need to do the following work in next prs.
TODO:
- [x] remove patch on attention spec
- [ ] refactor the kvcache creation logic

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
1. CI passed with existing test.
2. Test pass with deepseek-v3.2-exp


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-10-15 17:48:58 +08:00
committed by GitHub
parent 099255e933
commit 8abe517870
20 changed files with 143 additions and 262 deletions

View File

@@ -501,7 +501,7 @@ class LLMDataDistCMgrConnectorWorker():
self.use_mla: bool = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sfa: bool = len(first_kv_cache_tuple) == 3
self.use_sparse: bool = len(first_kv_cache_tuple) == 3
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
# MHA case. [2 (k and v), num_blocks, ...]
@@ -549,7 +549,7 @@ class LLMDataDistCMgrConnectorWorker():
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
elif self.use_sfa:
elif self.use_sparse:
cache_k_normed_addr_list = []
cache_k_pe_addr_list = []
cache_k_idx_addr_list = []
@@ -887,7 +887,7 @@ class LLMDataDistCMgrConnectorWorker():
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
elif self.use_sfa:
elif self.use_sparse:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(

View File

@@ -242,7 +242,7 @@ class KVCacheRecvingThread(threading.Thread):
self.block_len = block_len
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.use_sfa = len(block_len) == 3
self.use_sparse = len(block_len) == 3
self.request_queue: queue.Queue[Any] = queue.Queue()
self.executor = ThreadPoolExecutor(max_workers=32)
@@ -373,7 +373,7 @@ class KVCacheRecvingThread(threading.Thread):
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
if self.use_mla:
block_len = (self.block_len[k % 2])
elif self.use_sfa:
elif self.use_sparse:
block_len = (self.block_len[k % 3])
else:
block_len = (self.block_len[0])
@@ -850,7 +850,8 @@ class MooncakeConnectorScheduler:
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
if self.vllm_config.model_config.use_mla or hasattr(
self.vllm_config.model_config.hf_config, "index_topk"):
num_need_pulls = 1
else:
num_p_block_heads = max(
@@ -942,7 +943,7 @@ class MooncakeConnectorWorker:
# kv_transfer variables
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
self.num_need_pulls = 1
else:
num_d_block_heads = max(1,
@@ -995,7 +996,7 @@ class MooncakeConnectorWorker:
self.use_mla = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sfa = len(first_kv_cache_tuple) == 3
self.use_sparse = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
@@ -1009,7 +1010,7 @@ class MooncakeConnectorWorker:
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
elif self.use_sfa:
elif self.use_sparse:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
@@ -1037,8 +1038,8 @@ class MooncakeConnectorWorker:
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info(
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
self.use_mla, self.use_sfa, first_kv_cache.shape)
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
self.use_mla, self.use_sparse, first_kv_cache.shape)
self.kv_caches = kv_caches
kv_caches_base_addr = []
@@ -1050,7 +1051,7 @@ class MooncakeConnectorWorker:
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
elif self.use_sfa:
elif self.use_sparse:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
@@ -1059,7 +1060,7 @@ class MooncakeConnectorWorker:
else:
cache_list = [
cache_or_caches
] if self.use_mla or self.use_sfa else cache_or_caches
] if self.use_mla or self.use_sparse else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]
@@ -1156,9 +1157,9 @@ class MooncakeConnectorWorker:
sampled_nums = []
ori_data = np.arange(self._prefill_tp_size)
# random split prefill tp list
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
# use deepseek mla, num_key_value_heads == 128, but consider as 1
if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
num_kv_head = 1
else:
num_kv_head = self.num_key_value_heads
@@ -1279,4 +1280,4 @@ def ensure_zmq_recv(
logger.error(f"Receive failed after all retries: {e}")
raise RuntimeError(
f"Failed to receive data after {max_retries} "
f"retries: {e}")
f"retries: {e}")