2025-07-28 14:06:20 +08:00
|
|
|
import math
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from enum import Enum
|
2025-10-09 20:38:39 +08:00
|
|
|
from typing import TYPE_CHECKING, Any, Optional
|
2025-07-28 14:06:20 +08:00
|
|
|
|
|
|
|
|
import torch
|
2025-08-20 09:01:04 +08:00
|
|
|
from vllm.config import CUDAGraphMode, VllmConfig
|
2025-08-12 21:10:20 +08:00
|
|
|
from vllm.distributed import (get_dp_group, get_ep_group,
|
|
|
|
|
get_tensor_model_parallel_world_size)
|
2025-08-20 09:01:04 +08:00
|
|
|
from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
|
|
|
|
set_forward_context)
|
2025-07-28 14:06:20 +08:00
|
|
|
|
2025-08-14 09:33:39 +08:00
|
|
|
import vllm_ascend.envs as envs_ascend
|
2025-10-31 22:14:26 +08:00
|
|
|
from vllm_ascend.utils import enable_sp, has_layer_idx, is_moe_model
|
2025-07-28 14:06:20 +08:00
|
|
|
|
2025-10-09 20:38:39 +08:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
|
|
|
|
else:
|
|
|
|
|
WeightPrefetchMethod = None
|
|
|
|
|
|
2025-07-28 14:06:20 +08:00
|
|
|
|
|
|
|
|
class FusedMoEState(Enum):
|
|
|
|
|
AllGather = 0
|
|
|
|
|
All2All = 1
|
|
|
|
|
MC2 = 2
|
|
|
|
|
AllGatherEP = 3
|
|
|
|
|
NaiveMulticast = 4
|
2025-08-02 09:49:10 +08:00
|
|
|
All2AllSeq = 5
|
2025-07-28 14:06:20 +08:00
|
|
|
|
|
|
|
|
|
2025-09-22 19:12:58 +08:00
|
|
|
class MoECommType(Enum):
|
|
|
|
|
ALLGATHER = 0
|
|
|
|
|
MC2 = 1
|
|
|
|
|
ALLTOALL = 2
|
|
|
|
|
NAIVE_MULTICAST = 3
|
|
|
|
|
|
|
|
|
|
|
2025-07-28 14:06:20 +08:00
|
|
|
# TODO(zzzzwwjj): add soc_version to choose branch
|
2025-08-05 08:39:02 +08:00
|
|
|
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
|
|
|
|
is_deepseek_v3_r1: bool):
|
2025-07-28 14:06:20 +08:00
|
|
|
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
|
|
|
|
# only supports deepseek v3/r1
|
2025-08-14 09:33:39 +08:00
|
|
|
if (envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
2025-07-28 14:06:20 +08:00
|
|
|
and is_deepseek_v3_r1):
|
|
|
|
|
return FusedMoEState.AllGatherEP
|
|
|
|
|
elif ep_size == 1:
|
|
|
|
|
if with_prefill:
|
|
|
|
|
return FusedMoEState.NaiveMulticast
|
|
|
|
|
else:
|
|
|
|
|
return FusedMoEState.AllGather
|
|
|
|
|
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
|
|
|
|
elif ep_size < 16 or with_prefill:
|
|
|
|
|
return FusedMoEState.All2All
|
|
|
|
|
else:
|
|
|
|
|
return FusedMoEState.MC2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
def set_ascend_forward_context(
|
2025-08-20 09:01:04 +08:00
|
|
|
attn_metadata: Any,
|
|
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
virtual_engine: int = 0,
|
|
|
|
|
num_tokens: Optional[int] = None,
|
|
|
|
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
|
|
|
|
with_prefill: bool = True,
|
|
|
|
|
in_profile_run: bool = False,
|
|
|
|
|
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
2025-09-22 19:12:58 +08:00
|
|
|
moe_comm_type: Optional[MoECommType] = None,
|
2025-08-20 09:01:04 +08:00
|
|
|
num_actual_tokens: Optional[int] = None,
|
|
|
|
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
2025-09-11 21:20:09 +08:00
|
|
|
batch_descriptor: Optional[BatchDescriptor] = None,
|
|
|
|
|
prefetch_stream: torch.npu.Stream = None,
|
2025-10-09 20:38:39 +08:00
|
|
|
model_instance: torch.nn.Module = None,
|
|
|
|
|
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
|
2025-07-28 14:06:20 +08:00
|
|
|
"""A context manager that stores the current forward context,
|
|
|
|
|
can be attention metadata, etc.
|
|
|
|
|
We add some additional param into forward_context.
|
|
|
|
|
"""
|
2025-08-20 09:01:04 +08:00
|
|
|
with set_forward_context(
|
|
|
|
|
attn_metadata,
|
|
|
|
|
vllm_config,
|
|
|
|
|
virtual_engine=virtual_engine,
|
|
|
|
|
num_tokens=num_tokens,
|
|
|
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
|
|
|
cudagraph_runtime_mode=aclgraph_runtime_mode,
|
|
|
|
|
batch_descriptor=batch_descriptor,
|
|
|
|
|
):
|
2025-07-28 14:06:20 +08:00
|
|
|
forward_context = get_forward_context()
|
2025-09-22 19:12:58 +08:00
|
|
|
|
|
|
|
|
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
|
|
|
|
|
forward_context.moe_comm_type = moe_comm_type
|
|
|
|
|
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
|
|
|
|
|
2025-07-28 14:06:20 +08:00
|
|
|
forward_context.with_prefill = with_prefill
|
2025-09-08 22:52:24 +08:00
|
|
|
tp_world_size = get_tensor_model_parallel_world_size()
|
2025-07-28 14:06:20 +08:00
|
|
|
ep_size = (get_ep_group().world_size if
|
|
|
|
|
vllm_config.parallel_config.enable_expert_parallel else 1)
|
|
|
|
|
|
|
|
|
|
is_deepseek_v3_r1 = hasattr(
|
|
|
|
|
vllm_config.model_config.hf_config, 'n_routed_experts'
|
|
|
|
|
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
2025-08-05 08:39:02 +08:00
|
|
|
fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
|
|
|
|
|
is_deepseek_v3_r1)
|
2025-07-28 14:06:20 +08:00
|
|
|
forward_context.fused_moe_state = fused_moe_state
|
|
|
|
|
forward_context.in_profile_run = in_profile_run
|
|
|
|
|
|
|
|
|
|
# NOTE: This cannot be set using set_forward_context
|
|
|
|
|
# due to multiple warmups before actual capturing
|
|
|
|
|
forward_context.capturing = False
|
|
|
|
|
|
2025-09-24 11:29:59 +08:00
|
|
|
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
|
2025-09-08 22:52:24 +08:00
|
|
|
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
|
|
|
|
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
|
|
|
|
# the performance may degrade due to the switching of communication methods.
|
2025-10-28 23:31:19 +08:00
|
|
|
mmrs_fusion = True
|
2025-10-15 19:36:32 +08:00
|
|
|
if is_moe_model(vllm_config):
|
|
|
|
|
sp_enabled = enable_sp(vllm_config) and \
|
2025-10-17 21:13:41 +08:00
|
|
|
tp_world_size > 1 and num_tokens is not None
|
2025-10-28 23:31:19 +08:00
|
|
|
mmrs_fusion = False
|
2025-10-15 19:36:32 +08:00
|
|
|
else:
|
|
|
|
|
sp_enabled = enable_sp(vllm_config) and \
|
|
|
|
|
tp_world_size > 1 and \
|
|
|
|
|
num_tokens is not None and num_tokens > 1000
|
2025-10-28 23:31:19 +08:00
|
|
|
forward_context.mmrs_fusion = mmrs_fusion
|
2025-09-08 22:52:24 +08:00
|
|
|
|
2025-09-24 11:29:59 +08:00
|
|
|
if sp_enabled:
|
2025-09-08 22:52:24 +08:00
|
|
|
pad_size = (tp_world_size -
|
|
|
|
|
(num_tokens % tp_world_size)) % tp_world_size
|
|
|
|
|
forward_context.pad_size = pad_size
|
2025-09-24 11:29:59 +08:00
|
|
|
forward_context.sp_enabled = sp_enabled
|
2025-10-15 19:36:32 +08:00
|
|
|
forward_context.num_tokens = num_tokens
|
2025-09-08 22:52:24 +08:00
|
|
|
|
2025-09-09 14:28:14 +08:00
|
|
|
# set this for rope forward_oot using
|
|
|
|
|
forward_context.is_first_layer = True
|
|
|
|
|
|
2025-09-11 21:20:09 +08:00
|
|
|
# set layer_idx to enable optimization features that depend on this information.
|
|
|
|
|
# This is only applicable to models that contain these necessary attributes.
|
|
|
|
|
forward_context.layer_idx = None
|
[v0.11.0][Perf] Delete redundant operations in model_runner and forward_context (#3775)
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
cherry pick https://github.com/vllm-project/vllm-ascend/pull/3677
Remove redundant operations from `model_runner` and `forward_context`.
This optimization can significantly reduce the idle time (bubble) before
decoding when running models with small parameter counts (e.g.,
Qwen/Qwen2.5-0.5B).
Testing on 800I A2, bubble is reduced from 3.8ms to 2.8ms :
Before
<img width="1655" height="696" alt="image"
src="https://github.com/user-attachments/assets/d7608e52-2438-46dd-8fc9-391fd6274495"
/>
After
<img width="1607" height="774" alt="image"
src="https://github.com/user-attachments/assets/56daf081-2dba-4d2e-99d4-e055187d9806"
/>
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
No
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
---------
Signed-off-by: realliujiaxu <realliujiaxu@163.com>
2025-10-29 15:58:53 +08:00
|
|
|
if has_layer_idx(model_instance):
|
2025-09-11 21:20:09 +08:00
|
|
|
forward_context.layer_idx = model_instance.model.start_layer
|
|
|
|
|
|
2025-10-09 20:38:39 +08:00
|
|
|
# TODO(rjg-lyh): refactor mlp weight prefetch method
|
2025-09-11 21:20:09 +08:00
|
|
|
# set for mlp weight prefetch
|
|
|
|
|
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
|
|
|
|
|
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
|
|
|
|
|
forward_context.layer_idx is not None and \
|
|
|
|
|
num_tokens is not None and num_tokens < 500
|
|
|
|
|
if prefetch_mlp_enabled:
|
|
|
|
|
forward_context.prefetch_stream = prefetch_stream
|
|
|
|
|
forward_context.model_instance = model_instance
|
|
|
|
|
forward_context.prefetch_mlp_gate_up_proj = False
|
|
|
|
|
forward_context.prefetch_mlp_down_proj = False
|
|
|
|
|
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
2025-10-14 20:16:33 +08:00
|
|
|
forward_context.model_instance = model_instance
|
2025-10-09 20:38:39 +08:00
|
|
|
forward_context.weight_prefetch_method = weight_prefetch_method
|
2025-09-11 21:20:09 +08:00
|
|
|
|
2025-09-16 22:31:38 +08:00
|
|
|
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
|
|
|
|
# It will be improved later by implementing operator fusion through the FX graph.
|
|
|
|
|
#
|
|
|
|
|
# set for addrmsnorm+quant fusion.
|
|
|
|
|
# this optim now just support dense models due to the specific operators used.
|
|
|
|
|
# Once the necessary conditions are met, support for MOE models will also be added.
|
|
|
|
|
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
2025-10-31 22:14:26 +08:00
|
|
|
model_type_scope = ["llama", "qwen2", "qwen3", "qwen3_moe"]
|
2025-09-16 22:31:38 +08:00
|
|
|
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
|
2025-10-17 09:30:51 +08:00
|
|
|
vllm_config.model_config.hf_config.model_type in model_type_scope and \
|
2025-09-16 22:31:38 +08:00
|
|
|
forward_context.layer_idx is not None
|
|
|
|
|
if addrmsnorm_quant_fusion_enabled:
|
|
|
|
|
forward_context.model_instance = model_instance
|
|
|
|
|
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
|
|
|
|
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
|
2025-10-17 09:30:51 +08:00
|
|
|
if vllm_config.model_config.hf_config.model_type == "qwen3_moe":
|
|
|
|
|
forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe"
|
2025-09-16 22:31:38 +08:00
|
|
|
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
|
|
|
|
|
|
2025-07-28 14:06:20 +08:00
|
|
|
if num_tokens is None and attn_metadata is not None:
|
2025-08-05 08:39:02 +08:00
|
|
|
num_tokens = attn_metadata.num_actual_tokens
|
2025-07-28 14:06:20 +08:00
|
|
|
|
|
|
|
|
dp_world_size = get_dp_group().world_size
|
|
|
|
|
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
2025-10-15 19:36:32 +08:00
|
|
|
max_tokens_across_dp = \
|
|
|
|
|
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
|
|
|
|
if sp_enabled:
|
|
|
|
|
padded_length = (max_tokens_across_dp + tp_world_size -
|
|
|
|
|
1) // tp_world_size * tp_world_size
|
|
|
|
|
pad_size = padded_length - num_tokens
|
|
|
|
|
forward_context.padded_length = padded_length
|
|
|
|
|
forward_context.pad_size = pad_size
|
2025-07-28 14:06:20 +08:00
|
|
|
else:
|
|
|
|
|
max_tokens_across_dp = num_tokens
|
|
|
|
|
|
|
|
|
|
forward_context.max_tokens_across_dp = max_tokens_across_dp
|
|
|
|
|
|
|
|
|
|
if num_tokens is not None:
|
2025-08-05 08:39:02 +08:00
|
|
|
if num_actual_tokens is None:
|
|
|
|
|
num_actual_tokens = num_tokens
|
2025-07-28 14:06:20 +08:00
|
|
|
# NOTE: token num which need to pad to when mc2
|
|
|
|
|
forward_context.padded_num_tokens = math.ceil(
|
|
|
|
|
max_tokens_across_dp / tp_world_size) * tp_world_size
|
|
|
|
|
|
2025-08-12 21:10:20 +08:00
|
|
|
if reserved_mc2_mask is not None:
|
|
|
|
|
mc2_mask = reserved_mc2_mask[:forward_context.
|
|
|
|
|
padded_num_tokens]
|
|
|
|
|
mc2_mask[:num_actual_tokens] = True
|
|
|
|
|
mc2_mask[num_actual_tokens:] = False
|
|
|
|
|
forward_context.mc2_mask = mc2_mask
|
2025-07-28 14:06:20 +08:00
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
yield
|
|
|
|
|
finally:
|
|
|
|
|
pass
|