### What this PR does / why we need it?
* **Unify execution paths:** Consolidates the quantized and
non-quantized execution paths into a single `fused_experts` function,
removing duplicated logic and making the control flow clearer and easier
to maintain.
* **W8A8 dynamic quantization:** Adds support for W8A8 dynamic
quantization inside the unified MoE kernel. Communication routines are
updated to correctly handle dynamic quantization scales for activations.
* **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight
matrices (as implemented in PR #2025) so that quantized and
non-quantized models follow the same code path for the MoE gating,
up-projection, and down-projection operations.
* **All-to-all communication:** Adds an `all-to-all` collective
communication pattern. For large token counts on modern hardware,
`all-to-all` is more efficient than the previous `all-gather` strategy.
However, `all-to-all` is not really captured and replayed due to
multiple D2H operations which will trigger synchronization, and thus
raise error when capture graphs. We only use `all-to-all` when fallback
to `compiled_graph_for_general_shape`.
* **Dynamic communication selection:** The model runner now selects the
optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at
runtime based on token count and the Ascend SoC version.
* **Limitation:** `all-gather` is not yet supported for quantized
models, which means there is still something left to do on A2.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
No further test cases needed.
- vLLM version: v0.10.1.1
- vLLM main:
d660c98c1b
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
355 lines
13 KiB
Python
355 lines
13 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import Any, Callable, Optional
|
|
|
|
import torch
|
|
import torch_npu
|
|
from vllm.config import CompilationLevel, get_current_vllm_config
|
|
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.model_executor.layers.fused_moe.layer import (
|
|
FusedMoE, UnquantizedFusedMoEMethod)
|
|
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
|
AlltoAllCommImpl,
|
|
MC2CommImpl)
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|
from vllm_ascend.ops.fused_moe import fused_experts_moge
|
|
from vllm_ascend.ops.layers.experts_selector import select_experts
|
|
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
|
setup_token_dispatchers
|
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
|
|
|
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
|
|
|
|
|
def fused_experts(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
activation: str = "silu",
|
|
apply_router_weight_on_input: bool = False,
|
|
use_int8_w8a8: bool = False,
|
|
use_int4_w4a8: bool = False,
|
|
global_num_experts: Optional[int] = None,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
w1_scale: Optional[torch.Tensor] = None,
|
|
w2_scale: Optional[torch.Tensor] = None,
|
|
w1_scale_bias: torch.Tensor = None,
|
|
w2_scale_bias: torch.Tensor = None,
|
|
# For TorchAir graph
|
|
is_torchair: bool = False,
|
|
# For Cube/Vector parallel
|
|
shared_experts: Optional[Any] = None,
|
|
quantized_x_for_share: Optional[Any] = None,
|
|
dynamic_scale_for_share: Optional[Any] = None,
|
|
# For load balance
|
|
log2phy: torch.Tensor = None,
|
|
global_redundant_expert_num: int = 0,
|
|
) -> torch.Tensor:
|
|
# Check constraints
|
|
assert hidden_states.shape[1] == w1.shape[1], (
|
|
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
|
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
assert hidden_states.dtype in [
|
|
torch.float32, torch.float16, torch.bfloat16
|
|
]
|
|
if (use_int8_w8a8 or use_int4_w4a8):
|
|
assert w1_scale is not None and w2_scale is not None, \
|
|
"INT8 quantization requires weight scales."
|
|
|
|
w1_scale = w1_scale.to(torch.float32)
|
|
down_scale = [w2_scale]
|
|
down_output_dtype = w2_scale.dtype
|
|
else:
|
|
down_scale = None
|
|
down_output_dtype = None
|
|
|
|
moe_comm_method = get_forward_context().moe_comm_method
|
|
assert moe_comm_method is not None, "Missing communication context"
|
|
|
|
num_experts = w1.shape[0]
|
|
|
|
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute(
|
|
hidden_states, topk_ids, topk_weights, expert_map, num_experts,
|
|
use_int8_w8a8 or use_int4_w4a8)
|
|
|
|
gate_up_output = torch_npu.npu_grouped_matmul(
|
|
x=[permuted_hidden_states],
|
|
weight=[w1],
|
|
split_item=2,
|
|
group_list_type=group_list_type,
|
|
group_type=0,
|
|
group_list=expert_tokens,
|
|
output_dtype=torch.int32 if use_int8_w8a8 else None,
|
|
)[0]
|
|
|
|
if (use_int8_w8a8 or use_int4_w4a8):
|
|
activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant(
|
|
x=gate_up_output,
|
|
weight_scale=w1_scale,
|
|
activation_scale=dynamic_scale,
|
|
bias=None,
|
|
quant_scale=None,
|
|
quant_offset=None,
|
|
group_index=expert_tokens,
|
|
activate_left=True,
|
|
quant_mode=1,
|
|
)
|
|
activated_output_scale = [activated_output_scale]
|
|
else:
|
|
activated_output = torch_npu.npu_swiglu(gate_up_output)
|
|
activated_output_scale = None
|
|
|
|
down_output = torch_npu.npu_grouped_matmul(
|
|
x=[activated_output],
|
|
weight=[w2],
|
|
scale=down_scale,
|
|
per_token_scale=activated_output_scale,
|
|
split_item=2,
|
|
group_list_type=group_list_type,
|
|
group_type=0,
|
|
group_list=expert_tokens,
|
|
output_dtype=down_output_dtype,
|
|
)[0]
|
|
|
|
moe_comm_method.unpermute(down_output, hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
|
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
|
|
|
# NOTE: Currently, this self.use_aclgraph is only used in
|
|
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
|
|
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
|
|
# Once torch.randint_like is supported or removed, this flag can be removed.
|
|
vllm_config = get_current_vllm_config()
|
|
ascend_config = get_ascend_config()
|
|
if ascend_config.torchair_graph_config.enabled:
|
|
self.use_aclgraph = False
|
|
else:
|
|
self.use_aclgraph = (vllm_config.compilation_config.level
|
|
== CompilationLevel.PIECEWISE
|
|
and not vllm_config.model_config.enforce_eager)
|
|
|
|
|
|
def forward_oot(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
enable_eplb: bool = False,
|
|
expert_load_view: Optional[torch.Tensor] = None,
|
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
topk_weights, topk_ids, _ = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
global_num_experts=global_num_experts)
|
|
|
|
if topk_ids.shape[1] < top_k or is_310p():
|
|
assert global_num_experts is not None
|
|
return fused_experts_moge(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
moe_parallel_config=self.moe.moe_parallel_config,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
top_k=top_k,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
|
|
|
return fused_experts(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
)
|
|
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
|
|
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
|
1, 2).contiguous()
|
|
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
|
|
|
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
|
1, 2).contiguous()
|
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
|
|
|
if not is_310p():
|
|
layer.w13_weight.data = torch_npu.npu_format_cast(
|
|
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
layer.w2_weight.data = torch_npu.npu_format_cast(
|
|
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
|
|
|
|
class AscendFusedMoE(FusedMoE):
|
|
|
|
def __init__(
|
|
self,
|
|
num_experts,
|
|
top_k,
|
|
hidden_size,
|
|
intermediate_size,
|
|
params_dtype=None,
|
|
reduce_results=False,
|
|
renormalize=True,
|
|
use_grouped_topk=False,
|
|
num_expert_group=None,
|
|
topk_group=None,
|
|
quant_config=None,
|
|
tp_size=None,
|
|
ep_size=None,
|
|
dp_size=None,
|
|
prefix="",
|
|
custom_routing_function=None,
|
|
scoring_func="softmax",
|
|
e_score_correction_bias=None,
|
|
apply_router_weight_on_input=False,
|
|
activation="silu",
|
|
enable_eplb=False,
|
|
num_redundant_experts=0,
|
|
has_bias=False,
|
|
):
|
|
super().__init__(
|
|
num_experts,
|
|
top_k,
|
|
hidden_size,
|
|
intermediate_size,
|
|
params_dtype,
|
|
reduce_results,
|
|
renormalize,
|
|
use_grouped_topk,
|
|
num_expert_group,
|
|
topk_group,
|
|
quant_config,
|
|
tp_size,
|
|
ep_size,
|
|
dp_size,
|
|
prefix,
|
|
custom_routing_function,
|
|
scoring_func,
|
|
e_score_correction_bias,
|
|
apply_router_weight_on_input,
|
|
activation,
|
|
enable_eplb,
|
|
num_redundant_experts,
|
|
has_bias,
|
|
)
|
|
|
|
setup_token_dispatchers(self.moe_config.ep_size,
|
|
top_k=self.top_k,
|
|
num_experts=self.global_num_experts,
|
|
num_local_experts=self.local_num_experts)
|
|
|
|
self.moe_config.tp_group = get_tp_group()
|
|
self.moe_config.dp_group = get_dp_group()
|
|
self.moe_config.ep_group = get_ep_group()
|
|
self.moe_config.mc2_group = get_mc2_group()
|
|
|
|
for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}:
|
|
setattr(
|
|
self, method.__name__.lower(),
|
|
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
|
|
|
def forward_impl(self, hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor):
|
|
assert self.quant_method is not None
|
|
|
|
forward_context = get_forward_context()
|
|
moe_comm_method_name = forward_context.moe_comm_method_name
|
|
|
|
# TODO: Can we refactor this logic to model_runner?
|
|
# TODO: Adjusted logic to differentiate between A2 and A3, we check ep_size here since mc2 only support ep_size >= 16 on A3 now
|
|
if self.moe_config.ep_size < 16:
|
|
moe_comm_method_name = "allgathercommimpl"
|
|
|
|
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
|
|
|
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
|
hidden_states=hidden_states, router_logits=router_logits)
|
|
|
|
# Matrix multiply.
|
|
final_hidden_states = self.quant_method.apply(
|
|
layer=self,
|
|
x=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.top_k,
|
|
renormalize=self.renormalize,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
global_num_experts=self.global_num_experts,
|
|
expert_map=self.expert_map,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
custom_routing_function=self.custom_routing_function,
|
|
scoring_func=self.scoring_func,
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
|
activation=self.activation,
|
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
|
enable_eplb=self.enable_eplb,
|
|
expert_load_view=self.expert_load_view,
|
|
logical_to_physical_map=self.logical_to_physical_map,
|
|
logical_replica_count=self.logical_replica_count,
|
|
)
|
|
|
|
final_hidden_states = forward_context.moe_comm_method.finalize(
|
|
hidden_states=final_hidden_states,
|
|
reduce_results=self.reduce_results)
|
|
|
|
return final_hidden_states
|
|
|
|
|
|
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
|
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
|
|
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|