[3/N][Feat][Graph] Support all-to-all and quantized models with ACL Graph (#2614)

### What this PR does / why we need it?
* **Unify execution paths:** Consolidates the quantized and
non-quantized execution paths into a single `fused_experts` function,
removing duplicated logic and making the control flow clearer and easier
to maintain.
* **W8A8 dynamic quantization:** Adds support for W8A8 dynamic
quantization inside the unified MoE kernel. Communication routines are
updated to correctly handle dynamic quantization scales for activations.
* **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight
matrices (as implemented in PR #2025) so that quantized and
non-quantized models follow the same code path for the MoE gating,
up-projection, and down-projection operations.
* **All-to-all communication:** Adds an `all-to-all` collective
communication pattern. For large token counts on modern hardware,
`all-to-all` is more efficient than the previous `all-gather` strategy.
However, `all-to-all` is not really captured and replayed due to
multiple D2H operations which will trigger synchronization, and thus
raise error when capture graphs. We only use `all-to-all` when fallback
to `compiled_graph_for_general_shape`.
* **Dynamic communication selection:** The model runner now selects the
optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at
runtime based on token count and the Ascend SoC version.
* **Limitation:** `all-gather` is not yet supported for quantized
models, which means there is still something left to do on A2.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
No further test cases needed.

- vLLM version: v0.10.1.1
- vLLM main:
d660c98c1b

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-30 11:00:35 +08:00
committed by GitHub
parent 91c35d765a
commit d3c93fba5c
7 changed files with 248 additions and 41 deletions

View File

@@ -18,6 +18,7 @@
from typing import Any, Callable, Optional
import torch
import torch_npu
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.forward_context import get_forward_context
@@ -26,12 +27,14 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
MC2CommImpl,
MoECommMethod)
AlltoAllCommImpl,
MC2CommImpl)
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import apply_mlp, fused_experts_moge
from vllm_ascend.ops.fused_moe import fused_experts_moge
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import is_310p
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
setup_token_dispatchers
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -52,7 +55,6 @@ def fused_experts(
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
moe_comm_method: Optional[MoECommMethod] = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
@@ -64,9 +66,8 @@ def fused_experts(
global_redundant_expert_num: int = 0,
) -> torch.Tensor:
# Check constraints
assert hidden_states.shape[1] == w1.shape[2], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
assert hidden_states.shape[1] == w1.shape[1], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
@@ -74,31 +75,79 @@ def fused_experts(
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
if (use_int8_w8a8 or use_int4_w4a8):
assert w1_scale is not None and w2_scale is not None, \
"INT8 quantization requires weight scales."
w1_scale = w1_scale.to(torch.float32)
down_scale = [w2_scale]
down_output_dtype = w2_scale.dtype
else:
down_scale = None
down_output_dtype = None
moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
num_experts = w1.shape[0]
permuted_hidden_states, expert_tokens, group_list_type = moe_comm_method.permute(
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
mlp_output = apply_mlp(
permuted_hidden_states,
w1,
w2,
expert_tokens,
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute(
hidden_states, topk_ids, topk_weights, expert_map, num_experts,
use_int8_w8a8 or use_int4_w4a8)
gate_up_output = torch_npu.npu_grouped_matmul(
x=[permuted_hidden_states],
weight=[w1],
split_item=2,
group_list_type=group_list_type,
)
moe_comm_method.unpermute(mlp_output, hidden_states)
group_type=0,
group_list=expert_tokens,
output_dtype=torch.int32 if use_int8_w8a8 else None,
)[0]
if (use_int8_w8a8 or use_int4_w4a8):
activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant(
x=gate_up_output,
weight_scale=w1_scale,
activation_scale=dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=expert_tokens,
activate_left=True,
quant_mode=1,
)
activated_output_scale = [activated_output_scale]
else:
activated_output = torch_npu.npu_swiglu(gate_up_output)
activated_output_scale = None
down_output = torch_npu.npu_grouped_matmul(
x=[activated_output],
weight=[w2],
scale=down_scale,
per_token_scale=activated_output_scale,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=down_output_dtype,
)[0]
moe_comm_method.unpermute(down_output, hidden_states)
return hidden_states
def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
# NOTE: Currently, this self.use_aclgraph is only used in
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
# Once torch.randint_like is supported or removed, this flag can be removed.
vllm_config = get_current_vllm_config()
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
ascend_config = get_ascend_config()
if ascend_config.torchair_graph_config.enabled:
self.use_aclgraph = False
else:
@@ -156,8 +205,6 @@ def forward_oot(
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
moe_comm_method = get_forward_context().moe_comm_method
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
@@ -166,10 +213,26 @@ def forward_oot(
topk_ids=topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
moe_comm_method=moe_comm_method,
)
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
class AscendFusedMoE(FusedMoE):
def __init__(
@@ -224,12 +287,17 @@ class AscendFusedMoE(FusedMoE):
has_bias,
)
setup_token_dispatchers(self.moe_config.ep_size,
top_k=self.top_k,
num_experts=self.global_num_experts,
num_local_experts=self.local_num_experts)
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
for method in {AllGatherCommImpl, MC2CommImpl}:
for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
@@ -282,4 +350,5 @@ class AscendFusedMoE(FusedMoE):
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
UnquantizedFusedMoEMethod.forward_oot = forward_oot

View File

@@ -230,7 +230,6 @@ def fused_experts_moge(
0, sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
w1 = w1.transpose(1, 2)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[sorted_hidden_states],
weight=[w1],
@@ -247,7 +246,6 @@ def fused_experts_moge(
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out *= topk_scales
w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],