[Refactor] Adapt deepseek-v3.2 to vllm 0.11.0 (#3432)

### What this PR does / why we need it?
Adapt deepseek-v3.2 to vllm 0.11.0, removing the useless patch.

The final goal is to remove all the patches and align the code arch to
vllm, thus we need to do the following work in next prs.
TODO:
- [x] remove patch on attention spec
- [ ] refactor the kvcache creation logic

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
1. CI passed with existing test.
2. Test pass with deepseek-v3.2-exp


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-10-15 17:48:58 +08:00
committed by GitHub
parent 099255e933
commit 8abe517870
20 changed files with 143 additions and 262 deletions

View File

@@ -31,6 +31,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_rank,
@@ -47,7 +48,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import \
@@ -56,10 +58,11 @@ from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
get_spec_layer_idx_from_weight_name)
from vllm.model_executor.models.utils import (PPMissingLayer,
is_pp_missing_parameter,
maybe_prefix)
from vllm.model_executor.models.utils import (
PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.mla import AscendMLAModules
@@ -69,6 +72,53 @@ from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
from vllm_ascend.ops.linear import AscendLinearBase
@support_torch_compile
class AscendDeepseekV2Model(DeepseekV2Model, nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Rewrite this init func mainly for removing cuda-hard code
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device=current_platform.device_type)
else:
topk_indices_buffer = None
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def __init__(
@@ -270,6 +320,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.indexer = None
mla_modules = AscendMLAModules(
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
@@ -281,6 +332,8 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
rotary_emb=self.rotary_emb,
indexer=None,
is_sparse=hasattr(config, "index_topk"),
)
self.mla_attn = MultiHeadLatentAttention(
@@ -499,7 +552,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
ascend_config = get_ascend_config()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
@@ -515,7 +567,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
self.tp_rank = get_tp_group().rank_in_group
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
if ascend_config.use_sfa:
if hasattr(model_config.hf_config, "index_topk"):
attn_cls = CustomDeepseekV2SFAAttention
else:
attn_cls = CustomDeepseekV2MLAAttention
@@ -590,8 +642,9 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
"kv_a_proj_with_mqa",
]
self.model = DeepseekV2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = AscendDeepseekV2Model(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,