[1/N][Refactor] Refactor code to adapt with vllm main (#3612)

### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
17c540a993

1. refactor deepseek to the latest code arch as of
17c540a993
 
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
https://github.com/vllm-project/vllm/pull/25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by https://github.com/vllm-project/vllm/pull/26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
https://github.com/vllm-project/vllm/pull/23485
- Fix `MLAAttention` import,caused by
https://github.com/vllm-project/vllm/pull/25103
- Fix `SharedFusedMoE` import, caused by
https://github.com/vllm-project/vllm/pull/26145
- Fix `LazyLoader` improt, caused by
https://github.com/vllm-project/vllm/pull/27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
https://github.com/vllm-project/vllm/pull/26990
- Fix `Backend` enum import, caused by
https://github.com/vllm-project/vllm/pull/25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by https://github.com/vllm-project/vllm/pull/26355
- Fix fused_moe ops, caused by
https://github.com/vllm-project/vllm/pull/24097
- Fix bert model because of `inputs_embeds`, caused by
https://github.com/vllm-project/vllm/pull/25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
https://github.com/vllm-project/vllm/pull/24172
- Fix `splitting_ops` changes introduced by
https://github.com/vllm-project/vllm/pull/25845
- Fix multi-modality changes introduced by
https://github.com/vllm-project/vllm/issues/16229
- Fix lora bias dropping issue introduced by
https://github.com/vllm-project/vllm/pull/25807
- Fix structured ouput break introduced by
https://github.com/vllm-project/vllm/issues/26737

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
This commit is contained in:
Mengqing Cao
2025-10-24 16:55:08 +08:00
committed by GitHub
parent ec9ec78b53
commit cea0755b07
47 changed files with 1189 additions and 493 deletions

View File

@@ -536,6 +536,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm.model_executor.custom_op import CustomOp
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.models.layers.sfa import AscendSparseFlashAttention
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE)
@@ -572,7 +573,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"GemmaRMSNorm": AscendGemmaRMSNorm,
"FusedMoE": AscendFusedMoE,
"SharedFusedMoE": AscendSharedFusedMoE,
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
}
if vllm_config is not None and \
@@ -580,6 +580,13 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()) and \
not version_check():
REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm
mla_to_register = "MultiHeadLatentAttention" if vllm_version_is(
"0.11.0") else "MultiHeadLatentAttentionWrapper"
if vllm_config and vllm_config.model_config and vllm_config.model_config.use_mla:
AscendMLAAttentionWarrper = AscendSparseFlashAttention if hasattr(
vllm_config.model_config.hf_config,
"index_topk") else AscendMultiHeadLatentAttention
REGISTERED_ASCEND_OPS[mla_to_register] = AscendMLAAttentionWarrper
for name, op_cls in REGISTERED_ASCEND_OPS.items():
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)
@@ -771,7 +778,7 @@ def is_hierarchical_communication_enabled():
@functools.cache
def version_check():
"""check if torch_npu version >= dev20250919"""
import re
import re # noqa
torch_npu_version = torch_npu.version.__version__
date_pattern = r'dev(\d{8})'