From bedf22377139cd0bceaa089afc8b16ceaddee0a3 Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Tue, 4 Nov 2025 16:49:58 +0800 Subject: [PATCH] [Perf] move quant before allgather in Allgather EP (#3420) ### What this PR does / why we need it? move quant before allgather in Allgather EP, rely on https://github.com/vllm-project/vllm-ascend/pull/3334 Deepseek R1 W8A8 performance on A2 with `HCCL_ALGO="level0:NA;level1:pipeline"`: | Seq length | Mean TTFT (ms) main | Mean TTFT (ms) this PR | |----------|----------|----------| | 4k | 375.21 | 364.99 | | 16k | 1465.23 | 1421.75 | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac --------- Signed-off-by: realliujiaxu --- .../test_offline_inference_distributed.py | 19 ++++ tests/ut/ops/test_fused_moe.py | 6 +- vllm_ascend/ops/fused_moe/fused_moe.py | 8 +- vllm_ascend/ops/fused_moe/moe_comm_method.py | 90 ++++++++++++------- vllm_ascend/ops/fused_moe/moe_mlp.py | 10 +++ vllm_ascend/ops/fused_moe/prepare_finalize.py | 34 +++++-- vllm_ascend/ops/fused_moe/token_dispatcher.py | 47 ++++++---- vllm_ascend/ops/register_custom_ops.py | 5 +- vllm_ascend/quantization/w4a8_dynamic.py | 1 - vllm_ascend/quantization/w8a8_dynamic.py | 6 +- 10 files changed, 160 insertions(+), 66 deletions(-) diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 5959a3a6..a8102ec7 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -189,6 +189,25 @@ def test_sp_for_qwen3_moe() -> None: vllm_model.generate(example_prompts, sampling_params) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) +def test_models_distributed_deepseek_v2_lite_with_flashcomm_v1() -> None: + example_prompts = [ + "test" * 1001, + ] + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"), + dtype="auto", + tensor_parallel_size=2, + distributed_executor_backend="mp", + enable_expert_parallel=True, + enforce_eager=True, + quantization="ascend") as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + @pytest.mark.parametrize("model", QWEN_DENSE_MODELS) @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"}) @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) diff --git a/tests/ut/ops/test_fused_moe.py b/tests/ut/ops/test_fused_moe.py index 8cd2961b..3a27e44e 100644 --- a/tests/ut/ops/test_fused_moe.py +++ b/tests/ut/ops/test_fused_moe.py @@ -458,6 +458,7 @@ class TestUnifiedApplyMLP(TestBase): dtype=torch.float32)) hidden_states = torch.randn(10, 20, dtype=torch.bfloat16) + hidden_states_shape = hidden_states.shape w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16) w1_scale = torch.randn(5, 40, dtype=torch.bfloat16) w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16) @@ -486,7 +487,7 @@ class TestUnifiedApplyMLP(TestBase): mock_npu_swiglu.assert_called_once() mock_npu_dynamic_quant.assert_called_once() - self.assertEqual(result.shape, hidden_states.shape) + self.assertEqual(result.shape, hidden_states_shape) self.assertEqual(result.dtype, torch.bfloat16) @patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p') @@ -568,6 +569,7 @@ class TestUnifiedApplyMLP(TestBase): dtype=torch.float32)) hidden_states = torch.randn(10, 20, dtype=torch.bfloat16) + hidden_states_shape = hidden_states.shape w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16) w1_scale = torch.randn(5, 40, dtype=torch.bfloat16) w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16) @@ -596,7 +598,7 @@ class TestUnifiedApplyMLP(TestBase): mock_npu_grouped_matmul_swiglu_quant.assert_called_once() self.assertTrue(mock_forward_context.with_quant) - self.assertEqual(result.shape, hidden_states.shape) + self.assertEqual(result.shape, hidden_states_shape) self.assertEqual(result.dtype, torch.bfloat16) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index cc19b925..14b615ba 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -289,7 +289,7 @@ class AscendFusedMoE(FusedMoE): self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - setup_moe_comm_method(self.moe_config) + setup_moe_comm_method(self.moe_config, self.quant_method) def update_expert_map(self, new_expert_map): self.expert_map = new_expert_map @@ -336,11 +336,17 @@ class AscendFusedMoE(FusedMoE): replace_allreduce=forward_context.sp_enabled, enable_shared_expert_dp=self.enable_shared_expert_dp) + if isinstance(hidden_states, tuple): + hidden_states, pertoken_scale = hidden_states + else: + pertoken_scale = None + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, x=hidden_states, router_logits=router_logits, + pertoken_scale=pertoken_scale, top_k=self.top_k, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index ca1e71c7..6094aafb 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -27,10 +27,14 @@ from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.prepare_finalize import ( PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather, - PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast) + PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast, QuantType) from vllm_ascend.ops.fused_moe.token_dispatcher import ( TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, TokenDispatcherWithMC2, TokenDispatcherWithMoge) +from vllm_ascend.quantization.w4a8_dynamic import \ + AscendW4A8DynamicFusedMoEMethod +from vllm_ascend.quantization.w8a8_dynamic import \ + AscendW8A8DynamicFusedMoEMethod _MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} @@ -40,25 +44,43 @@ def get_moe_comm_method( return _MoECommMethods.get(moe_comm_type, None) -def setup_moe_comm_method(moe_config): - _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config) - _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config) - _MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config) +def setup_moe_comm_method(moe_config, quant_method): + _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl( + moe_config, quant_method) + _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl( + moe_config, quant_method) + _MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config, quant_method) _MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl( - moe_config) + moe_config, quant_method) class MoECommMethod(ABC): """Base class for MoE communication methods.""" - def __init__(self, moe_config: FusedMoEConfig): + def __init__(self, moe_config: FusedMoEConfig, quant_method=None): self.model_type = get_current_vllm_config( ).model_config.hf_config.model_type self.moe_config = moe_config self.token_dispatcher = self._get_token_dispatcher() + self.quant_type = self._get_quant_type(quant_method) + self.with_quant = self.quant_type != QuantType.NONE self.prepare_finalize = self._get_prepare_finalize() + def _get_quant_type(self, quant_method) -> QuantType: + if not hasattr(quant_method, + "quant_method") or quant_method.quant_method is None: + return QuantType.NONE + + method = quant_method.quant_method + + if isinstance(method, AscendW8A8DynamicFusedMoEMethod): + return QuantType.W8A8 + elif isinstance(method, AscendW4A8DynamicFusedMoEMethod): + return QuantType.W4A8 + else: + return QuantType.NONE + def prepare( self, hidden_states: torch.Tensor, @@ -90,8 +112,6 @@ class MoECommMethod(ABC): topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, - use_int8_w8a8: bool = False, - use_int4_w4a8: bool = False, global_num_experts: Optional[int] = None, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -109,10 +129,11 @@ class MoECommMethod(ABC): global_redundant_expert_num: int = 0, need_trans: bool = False, dynamic_eplb: bool = False, - mc2_mask: torch.Tensor = None): + mc2_mask: torch.Tensor = None, + pertoken_scale: Optional[torch.Tensor] = None): # Check constraints assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 + torch.float32, torch.float16, torch.bfloat16, torch.int8 ] moe_comm_method = get_forward_context().moe_comm_method @@ -130,28 +151,29 @@ class MoECommMethod(ABC): dynamic_scale_for_share=dynamic_scale_for_share, mc2_mask=mc2_mask, apply_router_weight_on_input=apply_router_weight_on_input, - with_quant=use_int8_w8a8 or use_int4_w4a8, - dynamic_eplb=dynamic_eplb) + with_quant=self.with_quant, + dynamic_eplb=dynamic_eplb, + pertoken_scale=pertoken_scale) permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \ results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata") - mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=expert_tokens, - dynamic_scale=dynamic_scale, - group_list_type=group_list_type, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias, - topk_scales=topk_scales, - with_quant=use_int8_w8a8 - or use_int4_w4a8, - fusion=use_int8_w8a8, - need_trans=need_trans, - dynamic_eplb=dynamic_eplb) + mlp_output = unified_apply_mlp( + hidden_states=permuted_hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=expert_tokens, + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + topk_scales=topk_scales, + with_quant=self.with_quant, + fusion=self.quant_type == QuantType.W8A8, + need_trans=need_trans, + dynamic_eplb=dynamic_eplb) final_hidden_states = self.token_dispatcher.token_combine( hidden_states=mlp_output, context_metadata=context_metadata) @@ -204,7 +226,8 @@ class AllGatherCommImpl(MoECommMethod): num_local_experts=self.moe_config.num_local_experts) def _get_prepare_finalize(self): - return PrepareAndFinalizeWithAllGather(self.moe_config) + return PrepareAndFinalizeWithAllGather(self.moe_config, + self.quant_type) class MC2CommImpl(MoECommMethod): @@ -221,7 +244,7 @@ class MC2CommImpl(MoECommMethod): return TokenDispatcherWithMC2() def _get_prepare_finalize(self): - return PrepareAndFinalizeWithMC2(self.moe_config) + return PrepareAndFinalizeWithMC2(self.moe_config, self.quant_type) class AlltoAllCommImpl(MoECommMethod): @@ -241,7 +264,7 @@ class AlltoAllCommImpl(MoECommMethod): num_local_experts=self.moe_config.num_local_experts) def _get_prepare_finalize(self): - return PrepareAndFinalizeWithAll2All(self.moe_config) + return PrepareAndFinalizeWithAll2All(self.moe_config, self.quant_type) class NaiveMulticastCommImpl(MoECommMethod): @@ -270,4 +293,5 @@ class NaiveMulticastCommImpl(MoECommMethod): num_local_experts=self.moe_config.num_local_experts) def _get_prepare_finalize(self): - return PrepareAndFinalizeWithNaiveMulticast(self.moe_config) + return PrepareAndFinalizeWithNaiveMulticast(self.moe_config, + self.quant_type) diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 5ee7d70d..0e2b81fb 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -72,8 +72,10 @@ def quant_apply_mlp(hidden_states: torch.Tensor, # Dispose the original unquantized hidden states # to save npu memory because they're no longer used. dispose_tensor(unquantized_hidden_states) + quantized_hidden_states = None else: pertoken_scale = dynamic_scale + quantized_hidden_states = hidden_states bias1, bias2 = None, None _output_dtype = w2_scale.dtype @@ -92,6 +94,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_list=cumsum_group_list(group_list, group_list_type), weight_scale=w1_scale, x_scale=pertoken_scale) + if quantized_hidden_states is not None: + dispose_tensor(quantized_hidden_states) else: if w1_scale.dtype != torch.float32: w1_scale = w1_scale.to(torch.float32) @@ -104,6 +108,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_type=0, group_list=group_list, output_dtype=torch.int32)[0] + if quantized_hidden_states is not None: + dispose_tensor(quantized_hidden_states) # act_fn: swiglu hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( x=hidden_states, @@ -148,6 +154,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_list=cumsum_group_list(group_list, group_list_type), weight_scale=w1_scale, x_scale=pertoken_scale) + if quantized_hidden_states is not None: + dispose_tensor(quantized_hidden_states) else: # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( @@ -161,6 +169,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_type=0, group_list=group_list, output_dtype=_output_dtype)[0] + if quantized_hidden_states is not None: + dispose_tensor(quantized_hidden_states) # act_fn: swiglu hidden_states = torch_npu.npu_swiglu(hidden_states) hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index 1e06d841..f158a4bf 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -15,11 +15,13 @@ # This file is a part of the vllm-ascend project. from abc import ABC, abstractmethod +from enum import Enum from typing import Optional import torch import torch.distributed as dist import torch.nn as nn +import torch_npu from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( get_dp_group, get_tensor_model_parallel_rank, @@ -30,6 +32,12 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm_ascend.utils import enable_sp +class QuantType(Enum): + NONE = 0 + W8A8 = 1 + W4A8 = 2 + + class PrepareAndFinalize(ABC): """ Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization @@ -42,8 +50,11 @@ class PrepareAndFinalize(ABC): sizes, ranks, and communication settings. """ - def __init__(self, moe_config: FusedMoEConfig): + def __init__(self, + moe_config: FusedMoEConfig, + quant_type: QuantType = QuantType.NONE): self.moe_config = moe_config + self.quant_type = quant_type @abstractmethod def prepare( @@ -103,8 +114,10 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize): Will be used when num_tokens exceed mc2's limitation (512 tokens/rank). """ - def __init__(self, moe_config: FusedMoEConfig): - super().__init__(moe_config) + def __init__(self, + moe_config: FusedMoEConfig, + quant_type: QuantType = QuantType.NONE): + super().__init__(moe_config, quant_type) self._restore_tp_across_dp() def _restore_tp_across_dp(self): @@ -195,8 +208,10 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All): Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment. """ - def __init__(self, moe_config: FusedMoEConfig): - super().__init__(moe_config) + def __init__(self, + moe_config: FusedMoEConfig, + quant_type: QuantType = QuantType.NONE): + super().__init__(moe_config, quant_type) self._restore_tp_across_dp() def _restore_tp_across_dp(self): @@ -316,11 +331,20 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): router_logits: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + pertoken_scale = None + if self.quant_type == QuantType.W8A8: + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + pertoken_scale, True, True) hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states, True, True) router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( router_logits, True, True) + if pertoken_scale is not None: + return (hidden_states, pertoken_scale), router_logits, None, None + return hidden_states, router_logits, None, None def _prepare_with_dp_group( diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index b3245013..077163c5 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -57,20 +57,23 @@ class MoETokenDispatcher(ABC): return get_ep_group().world_size @abstractmethod - def token_dispatch(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - dynamic_eplb: bool = False): + def token_dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + dynamic_eplb: bool = False, + pertoken_scale: Optional[torch.Tensor] = None, + ): raise NotImplementedError("Dispatch function not implemented.") @abstractmethod @@ -170,7 +173,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False, - dynamic_eplb: bool = False): + dynamic_eplb: bool = False, + pertoken_scale: Optional[torch.Tensor] = None): self.with_quant = with_quant # Apply log2phy if needed @@ -339,7 +343,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False, - dynamic_eplb: bool = False): + dynamic_eplb: bool = False, + pertoken_scale: Optional[torch.Tensor] = None): self.with_quant = with_quant self.original_shape = hidden_states.shape @@ -370,12 +375,14 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): torch_npu.npu_moe_init_routing_v2( hidden_states, topk_ids, + scale=pertoken_scale, active_num=num_tokens * self.top_k, expert_num=global_num_experts, expert_tokens_num_type=1, expert_tokens_num_flag=True, active_expert_range=[first_expert_idx, last_expert_idx], - quant_mode=1 if self.with_quant else -1, + quant_mode=1 + if self.with_quant and pertoken_scale is None else -1, )) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 1 # `count` mode @@ -430,7 +437,8 @@ class TokenDispatcherWithMoge(MoETokenDispatcher): mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False, - dynamic_eplb: bool = False): + dynamic_eplb: bool = False, + pertoken_scale: Optional[torch.Tensor] = None): self.bsz, _ = hidden_states.shape flatten_topk_ids = topk_ids.view(-1) self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) @@ -518,7 +526,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False, - dynamic_eplb: bool = False): + dynamic_eplb: bool = False, + pertoken_scale: Optional[torch.Tensor] = None): self.with_quant = with_quant self.hidden_shape = hidden_states.shape diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index c4b410d4..6a3057d9 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -36,7 +36,7 @@ def _maybe_all_gather_and_maybe_unpad_impl( x = tensor_model_parallel_all_gather(x, 0) pad_size = forward_context.pad_size if pad_size > 0: - x = x[:-pad_size, :] + x = x[:-pad_size] else: x = get_ep_group().all_gather(x, 0) # unpad @@ -50,8 +50,7 @@ def _maybe_all_gather_and_maybe_unpad_impl( offset = 0 for idx in range(dp_size): num_tokens_dp = num_tokens_across_dp_cpu[idx] - result[offset:offset + - num_tokens_dp, :] = x[idx, :num_tokens_dp, :] + result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp] offset += num_tokens_dp x = result diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 77f0f4b2..cd889a04 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -386,7 +386,6 @@ class AscendW4A8DynamicFusedMoEMethod: w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, topk_ids=topk_ids, - use_int4_w4a8=True, expert_map=expert_map, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 1a9d0b5c..b5f10c4d 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -143,6 +143,7 @@ class AscendW8A8DynamicFusedMoEMethod: and not ascend_config.torchair_graph_config.enabled) self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path + self.in_dtype = vllm_config.model_config.dtype try: device_group = get_mc2_group().device_group @@ -218,6 +219,7 @@ class AscendW8A8DynamicFusedMoEMethod: shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, + pertoken_scale: Optional[Any] = None, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -242,18 +244,18 @@ class AscendW8A8DynamicFusedMoEMethod: if enable_force_load_balance: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - topk_weights = topk_weights.to(x.dtype) + topk_weights = topk_weights.to(self.in_dtype) moe_comm_method = get_forward_context().moe_comm_method return moe_comm_method.fused_experts( hidden_states=x, + pertoken_scale=pertoken_scale, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale_fp32, w2=layer.w2_weight, w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, - use_int8_w8a8=True, expert_map=expert_map, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num,