[Perf] move quant before allgather in Allgather EP (#3420)

### What this PR does / why we need it?
move quant before allgather in Allgather EP, rely on
https://github.com/vllm-project/vllm-ascend/pull/3334

Deepseek R1 W8A8 performance on A2 with
`HCCL_ALGO="level0:NA;level1:pipeline"`:
| Seq length | Mean TTFT (ms) main | Mean TTFT (ms)  this PR |
|----------|----------|----------|
| 4k   |  375.21  | 364.99   |
| 16k  | 1465.23   | 1421.75  |
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-11-04 16:49:58 +08:00
committed by GitHub
parent 44b58b8665
commit bedf223771
10 changed files with 160 additions and 66 deletions

View File

@@ -189,6 +189,25 @@ def test_sp_for_qwen3_moe() -> None:
vllm_model.generate(example_prompts, sampling_params) vllm_model.generate(example_prompts, sampling_params)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
def test_models_distributed_deepseek_v2_lite_with_flashcomm_v1() -> None:
example_prompts = [
"test" * 1001,
]
sampling_params = SamplingParams(max_tokens=5,
temperature=0.0,
top_k=50,
top_p=0.9)
with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
dtype="auto",
tensor_parallel_size=2,
distributed_executor_backend="mp",
enable_expert_parallel=True,
enforce_eager=True,
quantization="ascend") as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS) @pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"}) @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})

View File

@@ -458,6 +458,7 @@ class TestUnifiedApplyMLP(TestBase):
dtype=torch.float32)) dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16) hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
hidden_states_shape = hidden_states.shape
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16) w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16) w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16) w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
@@ -486,7 +487,7 @@ class TestUnifiedApplyMLP(TestBase):
mock_npu_swiglu.assert_called_once() mock_npu_swiglu.assert_called_once()
mock_npu_dynamic_quant.assert_called_once() mock_npu_dynamic_quant.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.shape, hidden_states_shape)
self.assertEqual(result.dtype, torch.bfloat16) self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p') @patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p')
@@ -568,6 +569,7 @@ class TestUnifiedApplyMLP(TestBase):
dtype=torch.float32)) dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16) hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
hidden_states_shape = hidden_states.shape
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16) w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16) w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16) w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
@@ -596,7 +598,7 @@ class TestUnifiedApplyMLP(TestBase):
mock_npu_grouped_matmul_swiglu_quant.assert_called_once() mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
self.assertTrue(mock_forward_context.with_quant) self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.shape, hidden_states_shape)
self.assertEqual(result.dtype, torch.bfloat16) self.assertEqual(result.dtype, torch.bfloat16)

View File

@@ -289,7 +289,7 @@ class AscendFusedMoE(FusedMoE):
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
setup_moe_comm_method(self.moe_config) setup_moe_comm_method(self.moe_config, self.quant_method)
def update_expert_map(self, new_expert_map): def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map self.expert_map = new_expert_map
@@ -336,11 +336,17 @@ class AscendFusedMoE(FusedMoE):
replace_allreduce=forward_context.sp_enabled, replace_allreduce=forward_context.sp_enabled,
enable_shared_expert_dp=self.enable_shared_expert_dp) enable_shared_expert_dp=self.enable_shared_expert_dp)
if isinstance(hidden_states, tuple):
hidden_states, pertoken_scale = hidden_states
else:
pertoken_scale = None
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=hidden_states, x=hidden_states,
router_logits=router_logits, router_logits=router_logits,
pertoken_scale=pertoken_scale,
top_k=self.top_k, top_k=self.top_k,
renormalize=self.renormalize, renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,

View File

@@ -27,10 +27,14 @@ from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import ( from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather, PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast) PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast, QuantType)
from vllm_ascend.ops.fused_moe.token_dispatcher import ( from vllm_ascend.ops.fused_moe.token_dispatcher import (
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2, TokenDispatcherWithMoge) TokenDispatcherWithMC2, TokenDispatcherWithMoge)
from vllm_ascend.quantization.w4a8_dynamic import \
AscendW4A8DynamicFusedMoEMethod
from vllm_ascend.quantization.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} _MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
@@ -40,25 +44,43 @@ def get_moe_comm_method(
return _MoECommMethods.get(moe_comm_type, None) return _MoECommMethods.get(moe_comm_type, None)
def setup_moe_comm_method(moe_config): def setup_moe_comm_method(moe_config, quant_method):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config) _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config) moe_config, quant_method)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config) _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(
moe_config, quant_method)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config, quant_method)
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl( _MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
moe_config) moe_config, quant_method)
class MoECommMethod(ABC): class MoECommMethod(ABC):
"""Base class for MoE communication methods.""" """Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig): def __init__(self, moe_config: FusedMoEConfig, quant_method=None):
self.model_type = get_current_vllm_config( self.model_type = get_current_vllm_config(
).model_config.hf_config.model_type ).model_config.hf_config.model_type
self.moe_config = moe_config self.moe_config = moe_config
self.token_dispatcher = self._get_token_dispatcher() self.token_dispatcher = self._get_token_dispatcher()
self.quant_type = self._get_quant_type(quant_method)
self.with_quant = self.quant_type != QuantType.NONE
self.prepare_finalize = self._get_prepare_finalize() self.prepare_finalize = self._get_prepare_finalize()
def _get_quant_type(self, quant_method) -> QuantType:
if not hasattr(quant_method,
"quant_method") or quant_method.quant_method is None:
return QuantType.NONE
method = quant_method.quant_method
if isinstance(method, AscendW8A8DynamicFusedMoEMethod):
return QuantType.W8A8
elif isinstance(method, AscendW4A8DynamicFusedMoEMethod):
return QuantType.W4A8
else:
return QuantType.NONE
def prepare( def prepare(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -90,8 +112,6 @@ class MoECommMethod(ABC):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None, global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
@@ -109,10 +129,11 @@ class MoECommMethod(ABC):
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
need_trans: bool = False, need_trans: bool = False,
dynamic_eplb: bool = False, dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None): mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):
# Check constraints # Check constraints
assert hidden_states.dtype in [ assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16 torch.float32, torch.float16, torch.bfloat16, torch.int8
] ]
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = get_forward_context().moe_comm_method
@@ -130,28 +151,29 @@ class MoECommMethod(ABC):
dynamic_scale_for_share=dynamic_scale_for_share, dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=mc2_mask, mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8, with_quant=self.with_quant,
dynamic_eplb=dynamic_eplb) dynamic_eplb=dynamic_eplb,
pertoken_scale=pertoken_scale)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \ permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata") results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, mlp_output = unified_apply_mlp(
w1=w1, hidden_states=permuted_hidden_states,
w1_scale=w1_scale, w1=w1,
w2=w2, w1_scale=w1_scale,
w2_scale=w2_scale, w2=w2,
group_list=expert_tokens, w2_scale=w2_scale,
dynamic_scale=dynamic_scale, group_list=expert_tokens,
group_list_type=group_list_type, dynamic_scale=dynamic_scale,
w1_scale_bias=w1_scale_bias, group_list_type=group_list_type,
w2_scale_bias=w2_scale_bias, w1_scale_bias=w1_scale_bias,
topk_scales=topk_scales, w2_scale_bias=w2_scale_bias,
with_quant=use_int8_w8a8 topk_scales=topk_scales,
or use_int4_w4a8, with_quant=self.with_quant,
fusion=use_int8_w8a8, fusion=self.quant_type == QuantType.W8A8,
need_trans=need_trans, need_trans=need_trans,
dynamic_eplb=dynamic_eplb) dynamic_eplb=dynamic_eplb)
final_hidden_states = self.token_dispatcher.token_combine( final_hidden_states = self.token_dispatcher.token_combine(
hidden_states=mlp_output, context_metadata=context_metadata) hidden_states=mlp_output, context_metadata=context_metadata)
@@ -204,7 +226,8 @@ class AllGatherCommImpl(MoECommMethod):
num_local_experts=self.moe_config.num_local_experts) num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self): def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAllGather(self.moe_config) return PrepareAndFinalizeWithAllGather(self.moe_config,
self.quant_type)
class MC2CommImpl(MoECommMethod): class MC2CommImpl(MoECommMethod):
@@ -221,7 +244,7 @@ class MC2CommImpl(MoECommMethod):
return TokenDispatcherWithMC2() return TokenDispatcherWithMC2()
def _get_prepare_finalize(self): def _get_prepare_finalize(self):
return PrepareAndFinalizeWithMC2(self.moe_config) return PrepareAndFinalizeWithMC2(self.moe_config, self.quant_type)
class AlltoAllCommImpl(MoECommMethod): class AlltoAllCommImpl(MoECommMethod):
@@ -241,7 +264,7 @@ class AlltoAllCommImpl(MoECommMethod):
num_local_experts=self.moe_config.num_local_experts) num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self): def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config) return PrepareAndFinalizeWithAll2All(self.moe_config, self.quant_type)
class NaiveMulticastCommImpl(MoECommMethod): class NaiveMulticastCommImpl(MoECommMethod):
@@ -270,4 +293,5 @@ class NaiveMulticastCommImpl(MoECommMethod):
num_local_experts=self.moe_config.num_local_experts) num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self): def _get_prepare_finalize(self):
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config) return PrepareAndFinalizeWithNaiveMulticast(self.moe_config,
self.quant_type)

View File

@@ -72,8 +72,10 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# Dispose the original unquantized hidden states # Dispose the original unquantized hidden states
# to save npu memory because they're no longer used. # to save npu memory because they're no longer used.
dispose_tensor(unquantized_hidden_states) dispose_tensor(unquantized_hidden_states)
quantized_hidden_states = None
else: else:
pertoken_scale = dynamic_scale pertoken_scale = dynamic_scale
quantized_hidden_states = hidden_states
bias1, bias2 = None, None bias1, bias2 = None, None
_output_dtype = w2_scale.dtype _output_dtype = w2_scale.dtype
@@ -92,6 +94,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list=cumsum_group_list(group_list, group_list_type), group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale, weight_scale=w1_scale,
x_scale=pertoken_scale) x_scale=pertoken_scale)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else: else:
if w1_scale.dtype != torch.float32: if w1_scale.dtype != torch.float32:
w1_scale = w1_scale.to(torch.float32) w1_scale = w1_scale.to(torch.float32)
@@ -104,6 +108,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
output_dtype=torch.int32)[0] output_dtype=torch.int32)[0]
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
# act_fn: swiglu # act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states, x=hidden_states,
@@ -148,6 +154,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list=cumsum_group_list(group_list, group_list_type), group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale, weight_scale=w1_scale,
x_scale=pertoken_scale) x_scale=pertoken_scale)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else: else:
# gmm1: gate_up_proj # gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul( hidden_states = torch_npu.npu_grouped_matmul(
@@ -161,6 +169,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
output_dtype=_output_dtype)[0] output_dtype=_output_dtype)[0]
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
# act_fn: swiglu # act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states) hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(

View File

@@ -15,11 +15,13 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch_npu
from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_dp_group, get_tensor_model_parallel_rank, get_dp_group, get_tensor_model_parallel_rank,
@@ -30,6 +32,12 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import enable_sp from vllm_ascend.utils import enable_sp
class QuantType(Enum):
NONE = 0
W8A8 = 1
W4A8 = 2
class PrepareAndFinalize(ABC): class PrepareAndFinalize(ABC):
""" """
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
@@ -42,8 +50,11 @@ class PrepareAndFinalize(ABC):
sizes, ranks, and communication settings. sizes, ranks, and communication settings.
""" """
def __init__(self, moe_config: FusedMoEConfig): def __init__(self,
moe_config: FusedMoEConfig,
quant_type: QuantType = QuantType.NONE):
self.moe_config = moe_config self.moe_config = moe_config
self.quant_type = quant_type
@abstractmethod @abstractmethod
def prepare( def prepare(
@@ -103,8 +114,10 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank). Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
""" """
def __init__(self, moe_config: FusedMoEConfig): def __init__(self,
super().__init__(moe_config) moe_config: FusedMoEConfig,
quant_type: QuantType = QuantType.NONE):
super().__init__(moe_config, quant_type)
self._restore_tp_across_dp() self._restore_tp_across_dp()
def _restore_tp_across_dp(self): def _restore_tp_across_dp(self):
@@ -195,8 +208,10 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment. Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
""" """
def __init__(self, moe_config: FusedMoEConfig): def __init__(self,
super().__init__(moe_config) moe_config: FusedMoEConfig,
quant_type: QuantType = QuantType.NONE):
super().__init__(moe_config, quant_type)
self._restore_tp_across_dp() self._restore_tp_across_dp()
def _restore_tp_across_dp(self): def _restore_tp_across_dp(self):
@@ -316,11 +331,20 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor, router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
pertoken_scale = None
if self.quant_type == QuantType.W8A8:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
pertoken_scale, True, True)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True, True) hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
router_logits, True, True) router_logits, True, True)
if pertoken_scale is not None:
return (hidden_states, pertoken_scale), router_logits, None, None
return hidden_states, router_logits, None, None return hidden_states, router_logits, None, None
def _prepare_with_dp_group( def _prepare_with_dp_group(

View File

@@ -57,20 +57,23 @@ class MoETokenDispatcher(ABC):
return get_ep_group().world_size return get_ep_group().world_size
@abstractmethod @abstractmethod
def token_dispatch(self, def token_dispatch(
hidden_states: torch.Tensor, self,
topk_weights: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_weights: torch.Tensor,
expert_map: Optional[torch.Tensor] = None, topk_ids: torch.Tensor,
log2phy: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0, log2phy: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None, global_redundant_expert_num: int = 0,
quantized_x_for_share: Optional[Any] = None, shared_experts: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None, dynamic_scale_for_share: Optional[Any] = None,
apply_router_weight_on_input: bool = False, mc2_mask: Optional[torch.Tensor] = None,
with_quant: bool = False, apply_router_weight_on_input: bool = False,
dynamic_eplb: bool = False): with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None,
):
raise NotImplementedError("Dispatch function not implemented.") raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod @abstractmethod
@@ -170,7 +173,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False, with_quant: bool = False,
dynamic_eplb: bool = False): dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
self.with_quant = with_quant self.with_quant = with_quant
# Apply log2phy if needed # Apply log2phy if needed
@@ -339,7 +343,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False, with_quant: bool = False,
dynamic_eplb: bool = False): dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
self.with_quant = with_quant self.with_quant = with_quant
self.original_shape = hidden_states.shape self.original_shape = hidden_states.shape
@@ -370,12 +375,14 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
torch_npu.npu_moe_init_routing_v2( torch_npu.npu_moe_init_routing_v2(
hidden_states, hidden_states,
topk_ids, topk_ids,
scale=pertoken_scale,
active_num=num_tokens * self.top_k, active_num=num_tokens * self.top_k,
expert_num=global_num_experts, expert_num=global_num_experts,
expert_tokens_num_type=1, expert_tokens_num_type=1,
expert_tokens_num_flag=True, expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx], active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=1 if self.with_quant else -1, quant_mode=1
if self.with_quant and pertoken_scale is None else -1,
)) ))
expert_tokens = expert_tokens.to(torch.int64) expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode group_list_type = 1 # `count` mode
@@ -430,7 +437,8 @@ class TokenDispatcherWithMoge(MoETokenDispatcher):
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False, with_quant: bool = False,
dynamic_eplb: bool = False): dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
self.bsz, _ = hidden_states.shape self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1) flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
@@ -518,7 +526,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False, with_quant: bool = False,
dynamic_eplb: bool = False): dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
self.with_quant = with_quant self.with_quant = with_quant
self.hidden_shape = hidden_states.shape self.hidden_shape = hidden_states.shape

View File

@@ -36,7 +36,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(
x = tensor_model_parallel_all_gather(x, 0) x = tensor_model_parallel_all_gather(x, 0)
pad_size = forward_context.pad_size pad_size = forward_context.pad_size
if pad_size > 0: if pad_size > 0:
x = x[:-pad_size, :] x = x[:-pad_size]
else: else:
x = get_ep_group().all_gather(x, 0) x = get_ep_group().all_gather(x, 0)
# unpad # unpad
@@ -50,8 +50,7 @@ def _maybe_all_gather_and_maybe_unpad_impl(
offset = 0 offset = 0
for idx in range(dp_size): for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx] num_tokens_dp = num_tokens_across_dp_cpu[idx]
result[offset:offset + result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp]
num_tokens_dp, :] = x[idx, :num_tokens_dp, :]
offset += num_tokens_dp offset += num_tokens_dp
x = result x = result

View File

@@ -386,7 +386,6 @@ class AscendW4A8DynamicFusedMoEMethod:
w2_scale_bias=layer.w2_scale_bias, w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
use_int4_w4a8=True,
expert_map=expert_map, expert_map=expert_map,
log2phy=log2phy, log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num, global_redundant_expert_num=global_redundant_expert_num,

View File

@@ -143,6 +143,7 @@ class AscendW8A8DynamicFusedMoEMethod:
and not ascend_config.torchair_graph_config.enabled) and not ascend_config.torchair_graph_config.enabled)
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
self.in_dtype = vllm_config.model_config.dtype
try: try:
device_group = get_mc2_group().device_group device_group = get_mc2_group().device_group
@@ -218,6 +219,7 @@ class AscendW8A8DynamicFusedMoEMethod:
shared_experts: Optional[Any] = None, shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None,
pertoken_scale: Optional[Any] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert router_logits.shape[ assert router_logits.shape[
@@ -242,18 +244,18 @@ class AscendW8A8DynamicFusedMoEMethod:
if enable_force_load_balance: if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(self.in_dtype)
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts( return moe_comm_method.fused_experts(
hidden_states=x, hidden_states=x,
pertoken_scale=pertoken_scale,
w1=layer.w13_weight, w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale_fp32, w1_scale=layer.w13_weight_scale_fp32,
w2=layer.w2_weight, w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
use_int8_w8a8=True,
expert_map=expert_map, expert_map=expert_map,
log2phy=log2phy, log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num, global_redundant_expert_num=global_redundant_expert_num,