### What this PR does / why we need it?
Fix multi dp padding logic for eager mode, bacause its will cause rank0
load imbalance in kimi-k2.5-w4a8 with the all the padding tokens router
to rank0. And the fix can also apply to other model in multi dp.
- before
hbm usage:
<img width="2229" height="733" alt="image"
src="https://github.com/user-attachments/assets/50479b6d-cfd0-4206-8e80-974024652997"
/>
preformance:
```shell
Concurrency NumPrompts QPS TTFT_Avg TTFT_P50 TPOT_Avg TPOT_P50 TPOT_P90
============ ============ ============ ============ ============ ============ ============ ============
1 15 0.0179 1667.7803 1673.3437 35.2973 35.2775 35.3784
32 480 0.4725 2764.8027 1905.2137 40.8030 40.6978 41.0179
64 960 0.7820 4123.7096 3485.6153 48.0461 48.1598 48.2971
100 1500 1.0852 6216.7988 5714.0082 52.9323 53.0613 54.6304
108 1620 1.1040 6277.4892 5798.7425 56.3862 56.9224 57.2901
116 1740 1.1680 6563.3293 6039.5659 56.9894 57.4027 57.5786
128 1920 1.2555 7822.5551 7604.1662 57.7660 58.1768 58.2717
192 2880 1.4314 9212.1953 9131.3461 58.9905 59.1683 59.2791
256 3840 1.4480 9028.0812 8913.7937 59.0092 59.2385 59.3516
```
- after
hbm usage:
<img width="2246" height="1005" alt="image"
src="https://github.com/user-attachments/assets/d0936481-5a58-4bc5-a6f1-b92735d47885"
/>
preformance:
```shell
Concurrency NumPrompts QPS TTFT_Avg TTFT_P50 TPOT_Avg TPOT_P50 TPOT_P90
============ ============ ============ ============ ============ ============ ============ ============
1 15 0.0181 601.4171 600.9774 35.6270 35.6254 35.6480
32 480 0.4455 720.8782 724.2889 45.4250 45.4755 45.6318
64 960 0.8445 729.6209 728.2149 47.0464 47.0896 47.1985
100 1500 1.2601 723.4834 724.6673 48.3108 48.3844 48.5355
108 1620 1.3409 727.1509 720.6772 48.8962 48.9409 49.0489
116 1740 1.4080 679.9799 677.6119 49.1253 49.1983 49.3087
128 1920 1.4155 680.6284 674.9436 49.2193 49.2450 49.3763
192 2880 1.4422 684.6577 676.7833 49.2059 49.2264 49.3229
256 3840 1.4558 685.2462 678.1709 49.2191 49.2351 49.3419
```
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
Co-authored-by: fny-coder <985619145@qq.com>
341 lines
14 KiB
Python
341 lines
14 KiB
Python
import math
|
|
from contextlib import contextmanager
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
import torch
|
|
import vllm.envs as envs_vllm
|
|
from vllm.config import CUDAGraphMode, VllmConfig
|
|
from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parallel_world_size
|
|
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.utils import (
|
|
AscendDeviceType,
|
|
enable_sp,
|
|
flashcomm2_enable,
|
|
get_ascend_device_type,
|
|
has_layer_idx,
|
|
is_drafter_moe_model,
|
|
is_moe_model,
|
|
speculative_enable_dispatch_gmm_combine_decode,
|
|
vllm_version_is,
|
|
)
|
|
|
|
|
|
class MoECommType(Enum):
|
|
ALLGATHER = 0
|
|
MC2 = 1
|
|
ALLTOALL = 2
|
|
FUSED_MC2 = 3
|
|
|
|
|
|
@contextmanager
|
|
def set_ascend_forward_context(
|
|
attn_metadata: Any,
|
|
vllm_config: VllmConfig,
|
|
virtual_engine: int = 0,
|
|
num_tokens: int = 0,
|
|
num_tokens_across_dp: torch.Tensor | None = None,
|
|
in_profile_run: bool = False,
|
|
num_actual_tokens: int | None = None,
|
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
|
batch_descriptor: BatchDescriptor | None = None,
|
|
model_instance: torch.nn.Module = None,
|
|
is_draft_model=False,
|
|
skip_compiled: bool = False,
|
|
max_tokens_across_pcp: int = 0,
|
|
draft_attn_metadatas=None,
|
|
):
|
|
"""A context manager that stores the current forward context,
|
|
can be attention metadata, etc.
|
|
We add some additional param into forward_context.
|
|
"""
|
|
forward_context_kwargs = {
|
|
"attn_metadata": attn_metadata,
|
|
"vllm_config": vllm_config,
|
|
"num_tokens": num_tokens,
|
|
"num_tokens_across_dp": num_tokens_across_dp,
|
|
"cudagraph_runtime_mode": aclgraph_runtime_mode,
|
|
"batch_descriptor": batch_descriptor,
|
|
"skip_compiled": skip_compiled,
|
|
}
|
|
if vllm_version_is("0.18.0"):
|
|
forward_context_kwargs["virtual_engine"] = virtual_engine
|
|
|
|
with set_forward_context(**forward_context_kwargs):
|
|
forward_context = get_forward_context()
|
|
forward_context.draft_attn_metadatas = draft_attn_metadatas
|
|
|
|
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method
|
|
|
|
max_num_tokens = int(num_tokens_across_dp.max().item()) if num_tokens_across_dp is not None else num_tokens
|
|
moe_comm_type = select_moe_comm_method(max_num_tokens, vllm_config, is_draft_model)
|
|
|
|
forward_context.moe_comm_type = moe_comm_type
|
|
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
|
|
|
|
tp_world_size = get_tensor_model_parallel_world_size()
|
|
|
|
forward_context.in_profile_run = in_profile_run
|
|
|
|
# NOTE: This cannot be set using set_forward_context
|
|
# due to multiple warmups before actual capturing
|
|
forward_context.capturing = False
|
|
|
|
# TODO: remove it when torch_npu.npu_mm_reduce_scatter_base supports tp_size >= 16.
|
|
mmrs_fusion = tp_world_size <= 8
|
|
|
|
# set for sequence parallelism, 1000 is the batch size concurrency threshold
|
|
# for enabling the flashcomm_v1 or sequence_parallelism feature.
|
|
# Currently, it is an empirical value. In normal scenarios, if the concurrency
|
|
# exceeds this threshold, the performance benefits can be maximized.
|
|
# Conversely, if the concurrency is below the threshold,
|
|
# the performance may degrade due to the switching of communication methods.
|
|
|
|
# main model and drafter model may have different architecture
|
|
is_context_moe_model = is_drafter_moe_model(vllm_config) if is_draft_model else is_moe_model(vllm_config)
|
|
if is_context_moe_model:
|
|
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None
|
|
mmrs_fusion = False
|
|
elif is_draft_model:
|
|
# TODO: for dense drafter, `sp` is redundant and is not compatible with `dp` and `graph`.
|
|
# Disable it to avoid more problems.
|
|
flash_comm_v1_enabled = False
|
|
else:
|
|
flash_comm_v1_enabled = enable_sp(vllm_config) and num_tokens is not None and num_tokens > 1000
|
|
forward_context.mmrs_fusion = mmrs_fusion
|
|
forward_context.num_tokens = num_tokens
|
|
forward_context.flash_comm_v1_enabled = flash_comm_v1_enabled
|
|
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
|
forward_context.flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
|
|
|
forward_context.pad_size = 0
|
|
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
|
|
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
|
forward_context.pad_size = pad_size
|
|
|
|
# set this for rope forward_oot using
|
|
forward_context.is_first_layer = True
|
|
|
|
# set layer_idx to enable optimization features that depend on this information.
|
|
# This is only applicable to models that contain these necessary attributes.
|
|
forward_context.layer_idx = None
|
|
if has_layer_idx(model_instance):
|
|
forward_context.layer_idx = model_instance.model.start_layer
|
|
|
|
forward_context.prefetch_mlp_gate_up_proj = False
|
|
forward_context.prefetch_mlp_down_proj = False
|
|
forward_context.model_instance = model_instance
|
|
forward_context.is_draft_model = is_draft_model
|
|
|
|
if num_tokens is None and attn_metadata is not None:
|
|
num_tokens = attn_metadata.num_actual_tokens
|
|
|
|
dp_world_size = get_dp_group().world_size
|
|
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
|
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
|
if forward_context.flash_comm_v1_enabled or forward_context.flashcomm_v2_enabled:
|
|
padded_length = (max_tokens_across_dp + tp_world_size - 1) // tp_world_size * tp_world_size
|
|
pad_size = padded_length - num_tokens
|
|
forward_context.padded_length = padded_length
|
|
forward_context.pad_size = pad_size
|
|
else:
|
|
max_tokens_across_dp = num_tokens
|
|
|
|
forward_context.max_tokens_across_dp = max_tokens_across_dp
|
|
forward_context.max_tokens_across_pcp = max_tokens_across_pcp
|
|
|
|
if num_tokens is not None:
|
|
if num_actual_tokens is None:
|
|
num_actual_tokens = num_tokens
|
|
# NOTE: token num which need to pad to when mc2
|
|
forward_context.padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
|
|
reserved_mc2_mask = get_mc2_mask()
|
|
if reserved_mc2_mask is not None:
|
|
mc2_mask = reserved_mc2_mask[: forward_context.padded_num_tokens]
|
|
mc2_mask[:num_actual_tokens] = True
|
|
mc2_mask[num_actual_tokens:] = False
|
|
forward_context.mc2_mask = mc2_mask
|
|
try:
|
|
yield
|
|
finally:
|
|
pass
|
|
|
|
|
|
_mc2_tokens_capacity: int | None = None
|
|
_reserved_mc2_mask: torch.Tensor | None = None
|
|
|
|
|
|
def set_mc2_tokens_capacity(vllm_config, max_num_reqs, uniform_decode_query_len):
|
|
global _mc2_tokens_capacity
|
|
if _mc2_tokens_capacity is not None:
|
|
return
|
|
if vllm_config.compilation_config.cudagraph_capture_sizes:
|
|
max_num_tokens = vllm_config.compilation_config.max_cudagraph_capture_size
|
|
else:
|
|
# NOTE: To save memory, we cap the max number of tokens to 512.
|
|
max_num_tokens = min(max_num_reqs * uniform_decode_query_len, 512)
|
|
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
|
# Use integer arithmetic for ceiling division.
|
|
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
|
_mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
|
|
|
|
|
def get_mc2_tokens_capacity():
|
|
return _mc2_tokens_capacity
|
|
|
|
|
|
def set_mc2_mask(vllm_config, device):
|
|
global _reserved_mc2_mask
|
|
if _reserved_mc2_mask is not None:
|
|
return
|
|
if is_moe_model(vllm_config):
|
|
_reserved_mc2_mask = torch.zeros(get_mc2_tokens_capacity(), dtype=torch.bool, device=device)
|
|
else:
|
|
_reserved_mc2_mask = None
|
|
|
|
|
|
def get_mc2_mask():
|
|
return _reserved_mc2_mask
|
|
|
|
|
|
def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_model=False) -> MoECommType | None:
|
|
"""Select the MoE communication method according to parallel settings,
|
|
device generation, token count, and quantization.
|
|
|
|
1. Non-MoE models return `None`.
|
|
2. Without expert parallel, fall back to all-gather.
|
|
3. On A2 with expert parallel, pick MC2 when tokens fit the MC2 capacity
|
|
and the DP size is large enough; otherwise use all-gather.
|
|
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
|
|
quantization with small EP size, no dynamic_eplb, and not in MTP
|
|
mode; otherwise use MC2 within capacity or all-to-all.
|
|
5. On 310P, always use all-gather.
|
|
|
|
Args:
|
|
num_tokens (int): The number of tokens in the current batch.
|
|
vllm_config (VllmConfig): Runtime configuration for the model.
|
|
is_draft_model (bool): Whether the model runs in MTP mode (disables fused MC2).
|
|
|
|
Raises:
|
|
ValueError: If the soc version is unsupported.
|
|
|
|
Returns:
|
|
MoECommType | None: The selected MoE communication method.
|
|
"""
|
|
if not is_moe_model(vllm_config):
|
|
return None
|
|
mc2_tokens_capacity = get_mc2_tokens_capacity()
|
|
soc_version = get_ascend_device_type()
|
|
quant_type = getattr(
|
|
vllm_config.model_config.hf_text_config,
|
|
"moe_quantize",
|
|
getattr(vllm_config.model_config.hf_text_config, "quantize", None),
|
|
)
|
|
|
|
if not vllm_config.parallel_config.enable_expert_parallel or get_ep_group().world_size == 1:
|
|
moe_comm_type = MoECommType.ALLGATHER
|
|
elif soc_version in {AscendDeviceType.A2}:
|
|
num_experts = vllm_config.model_config.get_num_experts()
|
|
ep_world_size = (
|
|
vllm_config.parallel_config.world_size_across_dp // vllm_config.parallel_config.pipeline_parallel_size
|
|
)
|
|
num_experts_per_device = num_experts // ep_world_size
|
|
if num_experts_per_device <= 24 and ep_world_size >= 16 and num_tokens <= mc2_tokens_capacity:
|
|
moe_comm_type = MoECommType.MC2
|
|
else:
|
|
moe_comm_type = MoECommType.ALLGATHER
|
|
|
|
elif soc_version in {AscendDeviceType.A3}:
|
|
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
|
|
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
|
|
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2
|
|
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model)
|
|
if num_tokens <= mc2_tokens_capacity:
|
|
fused_decode_enable = fused_mc2_enable
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
|
fused_decode_enable = fused_mc2_enable and dispatch_ffn_combine_enable
|
|
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
|
fused_decode_enable = (
|
|
fused_mc2_enable
|
|
and speculative_enable_dispatch_gmm_combine_decode(vllm_config)
|
|
and quant_type == "w8a8_dynamic"
|
|
)
|
|
moe_comm_type = MoECommType.FUSED_MC2 if fused_decode_enable else MoECommType.MC2
|
|
else:
|
|
fused_prefill_enable = fused_mc2_enable
|
|
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
|
fused_prefill_enable = fused_mc2_enable and dispatch_ffn_combine_enable
|
|
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
|
fused_prefill_enable = False
|
|
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL
|
|
elif soc_version in {AscendDeviceType._310P}:
|
|
moe_comm_type = MoECommType.ALLGATHER
|
|
elif soc_version in {AscendDeviceType.A5}:
|
|
if num_tokens <= mc2_tokens_capacity and vllm_config.parallel_config.world_size_across_dp > 1:
|
|
moe_comm_type = MoECommType.MC2
|
|
else:
|
|
moe_comm_type = MoECommType.ALLTOALL
|
|
else:
|
|
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
|
return moe_comm_type
|
|
|
|
|
|
class _ExtraForwardContextProxy:
|
|
"""Unified forward-context access for v1/v2 model runners."""
|
|
|
|
extra_attrs = (
|
|
"capturing",
|
|
"moe_comm_type",
|
|
"moe_comm_method",
|
|
"mmrs_fusion",
|
|
"num_tokens",
|
|
"flash_comm_v1_enabled",
|
|
"flashcomm_v2_enabled",
|
|
"pad_size",
|
|
"padded_length",
|
|
"num_tokens_across_dp",
|
|
"mc2_mask",
|
|
"is_draft_model",
|
|
"prefetch_mlp_gate_up_proj",
|
|
"prefetch_mlp_down_proj",
|
|
"model_instance",
|
|
"layer_idx",
|
|
"max_tokens_across_dp",
|
|
"max_tokens_across_pcp",
|
|
"num_accept_tokens",
|
|
"in_profile_run",
|
|
"padded_num_tokens",
|
|
)
|
|
|
|
def check_extra_attr(self, name: str):
|
|
if name not in self.extra_attrs:
|
|
raise AttributeError(
|
|
f"{name} is not extra forward context attribute, "
|
|
"please get/set it from vllm's _forward_context directly."
|
|
)
|
|
|
|
@staticmethod
|
|
def _ctx():
|
|
return get_forward_context()
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
self.check_extra_attr(name)
|
|
ctx = self._ctx()
|
|
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
|
|
return ctx.additional_kwargs[name]
|
|
return getattr(ctx, name)
|
|
|
|
def __setattr__(self, name: str, value: Any) -> None:
|
|
self.check_extra_attr(name)
|
|
ctx = self._ctx()
|
|
if envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
|
|
ctx.additional_kwargs[name] = value
|
|
else:
|
|
setattr(ctx, name, value)
|
|
|
|
|
|
# usage: from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
|
_EXTRA_CTX = _ExtraForwardContextProxy()
|