【Feature】refactor npu_modelrunner for profile_run (#4993)

### What this PR does / why we need it?
(1)refactor npu_model_runner for profile_run
(2) move _select_moe_comm_method to ascend_forward_context
(3) delete _init_model_kwargs in npu_model_runner

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Na
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Signed-off-by: zhenwenqi2024 <155598497+zhenwenqi2024@users.noreply.github.com>
This commit is contained in:
zhenwenqi2024
2025-12-16 17:44:04 +08:00
committed by GitHub
parent af64087732
commit 4ed2951400
6 changed files with 127 additions and 205 deletions

View File

@@ -5,12 +5,15 @@ from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size)
from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import (enable_sp, flashcomm2_enable, has_layer_idx,
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
get_ascend_device_type, has_layer_idx,
is_moe_model)
if TYPE_CHECKING:
@@ -31,11 +34,10 @@ def set_ascend_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None,
with_prefill: bool = True,
in_profile_run: bool = False,
moe_comm_type: Optional[MoECommType] = None,
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None,
@@ -60,6 +62,11 @@ def set_ascend_forward_context(
from vllm_ascend.ops.fused_moe.moe_comm_method import \
get_moe_comm_method
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config)
# TODO: remove this after moe_comm_type selection logic is finalized
if in_profile_run and is_mtp_model:
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
@@ -231,3 +238,69 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
def get_cos_and_sin():
return _cos, _sin
def select_moe_comm_method(num_tokens: int,
vllm_config: VllmConfig) -> Optional[MoECommType]:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
number of tokens. This is based on the observation that all-gather is more
efficient than all-to-all when running on A2.
a. For A2, we choose from MC2 and all-gather.
b. For A3, we choose from MC2 and all-to-all.
In both cases, we use MC2 when the number of tokens is smaller than
a its capacity threshold.
Args:
num_tokens (int): The number of tokens in the current batch.
Raises:
ValueError: If the soc version is unsupported.
Returns:
MoECommType: The selected MoE communication method.
"""
if not is_moe_model(vllm_config):
return None
mc2_tokens_capacity = get_mc2_tokens_capacity()
soc_version = get_ascend_device_type()
quant_type = getattr(
vllm_config.model_config.hf_config, 'moe_quantize',
getattr(vllm_config.model_config.hf_config, 'quantize', None))
model_type = vllm_config.model_config.hf_config.model_type
if not vllm_config.parallel_config.enable_expert_parallel:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType._910B}:
if (num_tokens <= mc2_tokens_capacity
and vllm_config.parallel_config.world_size_across_dp /
vllm_config.parallel_config.pipeline_parallel_size >= 16):
moe_comm_type = MoECommType.MC2
else:
# Currently, w4a8_dynamic does not support allgatherep
if quant_type == "w4a8_dynamic":
moe_comm_type = MoECommType.ALLTOALL
else:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType._910_93}:
ascend_config = get_ascend_config()
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group(
).world_size <= 16 and (not dynamic_eplb)
moe_comm_type = (MoECommType.MC2 if num_tokens <= mc2_tokens_capacity
else MoECommType.FUSED_ALLTOALL
if fused_all2all_enable else MoECommType.ALLTOALL)
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
# PanguProMoE only supports allgather
if model_type == "PanguProMoE":
moe_comm_type = MoECommType.ALLGATHER
return moe_comm_type