【Feature】refactor npu_modelrunner for profile_run (#4993)
### What this PR does / why we need it?
(1)refactor npu_model_runner for profile_run
(2) move _select_moe_comm_method to ascend_forward_context
(3) delete _init_model_kwargs in npu_model_runner
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Na
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Signed-off-by: zhenwenqi2024 <155598497+zhenwenqi2024@users.noreply.github.com>
This commit is contained in:
@@ -5,12 +5,15 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
set_forward_context)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import (enable_sp, flashcomm2_enable, has_layer_idx,
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
|
||||
get_ascend_device_type, has_layer_idx,
|
||||
is_moe_model)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -31,11 +34,10 @@ def set_ascend_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: Optional[int] = None,
|
||||
num_tokens: int = 0,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
with_prefill: bool = True,
|
||||
in_profile_run: bool = False,
|
||||
moe_comm_type: Optional[MoECommType] = None,
|
||||
num_actual_tokens: Optional[int] = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
@@ -60,6 +62,11 @@ def set_ascend_forward_context(
|
||||
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import \
|
||||
get_moe_comm_method
|
||||
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config)
|
||||
# TODO: remove this after moe_comm_type selection logic is finalized
|
||||
if in_profile_run and is_mtp_model:
|
||||
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
|
||||
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
|
||||
forward_context.moe_comm_type = moe_comm_type
|
||||
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
||||
|
||||
@@ -231,3 +238,69 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
|
||||
|
||||
def get_cos_and_sin():
|
||||
return _cos, _sin
|
||||
|
||||
|
||||
def select_moe_comm_method(num_tokens: int,
|
||||
vllm_config: VllmConfig) -> Optional[MoECommType]:
|
||||
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
|
||||
are designed for expert parallelism.
|
||||
2. If expert parallel is enabled, we need to consider the soc version and the
|
||||
number of tokens. This is based on the observation that all-gather is more
|
||||
efficient than all-to-all when running on A2.
|
||||
|
||||
a. For A2, we choose from MC2 and all-gather.
|
||||
|
||||
b. For A3, we choose from MC2 and all-to-all.
|
||||
|
||||
In both cases, we use MC2 when the number of tokens is smaller than
|
||||
a its capacity threshold.
|
||||
|
||||
Args:
|
||||
num_tokens (int): The number of tokens in the current batch.
|
||||
|
||||
Raises:
|
||||
ValueError: If the soc version is unsupported.
|
||||
|
||||
Returns:
|
||||
MoECommType: The selected MoE communication method.
|
||||
"""
|
||||
if not is_moe_model(vllm_config):
|
||||
return None
|
||||
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
||||
soc_version = get_ascend_device_type()
|
||||
quant_type = getattr(
|
||||
vllm_config.model_config.hf_config, 'moe_quantize',
|
||||
getattr(vllm_config.model_config.hf_config, 'quantize', None))
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
|
||||
if not vllm_config.parallel_config.enable_expert_parallel:
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
elif soc_version in {AscendDeviceType._910B}:
|
||||
if (num_tokens <= mc2_tokens_capacity
|
||||
and vllm_config.parallel_config.world_size_across_dp /
|
||||
vllm_config.parallel_config.pipeline_parallel_size >= 16):
|
||||
moe_comm_type = MoECommType.MC2
|
||||
else:
|
||||
# Currently, w4a8_dynamic does not support allgatherep
|
||||
if quant_type == "w4a8_dynamic":
|
||||
moe_comm_type = MoECommType.ALLTOALL
|
||||
else:
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
|
||||
elif soc_version in {AscendDeviceType._910_93}:
|
||||
ascend_config = get_ascend_config()
|
||||
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
||||
fused_all2all_enable = quant_type == "w8a8_dynamic" and get_ep_group(
|
||||
).world_size <= 16 and (not dynamic_eplb)
|
||||
moe_comm_type = (MoECommType.MC2 if num_tokens <= mc2_tokens_capacity
|
||||
else MoECommType.FUSED_ALLTOALL
|
||||
if fused_all2all_enable else MoECommType.ALLTOALL)
|
||||
else:
|
||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||
moe_comm_type = (MoECommType.ALLTOALL if moe_comm_type
|
||||
== MoECommType.FUSED_ALLTOALL else moe_comm_type)
|
||||
# PanguProMoE only supports allgather
|
||||
if model_type == "PanguProMoE":
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
return moe_comm_type
|
||||
|
||||
Reference in New Issue
Block a user