[Refactor] Adapt deepseek-v3.2 to vllm 0.11.0 (#3432)
### What this PR does / why we need it? Adapt deepseek-v3.2 to vllm 0.11.0, removing the useless patch. The final goal is to remove all the patches and align the code arch to vllm, thus we need to do the following work in next prs. TODO: - [x] remove patch on attention spec - [ ] refactor the kvcache creation logic ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? 1. CI passed with existing test. 2. Test pass with deepseek-v3.2-exp - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -501,7 +501,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sfa: bool = len(first_kv_cache_tuple) == 3
|
||||
self.use_sparse: bool = len(first_kv_cache_tuple) == 3
|
||||
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
||||
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
|
||||
# MHA case. [2 (k and v), num_blocks, ...]
|
||||
@@ -549,7 +549,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
elif self.use_sfa:
|
||||
elif self.use_sparse:
|
||||
cache_k_normed_addr_list = []
|
||||
cache_k_pe_addr_list = []
|
||||
cache_k_idx_addr_list = []
|
||||
@@ -887,7 +887,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
elif self.use_sfa:
|
||||
elif self.use_sparse:
|
||||
remote_cache_key_k_normed = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=0)
|
||||
remote_cache_key_k_pe = BlocksCacheKey(
|
||||
|
||||
@@ -242,7 +242,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
self.block_len = block_len
|
||||
# TODO(jianzs): find a better way to detect MLA.
|
||||
self.use_mla = len(block_len) == 2
|
||||
self.use_sfa = len(block_len) == 3
|
||||
self.use_sparse = len(block_len) == 3
|
||||
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||
@@ -373,7 +373,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
|
||||
if self.use_mla:
|
||||
block_len = (self.block_len[k % 2])
|
||||
elif self.use_sfa:
|
||||
elif self.use_sparse:
|
||||
block_len = (self.block_len[k % 3])
|
||||
else:
|
||||
block_len = (self.block_len[0])
|
||||
@@ -850,7 +850,8 @@ class MooncakeConnectorScheduler:
|
||||
assert "tp_size" in decode_parallel_config.keys()
|
||||
self._decode_tp_size = decode_parallel_config["tp_size"]
|
||||
num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
||||
if self.vllm_config.model_config.use_mla or hasattr(
|
||||
self.vllm_config.model_config.hf_config, "index_topk"):
|
||||
num_need_pulls = 1
|
||||
else:
|
||||
num_p_block_heads = max(
|
||||
@@ -942,7 +943,7 @@ class MooncakeConnectorWorker:
|
||||
# kv_transfer variables
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
|
||||
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
|
||||
self.num_need_pulls = 1
|
||||
else:
|
||||
num_d_block_heads = max(1,
|
||||
@@ -995,7 +996,7 @@ class MooncakeConnectorWorker:
|
||||
self.use_mla = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sfa = len(first_kv_cache_tuple) == 3
|
||||
self.use_sparse = len(first_kv_cache_tuple) == 3
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -1009,7 +1010,7 @@ class MooncakeConnectorWorker:
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
elif self.use_sfa:
|
||||
elif self.use_sparse:
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
@@ -1037,8 +1038,8 @@ class MooncakeConnectorWorker:
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
|
||||
self.use_mla, self.use_sfa, first_kv_cache.shape)
|
||||
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
||||
self.use_mla, self.use_sparse, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
@@ -1050,7 +1051,7 @@ class MooncakeConnectorWorker:
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
elif self.use_sfa:
|
||||
elif self.use_sparse:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 3]
|
||||
@@ -1059,7 +1060,7 @@ class MooncakeConnectorWorker:
|
||||
else:
|
||||
cache_list = [
|
||||
cache_or_caches
|
||||
] if self.use_mla or self.use_sfa else cache_or_caches
|
||||
] if self.use_mla or self.use_sparse else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
@@ -1156,9 +1157,9 @@ class MooncakeConnectorWorker:
|
||||
sampled_nums = []
|
||||
ori_data = np.arange(self._prefill_tp_size)
|
||||
# random split prefill tp list
|
||||
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
|
||||
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
|
||||
# use deepseek mla, num_key_value_heads == 128, but consider as 1
|
||||
if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
|
||||
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
|
||||
num_kv_head = 1
|
||||
else:
|
||||
num_kv_head = self.num_key_value_heads
|
||||
@@ -1279,4 +1280,4 @@ def ensure_zmq_recv(
|
||||
logger.error(f"Receive failed after all retries: {e}")
|
||||
raise RuntimeError(
|
||||
f"Failed to receive data after {max_retries} "
|
||||
f"retries: {e}")
|
||||
f"retries: {e}")
|
||||
|
||||
Reference in New Issue
Block a user