[1/N][Refactor] Refactor code to adapt with vllm main (#3612)

### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
17c540a993

1. refactor deepseek to the latest code arch as of
17c540a993
 
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
https://github.com/vllm-project/vllm/pull/25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by https://github.com/vllm-project/vllm/pull/26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
https://github.com/vllm-project/vllm/pull/23485
- Fix `MLAAttention` import,caused by
https://github.com/vllm-project/vllm/pull/25103
- Fix `SharedFusedMoE` import, caused by
https://github.com/vllm-project/vllm/pull/26145
- Fix `LazyLoader` improt, caused by
https://github.com/vllm-project/vllm/pull/27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
https://github.com/vllm-project/vllm/pull/26990
- Fix `Backend` enum import, caused by
https://github.com/vllm-project/vllm/pull/25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by https://github.com/vllm-project/vllm/pull/26355
- Fix fused_moe ops, caused by
https://github.com/vllm-project/vllm/pull/24097
- Fix bert model because of `inputs_embeds`, caused by
https://github.com/vllm-project/vllm/pull/25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
https://github.com/vllm-project/vllm/pull/24172
- Fix `splitting_ops` changes introduced by
https://github.com/vllm-project/vllm/pull/25845
- Fix multi-modality changes introduced by
https://github.com/vllm-project/vllm/issues/16229
- Fix lora bias dropping issue introduced by
https://github.com/vllm-project/vllm/pull/25807
- Fix structured ouput break introduced by
https://github.com/vllm-project/vllm/issues/26737

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
This commit is contained in:
Mengqing Cao
2025-10-24 16:55:08 +08:00
committed by GitHub
parent ec9ec78b53
commit cea0755b07
47 changed files with 1189 additions and 493 deletions

View File

@@ -24,7 +24,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
@@ -56,6 +56,12 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.config import CompilationLevel
else:
from vllm.config import CompilationMode
class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
@@ -298,10 +304,16 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
self.use_aclgraph = (vllm_config is not None
and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
if vllm_version_is("0.11.0"):
self.use_aclgraph = (vllm_config is not None
and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not vllm_config.model_config.enforce_eager)
else:
self.use_aclgraph = (vllm_config is not None
and vllm_config.compilation_config.mode
== CompilationMode.VLLM_COMPILE and
not vllm_config.model_config.enforce_eager)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):

View File

@@ -23,6 +23,7 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -186,6 +187,7 @@ class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
return logits
@support_torch_compile
class TorchairDeepSeekMTP(DeepSeekMTP):
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have

View File

@@ -31,7 +31,7 @@ import torch
import torch_npu
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -75,7 +75,12 @@ from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
TorchairAscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor, oproj_tp_enable
from vllm_ascend.utils import dispose_tensor, oproj_tp_enable, vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.attention import Attention
else:
from vllm.attention.layer import MLAAttention
class TorchairDeepseekV2SiluAndMul(SiluAndMul):
@@ -561,30 +566,65 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
q_proj=self.q_proj if self.q_lora_rank is None else None,
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
)
if vllm_version_is("0.11.0"):
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=False,
indexer=None,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
q_a_proj=self.q_a_proj
if self.q_lora_rank is not None else None,
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj
if self.q_lora_rank is None else self.q_b_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
decoder_layer=decoder_layer,
)
else:
self.mla_attn = MLAAttention(
num_heads=self.num_local_heads,
scale=self.scaling,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_sparse=False,
indexer=None,
# MLA Args
rotary_emb=self.rotary_emb,
q_a_proj=self.q_a_proj
if self.q_lora_rank is not None else None,
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj
if self.q_lora_rank is None else self.q_b_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
)
def forward(
self,
@@ -791,35 +831,65 @@ class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention):
prefix=f"{prefix}.indexer",
)
self.sfa_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=True,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
indexer=self.indexer,
decoder_layer=decoder_layer,
)
if vllm_version_is("0.11.0"):
self.sfa_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=True,
indexer=self.indexer,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
q_a_proj=self.q_a_proj
if self.q_lora_rank is not None else None,
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj
if self.q_lora_rank is None else self.q_b_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
decoder_layer=decoder_layer,
)
else:
self.sfa_attn = MLAAttention(
num_heads=self.num_local_heads,
scale=self.scaling,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_sparse=True,
indexer=self.indexer,
# MLA Args
rotary_emb=self.rotary_emb,
q_a_proj=self.q_a_proj
if self.q_lora_rank is not None else None,
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj
if self.q_lora_rank is None else self.q_b_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
)
def forward(
self,

View File

@@ -54,7 +54,8 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
get_ascend_soc_version,
get_rm_router_logits_state, is_310p,
is_hierarchical_communication_enabled)
is_hierarchical_communication_enabled,
vllm_version_is)
def torchair_fused_experts_with_mc2(
@@ -1069,8 +1070,12 @@ class TorchairAscendFusedMoE(FusedMoE):
get_compressed_expert_map(self.expert_map))
else:
# init moe.
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size, self.ep_rank, self.global_num_experts)
if vllm_version_is("0.11.0"):
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size, self.ep_rank, self.global_num_experts)
else:
self.local_num_experts, self.expert_map, _ = determine_expert_map(
self.ep_size, self.ep_rank, self.global_num_experts)
# dynamic eplb initializing with not expert_map_path
if self.dynamic_eplb:
self.global_redundant_expert_num = ascend_config.init_redundancy_expert

View File

@@ -350,7 +350,7 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl):
return output.view(num_tokens, self.hidden_size)
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size)
return output.view(num_tokens, self.hidden_size).fill_(0)
output = output.view(-1, self.num_heads, self.head_size)

View File

@@ -656,8 +656,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
self.qk_head_dim = kwargs['qk_head_dim']
self.v_head_dim = kwargs['v_head_dim']
self.rotary_emb = kwargs['rotary_emb']
self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[
'q_b_proj']
self.q_proj = kwargs['q_proj']
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
@@ -1098,7 +1097,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]

View File

@@ -57,6 +57,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
self.decode_token_per_req))
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
None, None, vllm_config, device)
self.use_sparse = hasattr(self.model_config.hf_config, "index_topk")
register_torchair_model()
torchair_ops_patch()

View File

@@ -839,6 +839,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data.clone()),
dim=-1)
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
@@ -951,6 +952,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
hidden_states = self.decoder_layer.input_layernorm(hidden_states)
decode_kq = self.q_a_proj(hidden_states) # q down
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
@@ -982,7 +984,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
return output.fill_(0)
if attn_metadata.prefill is not None:
assert attn_metadata.num_decodes is not None and \
@@ -993,10 +995,12 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
hidden_states_prefill = hidden_states
prefill_slot_mapping = attn_metadata.slot_mapping
prefill_kq = self.q_a_proj(hidden_states_prefill) # q down
prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm
prefill_kv_no_split = self.kv_a_proj_with_mqa(
hidden_states_prefill) # c_kv
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
prefill_kv_no_split = get_tp_group().all_gather(
prefill_kv_no_split,
@@ -1110,6 +1114,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
else:
q_len = 1
hidden_states_decode = hidden_states
decode_kq = self.q_a_proj(hidden_states_decode) # q down
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
decode_kv_no_split = self.kv_a_proj_with_mqa(