Add DeepSeek V3.2 support (#3270)

### What this PR does / why we need it?

This PR added the initial DeepSeek V3.2 support with [vLLM
v0.11.0](https://github.com/vllm-project/vllm/tree/releases/v0.11.0)
(not released yet). We will complete vLLM adaptation as soon as
possible. This feature will be ready in recent 1-2 days.

Related doc: https://github.com/vllm-project/vllm-ascend/pull/3223 .

### Does this PR introduce _any_ user-facing change?
Yes!

### How was this patch tested?
CI passed and Run deepseek doc soon.


- vLLM version: v0.11.0rc3
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.0

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: zzzzwwjj <1183291235@qq.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: zzzzwwjj <1183291235@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: wxsIcey <1790571317@qq.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
wangxiyuan
2025-09-30 03:25:58 +08:00
committed by GitHub
parent 5503a3142f
commit 81bd6e4c99
27 changed files with 4354 additions and 70 deletions

View File

@@ -30,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -238,6 +239,7 @@ class KVCacheRecvingThread(threading.Thread):
self.block_len = block_len
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.use_sfa = len(block_len) == 3
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
@@ -349,8 +351,12 @@ class KVCacheRecvingThread(threading.Thread):
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
block_len = (self.block_len[k % 2]
if self.use_mla else self.block_len[0])
if self.use_mla:
block_len = (self.block_len[k % 2])
elif self.use_sfa:
block_len = (self.block_len[k % 3])
else:
block_len = (self.block_len[0])
for i, remote_block_id in enumerate(grouped_remote_block_ids):
local_block_ids = grouped_local_block_ids[i]
src = src_layer_base_addr + local_block_ids[0] * block_len
@@ -567,6 +573,7 @@ class MooncakeConnectorScheduler:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.ascend_config = get_ascend_config()
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
logger.info("Initializing Mooncake Scheduler %s", engine_id)
@@ -726,7 +733,7 @@ class MooncakeConnectorScheduler:
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
if self.vllm_config.model_config.use_mla:
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
return self._decode_tp_size
else:
# TODO support mha and gqa
@@ -847,7 +854,9 @@ class MooncakeConnectorWorker:
# TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1)
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sfa = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
@@ -861,6 +870,21 @@ class MooncakeConnectorWorker:
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
elif self.use_sfa:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
self.num_blocks, block_shape_norm, block_shape_pe,
block_shape_k)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
@@ -871,8 +895,9 @@ class MooncakeConnectorWorker:
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
logger.info(
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
self.use_mla, self.use_sfa, first_kv_cache.shape)
self.kv_caches = kv_caches
kv_caches_base_addr = []
@@ -884,9 +909,16 @@ class MooncakeConnectorWorker:
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
elif self.use_sfa:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
cache_list = [
cache_or_caches
] if self.use_mla or self.use_sfa else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]