[1/N][Refactor] Refactor code to adapt with vllm main (#3612)

### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
17c540a993

1. refactor deepseek to the latest code arch as of
17c540a993
 
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
https://github.com/vllm-project/vllm/pull/25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by https://github.com/vllm-project/vllm/pull/26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
https://github.com/vllm-project/vllm/pull/23485
- Fix `MLAAttention` import,caused by
https://github.com/vllm-project/vllm/pull/25103
- Fix `SharedFusedMoE` import, caused by
https://github.com/vllm-project/vllm/pull/26145
- Fix `LazyLoader` improt, caused by
https://github.com/vllm-project/vllm/pull/27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
https://github.com/vllm-project/vllm/pull/26990
- Fix `Backend` enum import, caused by
https://github.com/vllm-project/vllm/pull/25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by https://github.com/vllm-project/vllm/pull/26355
- Fix fused_moe ops, caused by
https://github.com/vllm-project/vllm/pull/24097
- Fix bert model because of `inputs_embeds`, caused by
https://github.com/vllm-project/vllm/pull/25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
https://github.com/vllm-project/vllm/pull/24172
- Fix `splitting_ops` changes introduced by
https://github.com/vllm-project/vllm/pull/25845
- Fix multi-modality changes introduced by
https://github.com/vllm-project/vllm/issues/16229
- Fix lora bias dropping issue introduced by
https://github.com/vllm-project/vllm/pull/25807
- Fix structured ouput break introduced by
https://github.com/vllm-project/vllm/issues/26737

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
This commit is contained in:
Mengqing Cao
2025-10-24 16:55:08 +08:00
committed by GitHub
parent ec9ec78b53
commit cea0755b07
47 changed files with 1189 additions and 493 deletions

View File

@@ -42,6 +42,14 @@ else:
from vllm.attention.layer import MLAAttention
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
if vllm_version_is("0.11.0"):
from vllm.attention import Attention
from vllm.model_executor.layers.mla import \
MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper
else:
from vllm.attention.layer import MLAAttention
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
# TODO(whx): adapt v0.11.0 and DSA
class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
@@ -107,22 +115,20 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
)
else:
self.mla_attn = MLAAttention(
num_heads=self.num_heads,
num_heads=num_heads,
scale=scale,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
kv_b_proj=mla_modules.kv_b_proj,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
kv_b_proj=mla_modules.kv_b_proj,
use_sparse=mla_modules.is_sparse,
indexer=mla_modules.indexer,
# extra args
qk_head_dim=self.qk_head_dim,
rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
q_b_proj=mla_modules.q_b_proj,

View File

@@ -24,18 +24,29 @@ from typing import Optional
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.mla import MLAModules
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.attention import Attention
from vllm.model_executor.layers.mla import \
MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper
else:
from vllm.attention.layer import MLAAttention
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
@dataclass
class AscendSFAModules:
q_a_proj: Optional[torch.nn.Module]
q_a_layernorm: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: torch.nn.Module
@@ -44,73 +55,103 @@ class AscendSFAModules:
o_proj: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module
is_sparse: bool
fused_qkv_a_proj: Optional[torch.nn.Module]
q_b_proj: Optional[torch.nn.Module]
topk_indices_buffer: Optional[torch.Tensor]
class AscendSparseFlashAttention(MultiHeadLatentAttention):
class AscendSparseFlashAttention(MultiHeadLatentAttentionWrapper):
def __init__(
self,
hidden_size: int,
enable_shared_expert_dp: bool,
debug_layer_idx: int,
first_k_dense_replace: int,
tp_size: int,
sfa_modules: AscendSFAModules,
num_local_heads: int,
scaling: float,
layers: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
q_lora_rank: Optional[int],
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
mla_modules: MLAModules,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.enable_shared_expert_dp = enable_shared_expert_dp
self.debug_layer_idx = debug_layer_idx
self.first_k_dense_replace = first_k_dense_replace
self.tp_size = tp_size
self.num_local_heads = num_local_heads
self.layers = layers
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.q_lora_rank = q_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_head_dim
self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim
self.v_head_dim = v_head_dim
self.prefix = prefix
self.scaling = scale
self.indexer = mla_modules.indexer
self.is_sparse = mla_modules.is_sparse
hf_config = get_current_vllm_config().model_config.hf_config
self.enable_shared_expert_dp = get_ascend_config(
).enable_shared_expert_dp
self.debug_layer_idx = int(self.prefix.split(".")[-2])
self.first_k_dense_replace = hf_config.first_k_dense_replace
self.tp_size = get_tensor_model_parallel_world_size()
self.layers = hf_config.num_hidden_layers
self.sfa_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=True,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=sfa_modules.rotary_emb,
q_a_proj=sfa_modules.q_a_proj,
q_a_layernorm=sfa_modules.q_a_layernorm,
q_proj=sfa_modules.q_proj,
kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa,
kv_a_layernorm=sfa_modules.kv_a_layernorm,
kv_b_proj=sfa_modules.kv_b_proj,
o_proj=sfa_modules.o_proj,
indexer=sfa_modules.indexer)
if vllm_version_is("0.11.0"):
self.sfa_attn = Attention(
num_heads=num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=True,
indexer=self.indexer,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
qk_head_dim=self.qk_head_dim,
rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
q_b_proj=mla_modules.q_b_proj,
q_a_layernorm=mla_modules.q_a_layernorm,
q_proj=mla_modules.q_proj,
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
kv_a_layernorm=mla_modules.kv_a_layernorm,
kv_b_proj=mla_modules.kv_b_proj,
o_proj=mla_modules.o_proj,
)
else:
self.sfa_attn = MLAAttention(
num_heads=num_heads,
scale=scale,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
kv_b_proj=mla_modules.kv_b_proj,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_sparse=mla_modules.is_sparse,
indexer=mla_modules.indexer,
# extra args
rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
q_b_proj=mla_modules.q_b_proj,
q_a_layernorm=mla_modules.q_a_layernorm,
q_proj=mla_modules.q_proj,
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
kv_a_layernorm=mla_modules.kv_a_layernorm,
o_proj=mla_modules.o_proj,
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: