Add DeepSeek V3.2 support (#3270)
### What this PR does / why we need it? This PR added the initial DeepSeek V3.2 support with [vLLM v0.11.0](https://github.com/vllm-project/vllm/tree/releases/v0.11.0) (not released yet). We will complete vLLM adaptation as soon as possible. This feature will be ready in recent 1-2 days. Related doc: https://github.com/vllm-project/vllm-ascend/pull/3223 . ### Does this PR introduce _any_ user-facing change? Yes! ### How was this patch tested? CI passed and Run deepseek doc soon. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wxsIcey <1790571317@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -30,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@@ -238,6 +239,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
self.block_len = block_len
|
||||
# TODO(jianzs): find a better way to detect MLA.
|
||||
self.use_mla = len(block_len) == 2
|
||||
self.use_sfa = len(block_len) == 3
|
||||
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
# TODO(jianzs): make this configurable
|
||||
@@ -349,8 +351,12 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
src_list, dst_list, length_list = [], [], []
|
||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
|
||||
block_len = (self.block_len[k % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
if self.use_mla:
|
||||
block_len = (self.block_len[k % 2])
|
||||
elif self.use_sfa:
|
||||
block_len = (self.block_len[k % 3])
|
||||
else:
|
||||
block_len = (self.block_len[0])
|
||||
for i, remote_block_id in enumerate(grouped_remote_block_ids):
|
||||
local_block_ids = grouped_local_block_ids[i]
|
||||
src = src_layer_base_addr + local_block_ids[0] * block_len
|
||||
@@ -567,6 +573,7 @@ class MooncakeConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
self.vllm_config = vllm_config
|
||||
self.ascend_config = get_ascend_config()
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id = engine_id
|
||||
logger.info("Initializing Mooncake Scheduler %s", engine_id)
|
||||
@@ -726,7 +733,7 @@ class MooncakeConnectorScheduler:
|
||||
assert "tp_size" in decode_parallel_config.keys()
|
||||
self._decode_tp_size = decode_parallel_config["tp_size"]
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
||||
return self._decode_tp_size
|
||||
else:
|
||||
# TODO support mha and gqa
|
||||
@@ -847,7 +854,9 @@ class MooncakeConnectorWorker:
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
self.use_mla = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1)
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sfa = len(first_kv_cache_tuple) == 3
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -861,6 +870,21 @@ class MooncakeConnectorWorker:
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
elif self.use_sfa:
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||
first_kv_cache[2].element_size() * math.prod(block_shape_k)
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe,
|
||||
block_shape_k)
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -871,8 +895,9 @@ class MooncakeConnectorWorker:
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
|
||||
self.use_mla, self.use_sfa, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
@@ -884,9 +909,16 @@ class MooncakeConnectorWorker:
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
elif self.use_sfa:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 3]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
cache_list = [
|
||||
cache_or_caches
|
||||
] if self.use_mla or self.use_sfa else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
|
||||
Reference in New Issue
Block a user