[Perf] move quant before allgather in Allgather EP (#3420)

### What this PR does / why we need it?
move quant before allgather in Allgather EP, rely on
https://github.com/vllm-project/vllm-ascend/pull/3334

Deepseek R1 W8A8 performance on A2 with
`HCCL_ALGO="level0:NA;level1:pipeline"`:
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms)  this PR |
|----------|----------|----------|
| 4k   |  375.21  | 364.99   |
| 16k  | 1465.23   | 1421.75  |
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-11-04 16:49:58 +08:00
committed by GitHub
parent 44b58b8665
commit bedf223771
10 changed files with 160 additions and 66 deletions

View File

@@ -289,7 +289,7 @@ class AscendFusedMoE(FusedMoE):
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
setup_moe_comm_method(self.moe_config)
setup_moe_comm_method(self.moe_config, self.quant_method)
def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map
@@ -336,11 +336,17 @@ class AscendFusedMoE(FusedMoE):
replace_allreduce=forward_context.sp_enabled,
enable_shared_expert_dp=self.enable_shared_expert_dp)
if isinstance(hidden_states, tuple):
hidden_states, pertoken_scale = hidden_states
else:
pertoken_scale = None
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
pertoken_scale=pertoken_scale,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,

View File

@@ -27,10 +27,14 @@ from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast)
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast, QuantType)
from vllm_ascend.ops.fused_moe.token_dispatcher import (
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
from vllm_ascend.quantization.w4a8_dynamic import \
AscendW4A8DynamicFusedMoEMethod
from vllm_ascend.quantization.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
@@ -40,25 +44,43 @@ def get_moe_comm_method(
return _MoECommMethods.get(moe_comm_type, None)
def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
def setup_moe_comm_method(moe_config, quant_method):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(
moe_config, quant_method)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(
moe_config, quant_method)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config, quant_method)
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
moe_config)
moe_config, quant_method)
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig):
def __init__(self, moe_config: FusedMoEConfig, quant_method=None):
self.model_type = get_current_vllm_config(
).model_config.hf_config.model_type
self.moe_config = moe_config
self.token_dispatcher = self._get_token_dispatcher()
self.quant_type = self._get_quant_type(quant_method)
self.with_quant = self.quant_type != QuantType.NONE
self.prepare_finalize = self._get_prepare_finalize()
def _get_quant_type(self, quant_method) -> QuantType:
if not hasattr(quant_method,
"quant_method") or quant_method.quant_method is None:
return QuantType.NONE
method = quant_method.quant_method
if isinstance(method, AscendW8A8DynamicFusedMoEMethod):
return QuantType.W8A8
elif isinstance(method, AscendW4A8DynamicFusedMoEMethod):
return QuantType.W4A8
else:
return QuantType.NONE
def prepare(
self,
hidden_states: torch.Tensor,
@@ -90,8 +112,6 @@ class MoECommMethod(ABC):
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
@@ -109,10 +129,11 @@ class MoECommMethod(ABC):
global_redundant_expert_num: int = 0,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None):
mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):
# Check constraints
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
torch.float32, torch.float16, torch.bfloat16, torch.int8
]
moe_comm_method = get_forward_context().moe_comm_method
@@ -130,28 +151,29 @@ class MoECommMethod(ABC):
dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8,
dynamic_eplb=dynamic_eplb)
with_quant=self.with_quant,
dynamic_eplb=dynamic_eplb,
pertoken_scale=pertoken_scale)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=expert_tokens,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=topk_scales,
with_quant=use_int8_w8a8
or use_int4_w4a8,
fusion=use_int8_w8a8,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb)
mlp_output = unified_apply_mlp(
hidden_states=permuted_hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=expert_tokens,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=topk_scales,
with_quant=self.with_quant,
fusion=self.quant_type == QuantType.W8A8,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb)
final_hidden_states = self.token_dispatcher.token_combine(
hidden_states=mlp_output, context_metadata=context_metadata)
@@ -204,7 +226,8 @@ class AllGatherCommImpl(MoECommMethod):
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAllGather(self.moe_config)
return PrepareAndFinalizeWithAllGather(self.moe_config,
self.quant_type)
class MC2CommImpl(MoECommMethod):
@@ -221,7 +244,7 @@ class MC2CommImpl(MoECommMethod):
return TokenDispatcherWithMC2()
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithMC2(self.moe_config)
return PrepareAndFinalizeWithMC2(self.moe_config, self.quant_type)
class AlltoAllCommImpl(MoECommMethod):
@@ -241,7 +264,7 @@ class AlltoAllCommImpl(MoECommMethod):
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
return PrepareAndFinalizeWithAll2All(self.moe_config, self.quant_type)
class NaiveMulticastCommImpl(MoECommMethod):
@@ -270,4 +293,5 @@ class NaiveMulticastCommImpl(MoECommMethod):
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config)
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config,
self.quant_type)

View File

@@ -72,8 +72,10 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# Dispose the original unquantized hidden states
# to save npu memory because they're no longer used.
dispose_tensor(unquantized_hidden_states)
quantized_hidden_states = None
else:
pertoken_scale = dynamic_scale
quantized_hidden_states = hidden_states
bias1, bias2 = None, None
_output_dtype = w2_scale.dtype
@@ -92,6 +94,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale,
x_scale=pertoken_scale)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else:
if w1_scale.dtype != torch.float32:
w1_scale = w1_scale.to(torch.float32)
@@ -104,6 +108,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
@@ -148,6 +154,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale,
x_scale=pertoken_scale)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
@@ -161,6 +169,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(

View File

@@ -15,11 +15,13 @@
# This file is a part of the vllm-ascend project.
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_dp_group, get_tensor_model_parallel_rank,
@@ -30,6 +32,12 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import enable_sp
class QuantType(Enum):
NONE = 0
W8A8 = 1
W4A8 = 2
class PrepareAndFinalize(ABC):
"""
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
@@ -42,8 +50,11 @@ class PrepareAndFinalize(ABC):
sizes, ranks, and communication settings.
"""
def __init__(self, moe_config: FusedMoEConfig):
def __init__(self,
moe_config: FusedMoEConfig,
quant_type: QuantType = QuantType.NONE):
self.moe_config = moe_config
self.quant_type = quant_type
@abstractmethod
def prepare(
@@ -103,8 +114,10 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
def __init__(self,
moe_config: FusedMoEConfig,
quant_type: QuantType = QuantType.NONE):
super().__init__(moe_config, quant_type)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
@@ -195,8 +208,10 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
def __init__(self,
moe_config: FusedMoEConfig,
quant_type: QuantType = QuantType.NONE):
super().__init__(moe_config, quant_type)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
@@ -316,11 +331,20 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
pertoken_scale = None
if self.quant_type == QuantType.W8A8:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
pertoken_scale, True, True)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
router_logits, True, True)
if pertoken_scale is not None:
return (hidden_states, pertoken_scale), router_logits, None, None
return hidden_states, router_logits, None, None
def _prepare_with_dp_group(

View File

@@ -57,20 +57,23 @@ class MoETokenDispatcher(ABC):
return get_ep_group().world_size
@abstractmethod
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False):
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None,
):
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
@@ -170,7 +173,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False):
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
self.with_quant = with_quant
# Apply log2phy if needed
@@ -339,7 +343,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False):
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
self.with_quant = with_quant
self.original_shape = hidden_states.shape
@@ -370,12 +375,14 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
scale=pertoken_scale,
active_num=num_tokens * self.top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=1 if self.with_quant else -1,
quant_mode=1
if self.with_quant and pertoken_scale is None else -1,
))
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
@@ -430,7 +437,8 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False):
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
@@ -518,7 +526,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False):
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape