[Refactor] Adapt deepseek-v3.2 to vllm 0.11.0 (#3432)
### What this PR does / why we need it? Adapt deepseek-v3.2 to vllm 0.11.0, removing the useless patch. The final goal is to remove all the patches and align the code arch to vllm, thus we need to do the following work in next prs. TODO: - [x] remove patch on attention spec - [ ] refactor the kvcache creation logic ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? 1. CI passed with existing test. 2. Test pass with deepseek-v3.2-exp - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -309,13 +309,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
# Set up Attention
|
||||
self.attn_backend = get_attn_backend(
|
||||
0,
|
||||
self.dtype,
|
||||
None,
|
||||
self.block_size,
|
||||
use_mla=self.model_config.use_mla,
|
||||
use_sfa=self.ascend_config.use_sfa)
|
||||
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
|
||||
"index_topk")
|
||||
self.attn_backend = get_attn_backend(0,
|
||||
self.dtype,
|
||||
None,
|
||||
self.block_size,
|
||||
use_mla=self.model_config.use_mla,
|
||||
use_sparse=self.use_sparse)
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
self.attn_mask_builder = AttentionMaskBuilder(
|
||||
self.scheduler_config.max_num_batched_tokens, self.dtype,
|
||||
@@ -871,7 +872,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||
# Chunk Prefill situation.
|
||||
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
|
||||
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
else:
|
||||
@@ -1507,7 +1508,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
model=self.get_model(),
|
||||
**extra_attn_metadata_args)
|
||||
|
||||
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
||||
if self.vllm_config.model_config.use_mla or self.use_sparse:
|
||||
attn_metadata_i.num_input_tokens = num_input_tokens
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
@@ -2655,7 +2656,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
if self.ascend_config.is_deepseek_sfa:
|
||||
if self.use_sparse:
|
||||
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
||||
kv_cache_config)
|
||||
elif self.model_config.is_deepseek_mla:
|
||||
@@ -2699,7 +2700,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
elif hasattr(
|
||||
attn_backend, "get_supported_block_size"
|
||||
) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa:
|
||||
) and not self.model_config.is_deepseek_mla and not self.use_sparse:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
@@ -3245,7 +3246,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
use_sfa = self.ascend_config.use_sfa
|
||||
use_sparse = self.use_sparse
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
@@ -3267,7 +3268,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if use_mla and not use_sfa:
|
||||
if use_mla and not use_sparse:
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
|
||||
@@ -43,7 +43,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
@@ -88,7 +88,11 @@ class NPUWorker(WorkerBase):
|
||||
# init ascend config and soc version
|
||||
init_ascend_config(vllm_config)
|
||||
init_ascend_soc_version()
|
||||
if get_ascend_config().use_sfa:
|
||||
use_sparse = False
|
||||
if vllm_config.model_config is not None:
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||
"index_topk")
|
||||
if use_sparse:
|
||||
# Direct import instead of using try_register_lib to ensure proper error handling when
|
||||
# custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments)
|
||||
# yapf: disable
|
||||
|
||||
Reference in New Issue
Block a user