Add DeepSeek V3.2 support (#3270)
### What this PR does / why we need it? This PR added the initial DeepSeek V3.2 support with [vLLM v0.11.0](https://github.com/vllm-project/vllm/tree/releases/v0.11.0) (not released yet). We will complete vLLM adaptation as soon as possible. This feature will be ready in recent 1-2 days. Related doc: https://github.com/vllm-project/vllm-ascend/pull/3223 . ### Does this PR introduce _any_ user-facing change? Yes! ### How was this patch tested? CI passed and Run deepseek doc soon. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wxsIcey <1790571317@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -493,8 +493,11 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
assert self.local_agent_metadata is not None
|
||||
kv_cache_dtype = first_kv_cache.dtype
|
||||
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1)
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sfa: bool = len(first_kv_cache_tuple) == 3
|
||||
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
||||
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
|
||||
# MHA case. [2 (k and v), num_blocks, ...]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
@@ -540,6 +543,58 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
elif self.use_sfa:
|
||||
cache_k_normed_addr_list = []
|
||||
cache_k_pe_addr_list = []
|
||||
cache_k_idx_addr_list = []
|
||||
k_normed = None
|
||||
k_pe = None
|
||||
k_idx = None
|
||||
for cache_or_caches in kv_caches.values():
|
||||
assert len(cache_or_caches) > 1
|
||||
k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[
|
||||
1], cache_or_caches[2]
|
||||
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
||||
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
||||
cache_k_idx_addr_list.append(k_idx.data_ptr())
|
||||
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list,
|
||||
cache_k_idx_addr_list)
|
||||
|
||||
cache_desc_k_normed = CacheDesc(
|
||||
len(self.cache_addr[0]), [*k_normed.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_pe = CacheDesc(
|
||||
len(self.cache_addr[1]), [*k_pe.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_idx = CacheDesc(
|
||||
len(self.cache_addr[2]), [*k_idx.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=0)
|
||||
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=1)
|
||||
cache_key_k_idx = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=2)
|
||||
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe,
|
||||
cache_desc_k_idx)
|
||||
self.cache_key = (cache_key_k_normed, cache_key_k_pe,
|
||||
cache_key_k_idx)
|
||||
try:
|
||||
cache_k_normed = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
||||
cache_k_pe = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
||||
cache_k_idx = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[2], self.cache_addr[2], self.cache_key[2])
|
||||
self.cache = (cache_k_normed, cache_k_pe, cache_k_idx)
|
||||
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
else:
|
||||
for cache_or_caches in kv_caches.values():
|
||||
for cache in cache_or_caches:
|
||||
@@ -826,6 +881,38 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
elif self.use_sfa:
|
||||
remote_cache_key_k_normed = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=0)
|
||||
remote_cache_key_k_pe = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=1)
|
||||
remote_cache_key_k_idx = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=2)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_normed,
|
||||
self.cache[0], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_pe,
|
||||
self.cache[1], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_idx,
|
||||
self.cache[2], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
else:
|
||||
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
|
||||
Reference in New Issue
Block a user