diff --git a/tests/e2e/multicard/moe/test_moe_comm.py b/tests/e2e/multicard/moe/test_moe_comm.py index 2b09f57..d9ace12 100644 --- a/tests/e2e/multicard/moe/test_moe_comm.py +++ b/tests/e2e/multicard/moe/test_moe_comm.py @@ -33,6 +33,7 @@ from vllm_ascend.distributed.moe_comm_method import ( # isort: skip @pytest.mark.parametrize("top_k_num", [2, 4]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("ep_rank", [0, 1]) +@pytest.mark.parametrize("apply_a8_quantization", [False]) def test_all_gather_comm_impl( num_tokens, hidden_size, @@ -41,6 +42,7 @@ def test_all_gather_comm_impl( top_k_num, dtype, ep_rank, + apply_a8_quantization, mocker, ): """ @@ -118,8 +120,9 @@ def test_all_gather_comm_impl( native_permuted_hidden, native_expert_tokens, _, + _, ) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map, - num_experts) + num_experts, apply_a8_quantization) # Simulate MLP output native_mlp_output = torch.randn_like(native_permuted_hidden) native_impl.unpermute(native_mlp_output, native_hidden_states_out) @@ -130,8 +133,9 @@ def test_all_gather_comm_impl( all_gather_permuted_hidden, all_gather_expert_tokens, _, + _, ) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights, - expert_map, num_experts) + expert_map, num_experts, apply_a8_quantization) # Use the same simulated MLP output for a fair comparison all_gather_mlp_output = native_mlp_output.clone() diff --git a/tests/e2e/multicard/test_qwen3_moe.py b/tests/e2e/multicard/test_qwen3_moe.py index a17de55..45f1b6e 100644 --- a/tests/e2e/multicard/test_qwen3_moe.py +++ b/tests/e2e/multicard/test_qwen3_moe.py @@ -107,4 +107,4 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH(): tensor_parallel_size=2, enforce_eager=False, ) as vllm_model: - vllm_model.generate_greedy(example_prompts, max_tokens) \ No newline at end of file + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm_ascend/distributed/moe_comm_method.py b/vllm_ascend/distributed/moe_comm_method.py index ea32495..aa9bae8 100644 --- a/vllm_ascend/distributed/moe_comm_method.py +++ b/vllm_ascend/distributed/moe_comm_method.py @@ -54,7 +54,8 @@ class MoECommMethod(ABC): topk_weights: torch.Tensor, expert_map: torch.Tensor, num_experts: int, - ) -> tuple[torch.Tensor, torch.Tensor, int]: + apply_a8_quantization: bool, + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: """Pre-process before MLP. Args: @@ -64,6 +65,7 @@ class MoECommMethod(ABC): expert_map (torch.Tensor): Tensor of shape (global_num_experts, ) Mapping from global expert IDs to local expert IDs. num_experts (int): Number of local experts (experts on this device). + apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8). Returns: tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing: @@ -72,6 +74,8 @@ class MoECommMethod(ABC): hidden_states based on topk_ids. - expert_tokens (torch.Tensor): Tensor of shape (num_experts, ) Number of tokens assigned to each expert. + - dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, ) + Dynamic scale for each expert, used for quantization. - group_list_type (int): Type of group list, 0 for `cumsum` and 1 for `count`. This is mainly for `npu_grouped_matmul` to determine how to handle the output. @@ -159,7 +163,8 @@ class AllGatherCommImpl(MoECommMethod): topk_weights: torch.Tensor, expert_map: torch.Tensor, # noqa: F841 num_experts: int, - ) -> tuple[torch.Tensor, torch.Tensor, int]: + apply_a8_quantization: bool, + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: num_tokens = hidden_states.shape[0] self.topk_weights = topk_weights @@ -194,7 +199,7 @@ class AllGatherCommImpl(MoECommMethod): group_list_type = 1 # `count` mode - return permuted_hidden_states, expert_tokens, group_list_type + return permuted_hidden_states, expert_tokens, None, group_list_type def unpermute(self, mlp_output: torch.Tensor, hidden_states: torch.Tensor) -> None: @@ -219,7 +224,8 @@ class NativeAllGatherCommImpl(AllGatherCommImpl): topk_weights: torch.Tensor, expert_map: torch.Tensor, num_experts: int, - ) -> tuple[torch.Tensor, torch.Tensor, int]: + apply_a8_quantization: bool, + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: num_tokens = hidden_states.shape[0] # Generate token indices and flatten @@ -269,7 +275,7 @@ class NativeAllGatherCommImpl(AllGatherCommImpl): group_list_type = 1 # `count` mode - return permuted_hidden_states, expert_tokens, group_list_type + return permuted_hidden_states, expert_tokens, None, group_list_type def unpermute(self, mlp_output: torch.Tensor, hidden_states: torch.Tensor) -> None: @@ -375,7 +381,8 @@ class MC2CommImpl(MoECommMethod): topk_weights: torch.Tensor, expert_map: torch.Tensor, num_experts: int, - ) -> tuple[torch.Tensor, torch.Tensor, int]: + apply_a8_quantization: bool, + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: # Store tensors needed for post_process self.topk_ids = topk_ids self.topk_weights = topk_weights.to(torch.float32) @@ -388,7 +395,7 @@ class MC2CommImpl(MoECommMethod): "moe_expert_num": self.moe_config.num_experts, "global_bs": 0, "scales": None, - "quant_mode": 0, + "quant_mode": 2 if apply_a8_quantization else 0, "group_ep": self.mc2_comm_name, "ep_world_size": self.moe_config.ep_size, "ep_rank_id": self.moe_config.ep_rank, @@ -409,7 +416,7 @@ class MC2CommImpl(MoECommMethod): ( permuted_hidden_states, - _, # dynamic_scale is not used + dynamic_scale, self.assist_info_for_combine, expert_tokens, self.ep_recv_counts, @@ -418,7 +425,7 @@ class MC2CommImpl(MoECommMethod): group_list_type = 1 - return permuted_hidden_states, expert_tokens, group_list_type + return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type def unpermute(self, mlp_output: torch.Tensor, hidden_states: torch.Tensor) -> None: @@ -457,3 +464,93 @@ class MC2CommImpl(MoECommMethod): combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine hidden_states[:] = combine(**combine_kwargs) + + +class AlltoAllCommImpl(MoECommMethod): + """This implementation is for the scenarios listed below: + 1. `enable_expert_parallel=True`. + 2. `npu_grouped_matmul` is available. + + This implementation uses all-to-all communication to exchange tokens + between data parallel ranks before and after the MLP computation. It should + have better performance than AllGatherCommImpl when DP size > 1. + """ + + def __init__(self, moe_config: Optional[FusedMoEConfig]): + super().__init__(moe_config) + from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ + get_token_dispatcher + self.token_dispatcher = get_token_dispatcher( + "TokenDispatcherWithAll2AllV") + self._restore_tp_across_dp() + + def _restore_tp_across_dp(self): + # NOTE: Since vLLM flatten tp across dp, we need to restore the original + # tp_size and tp_rank. + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + def prepare( + self, hidden_states: torch.Tensor, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + self.num_tokens, _ = hidden_states.shape + pad_size = self.tp_size - self.num_tokens + + if pad_size > 0: + hidden_states = nn.functional.pad(hidden_states, + (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, + (0, 0, 0, pad_size)) + + if self.tp_size > 1: + split_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + split_router_logits = torch.tensor_split(router_logits, + self.tp_size, + dim=0) + self.split_hidden_states = split_hidden_states + + hidden_states = split_hidden_states[self.tp_rank] + router_logits = split_router_logits[self.tp_rank] + + return hidden_states, router_logits + + def finalize(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """If TP size > 1, all-gather the hidden states to get the final output. + + Also, unpad the hidden states if needed. + """ + if self.tp_size > 1: + dist.all_gather(list(self.split_hidden_states), hidden_states, + self.moe_config.tp_group.device_group) + hidden_states = torch.cat(self.split_hidden_states, dim=0) + + if self.num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:self.num_tokens] + + return hidden_states + + def permute( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_map: torch.Tensor, + num_experts: int, + apply_a8_quantization: bool, + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: + results = self.token_dispatcher.token_dispatch( + hidden_states, + topk_weights, + topk_ids, + None, + log2phy=None, + with_quant=apply_a8_quantization) + return results["hidden_states"], results["group_list"], results[ + "dynamic_scale"], results["group_list_type"] + + def unpermute(self, mlp_output: torch.Tensor, + hidden_states: torch.Tensor) -> None: + hidden_states[:] = self.token_dispatcher.token_combine(mlp_output) diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index dd38c23..a44ab68 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Optional import torch +import torch_npu from vllm.config import CompilationLevel, get_current_vllm_config from vllm.distributed import get_dp_group, get_ep_group, get_tp_group from vllm.forward_context import get_forward_context @@ -26,12 +27,14 @@ from vllm.model_executor.layers.fused_moe.layer import ( from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl, - MC2CommImpl, - MoECommMethod) + AlltoAllCommImpl, + MC2CommImpl) from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.fused_moe import apply_mlp, fused_experts_moge +from vllm_ascend.ops.fused_moe import fused_experts_moge from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.utils import is_310p +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ + setup_token_dispatchers +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ @@ -52,7 +55,6 @@ def fused_experts( w2_scale: Optional[torch.Tensor] = None, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, - moe_comm_method: Optional[MoECommMethod] = None, # For TorchAir graph is_torchair: bool = False, # For Cube/Vector parallel @@ -64,9 +66,8 @@ def fused_experts( global_redundant_expert_num: int = 0, ) -> torch.Tensor: # Check constraints - assert hidden_states.shape[1] == w1.shape[2], ( - f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}") - + assert hidden_states.shape[1] == w1.shape[1], ( + f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}") assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" @@ -74,31 +75,79 @@ def fused_experts( assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16 ] + if (use_int8_w8a8 or use_int4_w4a8): + assert w1_scale is not None and w2_scale is not None, \ + "INT8 quantization requires weight scales." + + w1_scale = w1_scale.to(torch.float32) + down_scale = [w2_scale] + down_output_dtype = w2_scale.dtype + else: + down_scale = None + down_output_dtype = None + + moe_comm_method = get_forward_context().moe_comm_method assert moe_comm_method is not None, "Missing communication context" num_experts = w1.shape[0] - permuted_hidden_states, expert_tokens, group_list_type = moe_comm_method.permute( - hidden_states, topk_ids, topk_weights, expert_map, num_experts) - mlp_output = apply_mlp( - permuted_hidden_states, - w1, - w2, - expert_tokens, + permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute( + hidden_states, topk_ids, topk_weights, expert_map, num_experts, + use_int8_w8a8 or use_int4_w4a8) + + gate_up_output = torch_npu.npu_grouped_matmul( + x=[permuted_hidden_states], + weight=[w1], + split_item=2, group_list_type=group_list_type, - ) - moe_comm_method.unpermute(mlp_output, hidden_states) + group_type=0, + group_list=expert_tokens, + output_dtype=torch.int32 if use_int8_w8a8 else None, + )[0] + + if (use_int8_w8a8 or use_int4_w4a8): + activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant( + x=gate_up_output, + weight_scale=w1_scale, + activation_scale=dynamic_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=expert_tokens, + activate_left=True, + quant_mode=1, + ) + activated_output_scale = [activated_output_scale] + else: + activated_output = torch_npu.npu_swiglu(gate_up_output) + activated_output_scale = None + + down_output = torch_npu.npu_grouped_matmul( + x=[activated_output], + weight=[w2], + scale=down_scale, + per_token_scale=activated_output_scale, + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=expert_tokens, + output_dtype=down_output_dtype, + )[0] + + moe_comm_method.unpermute(down_output, hidden_states) return hidden_states def unquantized_fused_moe_init_func(self, *args, **kwargs): original_unquantized_fused_moe_init_func(self, *args, **kwargs) + + # NOTE: Currently, this self.use_aclgraph is only used in + # UnquantizedFusedMoEMethod.forward_oot to decide whether to use in + # ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue. + # Once torch.randint_like is supported or removed, this flag can be removed. vllm_config = get_current_vllm_config() - self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens - ascend_config = get_ascend_config() - if ascend_config.torchair_graph_config.enabled: self.use_aclgraph = False else: @@ -156,8 +205,6 @@ def forward_oot( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) - moe_comm_method = get_forward_context().moe_comm_method - return fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -166,10 +213,26 @@ def forward_oot( topk_ids=topk_ids, global_num_experts=global_num_experts, expert_map=expert_map, - moe_comm_method=moe_comm_method, ) +def process_weights_after_loading(self, layer): + super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer) + w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( + 1, 2).contiguous() + layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) + + w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( + 1, 2).contiguous() + layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) + + if not is_310p(): + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) + + class AscendFusedMoE(FusedMoE): def __init__( @@ -224,12 +287,17 @@ class AscendFusedMoE(FusedMoE): has_bias, ) + setup_token_dispatchers(self.moe_config.ep_size, + top_k=self.top_k, + num_experts=self.global_num_experts, + num_local_experts=self.local_num_experts) + self.moe_config.tp_group = get_tp_group() self.moe_config.dp_group = get_dp_group() self.moe_config.ep_group = get_ep_group() self.moe_config.mc2_group = get_mc2_group() - for method in {AllGatherCommImpl, MC2CommImpl}: + for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}: setattr( self, method.__name__.lower(), method(moe_config=self.moe_config)) # type: ignore[abstract] @@ -282,4 +350,5 @@ class AscendFusedMoE(FusedMoE): UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func +UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading UnquantizedFusedMoEMethod.forward_oot = forward_oot diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 5f85b36..a84c104 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -230,7 +230,6 @@ def fused_experts_moge( 0, sorted_topk_ids).unsqueeze(-1) group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) - w1 = w1.transpose(1, 2) gate_up_out = torch_npu.npu_grouped_matmul( x=[sorted_hidden_states], weight=[w1], @@ -247,7 +246,6 @@ def fused_experts_moge( gate_up_out = torch_npu.npu_swiglu(gate_up_out) gate_up_out *= topk_scales - w2 = w2.transpose(1, 2) down_out_list = torch_npu.npu_grouped_matmul( x=[gate_up_out], weight=[w2], diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 8438b33..20c68be 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -19,12 +19,16 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import torch import torch_npu +from vllm.config import CompilationLevel, get_current_vllm_config from vllm.distributed import get_ep_group from vllm.forward_context import get_forward_context import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.ops.common_fused_moe import \ + fused_experts as unified_fused_experts from vllm_ascend.ops.fused_moe import unified_fused_experts_eager from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor @@ -283,6 +287,13 @@ class AscendW8A8DynamicFusedMoEMethod: self.ep_group = get_ep_group() + vllm_config = get_current_vllm_config() + ascend_config = get_ascend_config() + self.use_aclgraph = ( + vllm_config.compilation_config.level == CompilationLevel.PIECEWISE + and not vllm_config.model_config.enforce_eager + and not ascend_config.torchair_graph_config.enabled) + try: device_group = get_mc2_group().device_group # TODO: Try local_rank = ep_group.rank_in_group @@ -375,6 +386,19 @@ class AscendW8A8DynamicFusedMoEMethod: e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) + if self.use_aclgraph: + return unified_fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_int8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + expert_map=expert_map, + ) + fused_moe_state = get_forward_context().fused_moe_state shared_gate_up, shared_dequant_scale = None, None if shared_experts is not None and fused_moe_state == FusedMoEState.MC2: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 007a1c5..7a9fe1b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -89,7 +89,8 @@ from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, - ProfileExecuteDuration, is_310p, + AscendSocVersion, ProfileExecuteDuration, + get_ascend_soc_version, is_310p, lmhead_tp_enable, vllm_version_is) from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer @@ -1620,8 +1621,22 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) def _select_moe_comm_method(self, num_tokens: int) -> str: - return ("mc2" - if num_tokens <= self.mc2_tokens_capacity else "allgather") + soc_version = get_ascend_soc_version() + + if num_tokens <= self.mc2_tokens_capacity: + moe_comm_method = "mc2" + elif soc_version in {AscendSocVersion.A2}: + moe_comm_method = "allgather" + elif soc_version in {AscendSocVersion.A3}: + moe_comm_method = "alltoall" + else: + raise ValueError(f"Unsupported soc_version: {soc_version}") + + if is_global_first_rank(): + logger.debug(f"num_tokens: {num_tokens}, " + f"moe_comm_method: {moe_comm_method}") + + return moe_comm_method @torch.inference_mode() def execute_model(