2025-07-28 14:06:20 +08:00
|
|
|
import math
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from enum import Enum
|
2025-12-23 08:49:52 +08:00
|
|
|
from typing import Any, Optional
|
2025-07-28 14:06:20 +08:00
|
|
|
|
|
|
|
|
import torch
|
2025-08-20 09:01:04 +08:00
|
|
|
from vllm.config import CUDAGraphMode, VllmConfig
|
2025-12-16 17:44:04 +08:00
|
|
|
from vllm.distributed import (get_dp_group, get_ep_group,
|
|
|
|
|
get_tensor_model_parallel_world_size)
|
2025-08-20 09:01:04 +08:00
|
|
|
from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
|
|
|
|
set_forward_context)
|
2025-07-28 14:06:20 +08:00
|
|
|
|
2025-08-14 09:33:39 +08:00
|
|
|
import vllm_ascend.envs as envs_ascend
|
2025-12-16 17:44:04 +08:00
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
|
|
|
from vllm_ascend.utils import (AscendDeviceType, enable_sp, flashcomm2_enable,
|
|
|
|
|
get_ascend_device_type, has_layer_idx,
|
2025-11-10 11:01:45 +08:00
|
|
|
is_moe_model)
|
2025-07-28 14:06:20 +08:00
|
|
|
|
|
|
|
|
|
2025-11-24 17:32:37 +08:00
|
|
|
class MoECommType(Enum):
|
|
|
|
|
ALLGATHER = 0
|
|
|
|
|
MC2 = 1
|
|
|
|
|
ALLTOALL = 2
|
2025-12-18 23:34:31 +08:00
|
|
|
FUSED_MC2 = 3
|
2025-11-24 17:32:37 +08:00
|
|
|
|
|
|
|
|
|
2025-07-28 14:06:20 +08:00
|
|
|
@contextmanager
|
|
|
|
|
def set_ascend_forward_context(
|
2025-08-20 09:01:04 +08:00
|
|
|
attn_metadata: Any,
|
|
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
virtual_engine: int = 0,
|
2025-12-16 17:44:04 +08:00
|
|
|
num_tokens: int = 0,
|
2025-08-20 09:01:04 +08:00
|
|
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
|
|
|
|
in_profile_run: bool = False,
|
|
|
|
|
num_actual_tokens: Optional[int] = None,
|
|
|
|
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
2025-09-11 21:20:09 +08:00
|
|
|
batch_descriptor: Optional[BatchDescriptor] = None,
|
2025-10-09 20:38:39 +08:00
|
|
|
model_instance: torch.nn.Module = None,
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
is_mtp_model=False):
|
2025-07-28 14:06:20 +08:00
|
|
|
"""A context manager that stores the current forward context,
|
|
|
|
|
can be attention metadata, etc.
|
|
|
|
|
We add some additional param into forward_context.
|
|
|
|
|
"""
|
2025-08-20 09:01:04 +08:00
|
|
|
with set_forward_context(
|
|
|
|
|
attn_metadata,
|
|
|
|
|
vllm_config,
|
|
|
|
|
virtual_engine=virtual_engine,
|
|
|
|
|
num_tokens=num_tokens,
|
|
|
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
|
|
|
cudagraph_runtime_mode=aclgraph_runtime_mode,
|
|
|
|
|
batch_descriptor=batch_descriptor,
|
|
|
|
|
):
|
2025-07-28 14:06:20 +08:00
|
|
|
forward_context = get_forward_context()
|
2025-09-22 19:12:58 +08:00
|
|
|
|
2025-10-25 11:22:03 +08:00
|
|
|
from vllm_ascend.ops.fused_moe.moe_comm_method import \
|
|
|
|
|
get_moe_comm_method
|
2025-12-18 23:34:31 +08:00
|
|
|
moe_comm_type = select_moe_comm_method(num_tokens, vllm_config,
|
|
|
|
|
is_mtp_model)
|
2025-09-22 19:12:58 +08:00
|
|
|
forward_context.moe_comm_type = moe_comm_type
|
|
|
|
|
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
|
|
|
|
|
2025-09-08 22:52:24 +08:00
|
|
|
tp_world_size = get_tensor_model_parallel_world_size()
|
2025-07-28 14:06:20 +08:00
|
|
|
|
|
|
|
|
forward_context.in_profile_run = in_profile_run
|
|
|
|
|
|
|
|
|
|
# NOTE: This cannot be set using set_forward_context
|
|
|
|
|
# due to multiple warmups before actual capturing
|
|
|
|
|
forward_context.capturing = False
|
|
|
|
|
|
2025-09-24 11:29:59 +08:00
|
|
|
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
|
2025-09-08 22:52:24 +08:00
|
|
|
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
|
|
|
|
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
|
|
|
|
# the performance may degrade due to the switching of communication methods.
|
2025-10-28 23:30:27 +08:00
|
|
|
mmrs_fusion = True
|
2025-10-15 19:36:32 +08:00
|
|
|
if is_moe_model(vllm_config):
|
2025-11-06 20:02:03 +08:00
|
|
|
sp_enabled = enable_sp(vllm_config) and num_tokens is not None
|
2025-10-28 23:30:27 +08:00
|
|
|
mmrs_fusion = False
|
2025-10-15 19:36:32 +08:00
|
|
|
else:
|
|
|
|
|
sp_enabled = enable_sp(vllm_config) and \
|
|
|
|
|
num_tokens is not None and num_tokens > 1000
|
2025-10-28 23:30:27 +08:00
|
|
|
forward_context.mmrs_fusion = mmrs_fusion
|
2025-11-10 11:01:45 +08:00
|
|
|
forward_context.num_tokens = num_tokens
|
|
|
|
|
forward_context.sp_enabled = sp_enabled
|
2025-12-18 23:34:31 +08:00
|
|
|
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
2025-11-10 11:01:45 +08:00
|
|
|
forward_context.flashcomm_v2_enabled = flashcomm2_enable(
|
|
|
|
|
) and tp_world_size > 1 and num_tokens is not None
|
2025-09-08 22:52:24 +08:00
|
|
|
|
2025-11-10 11:01:45 +08:00
|
|
|
if (forward_context.sp_enabled
|
|
|
|
|
or forward_context.flashcomm_v2_enabled):
|
2025-09-08 22:52:24 +08:00
|
|
|
pad_size = (tp_world_size -
|
|
|
|
|
(num_tokens % tp_world_size)) % tp_world_size
|
|
|
|
|
forward_context.pad_size = pad_size
|
|
|
|
|
|
2025-09-09 14:28:14 +08:00
|
|
|
# set this for rope forward_oot using
|
|
|
|
|
forward_context.is_first_layer = True
|
|
|
|
|
|
2025-09-11 21:20:09 +08:00
|
|
|
# set layer_idx to enable optimization features that depend on this information.
|
|
|
|
|
# This is only applicable to models that contain these necessary attributes.
|
|
|
|
|
forward_context.layer_idx = None
|
2025-10-29 15:59:55 +08:00
|
|
|
if has_layer_idx(model_instance):
|
2025-09-11 21:20:09 +08:00
|
|
|
forward_context.layer_idx = model_instance.model.start_layer
|
|
|
|
|
|
2025-10-09 20:38:39 +08:00
|
|
|
# TODO(rjg-lyh): refactor mlp weight prefetch method
|
2025-09-11 21:20:09 +08:00
|
|
|
# set for mlp weight prefetch
|
|
|
|
|
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
|
|
|
|
|
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
|
|
|
|
|
forward_context.layer_idx is not None and \
|
|
|
|
|
num_tokens is not None and num_tokens < 500
|
|
|
|
|
if prefetch_mlp_enabled:
|
|
|
|
|
forward_context.prefetch_mlp_gate_up_proj = False
|
|
|
|
|
forward_context.prefetch_mlp_down_proj = False
|
|
|
|
|
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
2025-10-14 20:16:33 +08:00
|
|
|
forward_context.model_instance = model_instance
|
[Feat] Support MTP to running in full graph mode (#3892)
### What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR
adapts the MTP with the full graph capture and execution. When the graph
mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to
improve the performance.
The change in both disable_padded_drafter_batch is True and False case
include:
1. Add _mtp_graph_params in acl_graph.py to isolate the data of main
model and the data of MTP.
2. Padding some metadata in mla_v1.py when in fullgraph mode.
3. Fixed the essential data address that will be used in model.forward.
4. Adapted according to the aclgraph capture framwork:
1). Rebuild MTP model with ACLGraphWrapper.
2). Add common attn metadata when start capture in MTP dummy_run.
3). Add common attn metadata update in MTP.
4). Addapted data update when num_speculative_tokens > 1.
5. Add a patch of MTP to adapt vllm v0.11.0.
Existing Issues:
1. When disable_padded_drafter_batch=True and running in FullGraph mode,
the data of the first-round requests in MTP is abnormal. We need to
identify the cause subsequently.
2. When disable_padded_drafter_batch=False and running in FullGraph
mode, the acceptance rate of the second and third tokens will decrease
(For example, if we set the num_speculative_tokens=3, the acceptance
rate of first token is 90%, the second is only 50% lower than 60%, the
third is only 20% lower than 30%). The reason is that the data processed
after the model runs does not match. This is a problem from another PR.
It works fine in eager and PIECEWISE mode, but has problem in FullGraph
mode. Once we have a solution, we will submit a bugfix.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379
---------
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
2025-11-20 20:34:54 +08:00
|
|
|
forward_context.is_mtp_model = is_mtp_model
|
2025-09-11 21:20:09 +08:00
|
|
|
|
2025-07-28 14:06:20 +08:00
|
|
|
if num_tokens is None and attn_metadata is not None:
|
2025-08-05 08:39:02 +08:00
|
|
|
num_tokens = attn_metadata.num_actual_tokens
|
2025-07-28 14:06:20 +08:00
|
|
|
|
|
|
|
|
dp_world_size = get_dp_group().world_size
|
|
|
|
|
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
2025-10-15 19:36:32 +08:00
|
|
|
max_tokens_across_dp = \
|
|
|
|
|
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
2025-11-10 11:01:45 +08:00
|
|
|
if (forward_context.sp_enabled
|
|
|
|
|
or forward_context.flashcomm_v2_enabled):
|
2025-10-15 19:36:32 +08:00
|
|
|
padded_length = (max_tokens_across_dp + tp_world_size -
|
|
|
|
|
1) // tp_world_size * tp_world_size
|
|
|
|
|
pad_size = padded_length - num_tokens
|
|
|
|
|
forward_context.padded_length = padded_length
|
|
|
|
|
forward_context.pad_size = pad_size
|
2025-07-28 14:06:20 +08:00
|
|
|
else:
|
|
|
|
|
max_tokens_across_dp = num_tokens
|
|
|
|
|
|
|
|
|
|
forward_context.max_tokens_across_dp = max_tokens_across_dp
|
|
|
|
|
|
|
|
|
|
if num_tokens is not None:
|
2025-08-05 08:39:02 +08:00
|
|
|
if num_actual_tokens is None:
|
|
|
|
|
num_actual_tokens = num_tokens
|
2025-07-28 14:06:20 +08:00
|
|
|
# NOTE: token num which need to pad to when mc2
|
|
|
|
|
forward_context.padded_num_tokens = math.ceil(
|
|
|
|
|
max_tokens_across_dp / tp_world_size) * tp_world_size
|
2025-12-12 17:27:09 +08:00
|
|
|
reserved_mc2_mask = get_mc2_mask()
|
2025-08-12 21:10:20 +08:00
|
|
|
if reserved_mc2_mask is not None:
|
|
|
|
|
mc2_mask = reserved_mc2_mask[:forward_context.
|
|
|
|
|
padded_num_tokens]
|
|
|
|
|
mc2_mask[:num_actual_tokens] = True
|
|
|
|
|
mc2_mask[num_actual_tokens:] = False
|
|
|
|
|
forward_context.mc2_mask = mc2_mask
|
2025-07-28 14:06:20 +08:00
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
yield
|
|
|
|
|
finally:
|
|
|
|
|
pass
|
2025-12-12 17:27:09 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
_mc2_tokens_capacity: Optional[int] = None
|
|
|
|
|
_reserved_mc2_mask: Optional[torch.Tensor] = None
|
|
|
|
|
_sin: Optional[torch.Tensor] = None
|
|
|
|
|
_cos: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_mc2_tokens_capacity(vllm_config, max_num_reqs,
|
|
|
|
|
uniform_decode_query_len):
|
|
|
|
|
global _mc2_tokens_capacity
|
|
|
|
|
if _mc2_tokens_capacity is not None:
|
|
|
|
|
return
|
|
|
|
|
if vllm_config.compilation_config.cudagraph_capture_sizes:
|
|
|
|
|
max_num_tokens = vllm_config.compilation_config.max_cudagraph_capture_size
|
|
|
|
|
else:
|
|
|
|
|
# NOTE: To save memory, we cap the max number of tokens to 512.
|
|
|
|
|
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
|
|
|
|
|
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
|
|
|
|
# Use integer arithmetic for ceiling division.
|
|
|
|
|
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
|
|
|
|
_mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_mc2_tokens_capacity():
|
|
|
|
|
return _mc2_tokens_capacity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_mc2_mask(vllm_config, device):
|
|
|
|
|
global _reserved_mc2_mask
|
|
|
|
|
if _reserved_mc2_mask is not None:
|
|
|
|
|
return
|
|
|
|
|
if is_moe_model(vllm_config):
|
|
|
|
|
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(),
|
|
|
|
|
dtype=torch.bool,
|
|
|
|
|
device=device)
|
|
|
|
|
else:
|
|
|
|
|
_reserved_mc2_mask = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_mc2_mask():
|
|
|
|
|
return _reserved_mc2_mask
|
|
|
|
|
|
|
|
|
|
|
2025-12-16 17:44:04 +08:00
|
|
|
def select_moe_comm_method(num_tokens: int,
|
2025-12-18 23:34:31 +08:00
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
is_mtp_model=False) -> Optional[MoECommType]:
|
|
|
|
|
"""Select the MoE communication method according to parallel settings,
|
|
|
|
|
device generation, token count, and quantization.
|
|
|
|
|
|
|
|
|
|
1. Non-MoE models return `None`.
|
|
|
|
|
2. Without expert parallel, fall back to all-gather.
|
|
|
|
|
3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
|
|
|
|
|
and the DP size is large enough; otherwise use all-gather.
|
|
|
|
|
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
|
|
|
|
|
quantization with small EP size, no dynamic_eplb, and not in MTP
|
|
|
|
|
mode; otherwise use MC2 within capacity or all-to-all.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
num_tokens (int): The number of tokens in the current batch.
|
|
|
|
|
vllm_config (VllmConfig): Runtime configuration for the model.
|
|
|
|
|
is_mtp_model (bool): Whether the model runs in MTP mode (disables fused MC2).
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If the soc version is unsupported.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
MoECommType | None: The selected MoE communication method.
|
|
|
|
|
"""
|
2025-12-16 17:44:04 +08:00
|
|
|
if not is_moe_model(vllm_config):
|
|
|
|
|
return None
|
|
|
|
|
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
|
|
|
|
soc_version = get_ascend_device_type()
|
|
|
|
|
quant_type = getattr(
|
|
|
|
|
vllm_config.model_config.hf_config, 'moe_quantize',
|
|
|
|
|
getattr(vllm_config.model_config.hf_config, 'quantize', None))
|
|
|
|
|
|
|
|
|
|
if not vllm_config.parallel_config.enable_expert_parallel:
|
|
|
|
|
moe_comm_type = MoECommType.ALLGATHER
|
2025-12-17 14:08:19 +08:00
|
|
|
elif soc_version in {AscendDeviceType.A2}:
|
2025-12-16 17:44:04 +08:00
|
|
|
if (num_tokens <= mc2_tokens_capacity
|
|
|
|
|
and vllm_config.parallel_config.world_size_across_dp /
|
|
|
|
|
vllm_config.parallel_config.pipeline_parallel_size >= 16):
|
|
|
|
|
moe_comm_type = MoECommType.MC2
|
|
|
|
|
else:
|
2025-12-17 17:39:57 +08:00
|
|
|
moe_comm_type = MoECommType.ALLGATHER
|
2025-12-16 17:44:04 +08:00
|
|
|
|
2025-12-17 14:08:19 +08:00
|
|
|
elif soc_version in {AscendDeviceType.A3}:
|
2025-12-16 17:44:04 +08:00
|
|
|
ascend_config = get_ascend_config()
|
|
|
|
|
dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
|
|
|
|
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
2025-12-21 15:23:59 +08:00
|
|
|
# TODO: drop dynamic_eplb guard when dispatch_gmm_combine_decode supports tensor list inputs
|
|
|
|
|
# TODO: add guard for dispatch_gmm_combine_decode when mtp uses float while moe uses w8a8
|
|
|
|
|
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" and (
|
|
|
|
|
not dynamic_eplb)
|
2025-12-18 23:34:31 +08:00
|
|
|
if num_tokens <= mc2_tokens_capacity:
|
2025-12-21 15:23:59 +08:00
|
|
|
fused_decode_enable = fused_mc2_enable
|
|
|
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
|
|
|
|
fused_decode_enable = fused_mc2_enable and get_ep_group(
|
|
|
|
|
).world_size <= 16 and (not is_mtp_model)
|
|
|
|
|
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
|
2025-12-18 23:34:31 +08:00
|
|
|
else:
|
2025-12-21 15:23:59 +08:00
|
|
|
fused_prefill_enable = fused_mc2_enable
|
|
|
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
|
|
|
|
fused_prefill_enable = fused_mc2_enable and get_ep_group(
|
|
|
|
|
).world_size <= 16 and (not is_mtp_model)
|
|
|
|
|
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
|
|
|
|
fused_prefill_enable = False
|
|
|
|
|
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL
|
2025-12-18 23:34:31 +08:00
|
|
|
|
2025-12-16 17:44:04 +08:00
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
|
|
|
|
return moe_comm_type
|