From 65b7f716e6cf2d9c9d9634af2c10926f6af98ce7 Mon Sep 17 00:00:00 2001 From: SILONG ZENG <2609716663@qq.com> Date: Fri, 6 Feb 2026 15:28:49 +0800 Subject: [PATCH] [Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #11) (#6176) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/fused_moe/comm_utils.py` | | `vllm_ascend/ops/fused_moe/experts_selector.py` | | `vllm_ascend/ops/fused_moe/fused_moe.py` | | `vllm_ascend/ops/fused_moe/moe_comm_method.py` | | `vllm_ascend/ops/fused_moe/moe_mlp.py` | | `vllm_ascend/ops/fused_moe/prepare_finalize.py` | | `vllm_ascend/ops/fused_moe/token_dispatcher.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com> Signed-off-by: SILONG ZENG <2609716663@qq.com> --- pyproject.toml | 3 - vllm_ascend/ops/fused_moe/comm_utils.py | 43 +- vllm_ascend/ops/fused_moe/experts_selector.py | 162 ++++---- vllm_ascend/ops/fused_moe/fused_moe.py | 292 +++++++------- vllm_ascend/ops/fused_moe/moe_comm_method.py | 193 ++++----- vllm_ascend/ops/fused_moe/moe_mlp.py | 242 ++++++------ vllm_ascend/ops/fused_moe/prepare_finalize.py | 173 +++----- vllm_ascend/ops/fused_moe/token_dispatcher.py | 370 +++++++++--------- 8 files changed, 694 insertions(+), 784 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b44570a8..9cef9b2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,9 +62,6 @@ exclude = [ "vllm_ascend/worker/v2/**", "vllm_ascend/worker/npu_input_batch.py", "vllm_ascend/ops/rotary_embedding.py", - - # (11) - "vllm_ascend/ops/fused_moe/**", ] [tool.ruff.lint] diff --git a/vllm_ascend/ops/fused_moe/comm_utils.py b/vllm_ascend/ops/fused_moe/comm_utils.py index b8952a95..cd279cf0 100644 --- a/vllm_ascend/ops/fused_moe/comm_utils.py +++ b/vllm_ascend/ops/fused_moe/comm_utils.py @@ -23,11 +23,7 @@ import torch_npu COMM_STREAM = None -def async_all_to_all(input_, - output_split_sizes, - input_split_sizes, - group, - event=None): +def async_all_to_all(input_, output_split_sizes, input_split_sizes, group, event=None): if output_split_sizes is None: # Equal split (all2all) a2a_out = torch.empty_like(input_) @@ -43,8 +39,7 @@ def async_all_to_all(input_, # multi stream wait event global COMM_STREAM if COMM_STREAM is None: - COMM_STREAM = torch_npu.npu.Stream( - device=torch.npu.current_device()) + COMM_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device()) with torch_npu.npu.stream(COMM_STREAM): event.wait() handle = dist.all_to_all_single( @@ -53,14 +48,17 @@ def async_all_to_all(input_, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=group, - async_op=True) + async_op=True, + ) else: - handle = dist.all_to_all_single(a2a_out, - input_.contiguous(), - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=True) + handle = dist.all_to_all_single( + a2a_out, + input_.contiguous(), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True, + ) return input_, a2a_out, handle @@ -86,19 +84,12 @@ def _gather_along_first_dim(input_, group, output_split_sizes=None): if output_split_sizes is None: dim_size[0] = dim_size[0] * world_size - output = torch.empty(dim_size, - dtype=input_.dtype, - device=torch.npu.current_device()) - torch.distributed.all_gather_into_tensor(output, - input_.contiguous(), - group=group) + output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device()) + torch.distributed.all_gather_into_tensor(output, input_.contiguous(), group=group) else: dim_size[0] = sum(output_split_sizes) - output = torch.empty(dim_size, - dtype=input_.dtype, - device=torch.npu.current_device()) - output_tensor_list = list( - torch.split(output, output_split_sizes, dim=0)) + output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device()) + output_tensor_list = list(torch.split(output, output_split_sizes, dim=0)) torch.distributed.all_gather(output_tensor_list, input_, group=group) return output @@ -110,4 +101,4 @@ def gather_from_sequence_parallel_region( output_split_sizes=None, ): """Wrapper for autograd function: forward: AG, backward: RS """ - return _gather_along_first_dim(input_, group, output_split_sizes) \ No newline at end of file + return _gather_along_first_dim(input_, group, output_split_sizes) diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index 07b611e7..d9775fbe 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -14,26 +14,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Callable, Optional +from collections.abc import Callable import torch from vllm_ascend.utils import get_weight_prefetch_method -def select_experts(hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - routed_scaling_factor=1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, - indices_type: Optional[torch.dtype] = None, - global_num_experts: int = -1): +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor=1.0, + e_score_correction_bias: torch.Tensor | None = None, + indices_type: torch.dtype | None = None, + global_num_experts: int = -1, +): """ Fused experts with select experts. @@ -58,8 +60,7 @@ def select_experts(hidden_states: torch.Tensor, # prefetch w1_w3_proj.weight preprocess weight_prefetch_method = get_weight_prefetch_method() if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( - hidden_states, "gate_up") + weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up") is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k( hidden_states=hidden_states, top_k=top_k, @@ -67,7 +68,8 @@ def select_experts(hidden_states: torch.Tensor, topk_group=topk_group, num_expert_group=num_expert_group, scoring_func=scoring_func, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + ) if is_support_npu_moe_gating_top_k: topk_weights, topk_ids = _select_experts_with_fusion_ops( @@ -81,7 +83,8 @@ def select_experts(hidden_states: torch.Tensor, num_expert_group=num_expert_group, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, - global_num_experts=global_num_experts) + global_num_experts=global_num_experts, + ) else: topk_weights, topk_ids = _native_select_experts( hidden_states=hidden_states, @@ -100,14 +103,15 @@ def select_experts(hidden_states: torch.Tensor, def check_npu_moe_gating_top_k( - hidden_states: torch.Tensor, - top_k: int, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - scoring_func: str = "softmax", - custom_routing_function: Optional[Callable] = None): - if scoring_func == "sigmoid" and not renormalize: #sigmoid + renorm=0 is not supported in current branch + hidden_states: torch.Tensor, + top_k: int, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + scoring_func: str = "softmax", + custom_routing_function: Callable | None = None, +): + if scoring_func == "sigmoid" and not renormalize: # sigmoid + renorm=0 is not supported in current branch return False if custom_routing_function is not None: return False @@ -115,39 +119,39 @@ def check_npu_moe_gating_top_k( return False topk_group = topk_group if topk_group is not None else 1 num_expert_group = num_expert_group if num_expert_group is not None else 1 - if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group - == 0 and hidden_states.shape[-1] // num_expert_group > 2): + if not ( + num_expert_group > 0 + and hidden_states.shape[-1] % num_expert_group == 0 + and hidden_states.shape[-1] // num_expert_group > 2 + ): return False if topk_group < 1 or topk_group > num_expert_group: return False - if top_k < 1 or \ - top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): + if top_k < 1 or top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)): return False - if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: + if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: # noqa: SIM103 return False return True def _native_grouped_topk( topk_weights: torch.Tensor, - num_expert_group: Optional[int], - topk_group: Optional[int], + num_expert_group: int | None, + topk_group: int | None, ): topk_group = 0 if topk_group is None else topk_group num_expert_group = 0 if num_expert_group is None else num_expert_group num_token = topk_weights.shape[0] - grouped_weights = topk_weights.view(num_token, num_expert_group, - -1).max(dim=-1).values - topk_group_indices = torch.topk(grouped_weights.to(torch.float32), - k=topk_group, - dim=-1, - sorted=False)[1] + grouped_weights = topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values + topk_group_indices = torch.topk(grouped_weights.to(torch.float32), k=topk_group, dim=-1, sorted=False)[1] topk_group_mask = torch.zeros_like(grouped_weights) topk_group_mask.scatter_(1, topk_group_indices, 1) - topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) + topk_weight_mask = ( + topk_group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, topk_weights.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) return topk_weights @@ -163,9 +167,13 @@ def _renormalize_topk_weights( def _select_expert_use_group_topk( - topk_weights: torch.Tensor, topk_group: Optional[int], - renormalize: bool, top_k: int, num_expert_group: Optional[int], - e_score_correction_bias: Optional[torch.Tensor]): + topk_weights: torch.Tensor, + topk_group: int | None, + renormalize: bool, + top_k: int, + num_expert_group: int | None, + e_score_correction_bias: torch.Tensor | None, +): assert topk_group is not None assert num_expert_group is not None @@ -177,47 +185,38 @@ def _select_expert_use_group_topk( # TODO: Change to npu_group_topk when the latest CANN and NNAL is available # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) - topk_weights = _native_grouped_topk(topk_weights, num_expert_group, - topk_group) + topk_weights = _native_grouped_topk(topk_weights, num_expert_group, topk_group) # TODO bfloat16 is not supported in torch.topk with ge graph. if e_score_correction_bias is not None: - topk_ids = torch.topk(topk_weights.to(torch.float32), - k=top_k, - dim=-1, - sorted=False)[1] + topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)[1] # Use original unbiased scores for the routing weights topk_weights = original_weights.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), - k=top_k, - dim=-1, - sorted=False) + topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False) topk_ids = topk_ids.to(torch.int32) topk_weights = _renormalize_topk_weights(topk_weights, renormalize) return topk_weights, topk_ids def _select_experts_with_fusion_ops( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - e_score_correction_bias: Optional[torch.Tensor], - topk_group: Optional[int], - num_expert_group: Optional[int], - scoring_func: str = "softmax", - routed_scaling_factor=1.0, - global_num_experts: int = -1): - + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + e_score_correction_bias: torch.Tensor | None, + topk_group: int | None, + num_expert_group: int | None, + scoring_func: str = "softmax", + routed_scaling_factor=1.0, + global_num_experts: int = -1, +): topk_group = topk_group if topk_group is not None else 1 num_expert_group = num_expert_group if num_expert_group is not None else 1 renorm = int(renormalize) norm_type = 0 if scoring_func == "softmax" else 1 - if e_score_correction_bias is not None and \ - e_score_correction_bias.dtype != router_logits.dtype: - e_score_correction_bias = e_score_correction_bias.to( - router_logits.dtype) + if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype: + e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype) topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k( router_logits, k=top_k, @@ -228,7 +227,7 @@ def _select_experts_with_fusion_ops( norm_type=norm_type, # 0: softmax; 1: sigmoid out_flag=False, routed_scaling_factor=routed_scaling_factor, - eps=float(1e-20), + eps=1e-20, bias_opt=e_score_correction_bias, ) @@ -241,12 +240,12 @@ def _native_select_experts( top_k: int, use_grouped_topk: bool, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - global_num_experts: Optional[torch.Tensor] = None + e_score_correction_bias: torch.Tensor | None = None, + global_num_experts: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Select top-k experts based on router logits. @@ -285,7 +284,8 @@ def _native_select_experts( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + ) if custom_routing_function is not None: topk_weights, topk_ids = custom_routing_function( @@ -293,7 +293,8 @@ def _native_select_experts( gating_output=router_logits, topk=top_k, renormalize=renormalize, - global_num_experts=global_num_experts) + global_num_experts=global_num_experts, + ) # Required by npu_moe_init_routing topk_ids = topk_ids.to(torch.int32) return topk_weights, topk_ids @@ -318,8 +319,7 @@ def zero_experts_compute( if zero_expert_type == "identity": zero_expert_mask = expert_indices < num_experts zero_expert_scales = expert_scales.clone() - zero_expert_scales = torch.where(zero_expert_mask, 0.0, - zero_expert_scales) + zero_expert_scales = torch.where(zero_expert_mask, 0.0, zero_expert_scales) hidden_states = hidden_states.unsqueeze(1) zero_expert_scales = zero_expert_scales.unsqueeze(2) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index f67cc1a3..a53a368e 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -14,40 +14,37 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from collections.abc import Callable from dataclasses import dataclass, field from functools import wraps -from typing import Callable, Optional import torch import torch.nn.functional as F from vllm.config import get_current_vllm_config -from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, - tensor_model_parallel_all_reduce) +from vllm.distributed import get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig -from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map) -from vllm.model_executor.layers.fused_moe.shared_fused_moe import \ - SharedFusedMoE +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map +from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.eplb.core.eplb_utils import init_eplb_config -from vllm_ascend.flash_common3_context import (get_flash_common3_context, - set_flash_common3_context) -from vllm_ascend.ops.fused_moe.experts_selector import (select_experts, - zero_experts_compute) -from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, - FusedExpertsResult, - setup_moe_comm_method) +from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context +from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute +from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType -from vllm_ascend.utils import (AscendDeviceType, enable_sp, - get_ascend_device_type, maybe_trans_nz, - npu_stream_switch, shared_expert_dp_enabled, - shared_experts_calculation_stream, - vllm_version_is) +from vllm_ascend.utils import ( + enable_sp, + maybe_trans_nz, + npu_stream_switch, + shared_expert_dp_enabled, + shared_experts_calculation_stream, + vllm_version_is, +) + @dataclass class FusedMoEResult: @@ -64,46 +61,43 @@ class FusedMoEEvents: class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): - def __init__(self, moe: FusedMoEConfig = None): - super().__init__(moe=moe) self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb def process_weights_after_loading(self, layer): - super(UnquantizedFusedMoEMethod, - self).process_weights_after_loading(layer) + super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer) - w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( - 1, 2).contiguous() + w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(1, 2).contiguous() layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) - w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( - 1, 2).contiguous() + w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(1, 2).contiguous() layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - enable_force_load_balance: bool = False, - log2phy: torch.Tensor = None, - **kwargs) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + enable_force_load_balance: bool = False, + log2phy: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_type = getattr(layer, "zero_expert_type", None) topk_weights, topk_ids = select_experts( @@ -118,7 +112,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts) + global_num_experts=global_num_experts, + ) if zero_expert_num > 0 and zero_expert_type is not None: topk_ids, topk_weights, zero_expert_result = zero_experts_compute( @@ -134,11 +129,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - random_matrix = torch.rand(topk_ids.size(0), - global_num_experts, - device=topk_ids.device) - topk_ids = torch.argsort( - random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype) + random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device) + topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) moe_comm_method = get_forward_context().moe_comm_method final_hidden_states = moe_comm_method.fused_experts( @@ -151,7 +143,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): apply_router_weight_on_input=apply_router_weight_on_input, dynamic_eplb=self.dynamic_eplb, log2phy=log2phy, - mc2_mask=kwargs.get("mc2_mask", None)) + mc2_mask=kwargs.get("mc2_mask"), + ) if zero_expert_num > 0 and zero_expert_type is not None: final_hidden_states += zero_expert_result return final_hidden_states @@ -159,7 +152,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): class AscendFusedMoE(FusedMoE): moe_counter = -1 - gate_stream: Optional[torch.npu.Stream] = None + gate_stream: torch.npu.Stream | None = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -174,11 +167,9 @@ class AscendFusedMoE(FusedMoE): self.log2phy = None if self.quant_config is None: - self.quant_method = AscendUnquantizedFusedMoEMethod( - self.moe_config) + self.quant_method = AscendUnquantizedFusedMoEMethod(self.moe_config) else: - self.quant_method = self.quant_config.get_quant_method( - self, self.layer_name) + self.quant_method = self.quant_config.get_quant_method(self, self.layer_name) assert self.quant_method is not None @@ -195,28 +186,32 @@ class AscendFusedMoE(FusedMoE): if self.custom_routing_function is None and self.e_score_correction_bias is not None: vllm_config = get_current_vllm_config() self.e_score_correction_bias.data = self.e_score_correction_bias.data.to( - dtype=vllm_config.model_config.dtype) + dtype=vllm_config.model_config.dtype + ) # init moe eplb_config = ascend_config.eplb_config self.global_expert_map, self._expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config( - eplb_config, self.moe_instance_id, self.moe_config) + eplb_config, self.moe_instance_id, self.moe_config + ) self.global_num_experts = num_experts + self.global_redundant_expert_num - self.dynamic_eplb = eplb_config.dynamic_eplb and (self.log2phy - is not None) - self.local_num_experts = (torch.sum( - self._expert_map != -1).item() if self._expert_map is not None else - self.global_num_experts) + self.dynamic_eplb = eplb_config.dynamic_eplb and (self.log2phy is not None) + self.local_num_experts = ( + torch.sum(self._expert_map != -1).item() if self._expert_map is not None else self.global_num_experts + ) if self._expert_map is not None: logger.info_once( "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" " number of experts: %s/%s. Experts local to global index map:" - " %s.", self.ep_rank, self.ep_size, self.local_num_experts, + " %s.", + self.ep_rank, + self.ep_size, + self.local_num_experts, self.global_num_experts, - get_compressed_expert_map(self._expert_map)) + get_compressed_expert_map(self._expert_map), + ) if self.dynamic_eplb: - self.moe_load = torch.zeros(self.local_num_experts, - dtype=torch.int64).npu() + self.moe_load = torch.zeros(self.local_num_experts, dtype=torch.int64).npu() self.moe_config.num_experts = self.global_num_experts self.moe_config.num_local_experts = self.local_num_experts @@ -225,14 +220,12 @@ class AscendFusedMoE(FusedMoE): moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": self.hidden_size, - "intermediate_size_per_partition": - self.intermediate_size_per_partition, + "intermediate_size_per_partition": self.intermediate_size_per_partition, "params_dtype": self.params_dtype, "weight_loader": self.weight_loader, } # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ - in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): + if self.quant_method.__class__.__name__ in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod"): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) @@ -243,15 +236,14 @@ class AscendFusedMoE(FusedMoE): def _get_quant_type(self) -> QuantType: quant_method = self.quant_method - if not hasattr(quant_method, - "quant_method") or quant_method.quant_method is None: + if not hasattr(quant_method, "quant_method") or quant_method.quant_method is None: return QuantType.NONE method = quant_method.quant_method if hasattr(method, "quant_type"): - from vllm_ascend.quantization.methods.base import \ - QuantType as SchemeQuantType + from vllm_ascend.quantization.methods.base import QuantType as SchemeQuantType + scheme_quant_type = method.quant_type if scheme_quant_type == SchemeQuantType.W8A8: return QuantType.W8A8 @@ -270,22 +262,18 @@ class AscendFusedMoE(FusedMoE): if self.moe_load is not None: self.moe_load.zero_() - def maybe_all_reduce_tensor_model_parallel( - self, final_hidden_states: torch.Tensor): + def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): """NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`, and `alltoallcommimpl`, we do not need to all-reduce the final outputs since the outputs are already aggregated across tensor parallel ranks in the `finalize` function. In `allgathercommimpl`, we still need to all-reduce the outputs since each rank only has partial outputs. """ - return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( - final_hidden_states) + return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(final_hidden_states) def forward_impl( # type: ignore[override] - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - return_with_event: bool = False) -> torch.Tensor | FusedMoEResult: + self, hidden_states: torch.Tensor, router_logits: torch.Tensor, return_with_event: bool = False + ) -> torch.Tensor | FusedMoEResult: assert self.quant_method is not None forward_context = get_forward_context() @@ -301,15 +289,16 @@ class AscendFusedMoE(FusedMoE): fc3_context = get_flash_common3_context() assert fc3_context is not None AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream()) - with npu_stream_switch(AscendFusedMoE.gate_stream, - enabled=self.multistream_overlap_gate): + with npu_stream_switch(AscendFusedMoE.gate_stream, enabled=self.multistream_overlap_gate): # share_expert assert fc3_context.shared_experts is not None shared_out = fc3_context.shared_experts(hidden_states) # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \ - and not shared_expert_dp_enabled(): + if ( + moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} + and not shared_expert_dp_enabled() + ): shared_out = tensor_model_parallel_all_reduce(shared_out) set_flash_common3_context(shared_out=shared_out) @@ -325,24 +314,22 @@ class AscendFusedMoE(FusedMoE): scoring_func=self.scoring_func, routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, - global_num_experts=self.global_num_experts) + global_num_experts=self.global_num_experts, + ) - if isinstance(forward_context.moe_comm_method, - AllGatherCommImpl): - topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - topk_weights, True, True) - topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - topk_ids, True, True) + if isinstance(forward_context.moe_comm_method, AllGatherCommImpl): + topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True) + topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True) - set_flash_common3_context(topk_weights=topk_weights, - topk_ids=topk_ids) + set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids) hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare( hidden_states=hidden_states, router_logits=router_logits, replace_allreduce=forward_context.sp_enabled, enable_shared_expert_dp=self.enable_shared_expert_dp, - quant_type=self.quant_type) + quant_type=self.quant_type, + ) # Make sure the default stream waits for the gate stream to finish. if self.multistream_overlap_gate: @@ -375,39 +362,45 @@ class AscendFusedMoE(FusedMoE): enable_force_load_balance=enable_force_load_balance, log2phy=self.log2phy, global_redundant_expert_num=self.global_redundant_expert_num, - mc2_mask=mc2_mask) + mc2_mask=mc2_mask, + ) if self.dynamic_eplb: expert_tokens = fused_experts_results.expert_tokens group_list_type = fused_experts_results.group_list_type - assert expert_tokens is not None and group_list_type is not None, \ + assert expert_tokens is not None and group_list_type is not None, ( "expert_tokens and group_list_type should not be None when dynamic_eplb is enabled." - local_load = expert_tokens if group_list_type == 1 else \ - torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + ) + local_load = ( + expert_tokens + if group_list_type == 1 + else torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + ) self.moe_load.add_(local_load) routed_out = forward_context.moe_comm_method.finalize( hidden_states=fused_experts_results.routed_out, reduce_results=self.reduce_results, - context_metadata=context_metadata) + context_metadata=context_metadata, + ) if return_with_event: return FusedMoEResult( routed_out=routed_out, before_dispatch_evt=fused_experts_results.before_dispatch_evt, - before_combine_evt=fused_experts_results.before_combine_evt) + before_combine_evt=fused_experts_results.before_combine_evt, + ) else: # The vLLM FusedMoE forward_impl does not return events. return routed_out class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): - def __init__( self, shared_experts: torch.nn.Module, - gate: Optional[torch.nn.Module] = None, + gate: torch.nn.Module | None = None, use_overlapped: bool = True, - routed_input_transform: Optional[torch.nn.Module] = None, + routed_input_transform: torch.nn.Module | None = None, **kwargs, ): AscendFusedMoE.__init__(self, **kwargs) @@ -418,16 +411,12 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): self.use_overlapped = use_overlapped self.shared_expert_stream = None ascend_config = get_ascend_config() - self.multistream_overlap_shared_expert = \ - ascend_config.multistream_overlap_shared_expert and \ - self._shared_experts is not None - self.multistream_overlap_gate = \ - ascend_config.multistream_overlap_gate and \ - self._shared_experts is not None + self.multistream_overlap_shared_expert = ( + ascend_config.multistream_overlap_shared_expert and self._shared_experts is not None + ) + self.multistream_overlap_gate = ascend_config.multistream_overlap_gate and self._shared_experts is not None if enable_sp(): - logger.info_once( - "Sequence parallelism is enabled, shared experts are replicated for best performance." - ) + logger.info_once("Sequence parallelism is enabled, shared experts are replicated for best performance.") self._gate = gate @@ -447,20 +436,15 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore def _shared_experts_part1(self, hidden_states: torch.Tensor): - shared_gate_up, _ = self._shared_experts.gate_up_proj( - hidden_states) # type: ignore + shared_gate_up, _ = self._shared_experts.gate_up_proj(hidden_states) # type: ignore return shared_gate_up - def _shared_experts_part2(self, hidden_states: torch.Tensor, - shared_gate_up: torch.Tensor): - shared_act = self._shared_experts.act_fn( - shared_gate_up) # type: ignore - shared_out, _ = self._shared_experts.down_proj( - shared_act) # type: ignore + def _shared_experts_part2(self, hidden_states: torch.Tensor, shared_gate_up: torch.Tensor): + shared_act = self._shared_experts.act_fn(shared_gate_up) # type: ignore + shared_out, _ = self._shared_experts.down_proj(shared_act) # type: ignore # Qwen3-Next specific gating mechanism - if hasattr(self._shared_experts, "expert_gate") and \ - self._shared_experts.expert_gate is not None: + if hasattr(self._shared_experts, "expert_gate") and self._shared_experts.expert_gate is not None: gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore shared_out = F.sigmoid(gate_out) * shared_out return shared_out @@ -468,9 +452,9 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): def _validate_shared_expert_consistency(self): """Validate that split shared expert computation matches integrated computation.""" - test_input = torch.rand( - 10, self.hidden_size, device='npu', dtype=self.moe_config.in_dtype - ) * 2 - 1 # Random input for testing, scoped to [-1, 1] + test_input = ( + torch.rand(10, self.hidden_size, device="npu", dtype=self.moe_config.in_dtype) * 2 - 1 + ) # Random input for testing, scoped to [-1, 1] integrated_out = self._shared_experts(test_input) part1_out = self._shared_experts_part1(test_input) @@ -478,25 +462,19 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): if not torch.allclose(integrated_out, split_out): diff = (integrated_out - split_out).abs() - logger.error( - "SharedFusedMoE shared experts split computation does not " - "match the integrated computation.") + logger.error("SharedFusedMoE shared experts split computation does not match the integrated computation.") logger.error(f"Max absolute difference: {diff.max().item()}") - logger.error("Integrated output - sum: %s, norm: %s", - integrated_out.sum().item(), - integrated_out.norm().item()) - logger.error("Split output - sum: %s, norm: %s", - split_out.sum().item(), - split_out.norm().item()) + logger.error( + "Integrated output - sum: %s, norm: %s", integrated_out.sum().item(), integrated_out.norm().item() + ) + logger.error("Split output - sum: %s, norm: %s", split_out.sum().item(), split_out.norm().item()) raise ValueError( - "SharedFusedMoE shared experts split computation does not " - "match the integrated computation.") - logger.info_once( - "SharedFusedMoE shared experts split computation matches the " - "integrated computation.") + "SharedFusedMoE shared experts split computation does not match the integrated computation." + ) + logger.info_once("SharedFusedMoE shared experts split computation matches the integrated computation.") @property - def gate(self) -> Optional[torch.nn.Module]: + def gate(self) -> torch.nn.Module | None: return self._gate if self.use_overlapped else None @property @@ -530,8 +508,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): ) return shared_out, fused_out - def _forward_shared_experts(self, hidden_states: torch.Tensor, - fused_moe_evts: FusedMoEEvents): + def _forward_shared_experts(self, hidden_states: torch.Tensor, fused_moe_evts: FusedMoEEvents): if self._shared_experts is None: return None @@ -539,11 +516,9 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): if evt is not None: torch.npu.current_stream().wait_event(evt) - with npu_stream_switch(shared_experts_calculation_stream(), - enabled=self.multistream_overlap_shared_expert): + with npu_stream_switch(shared_experts_calculation_stream(), enabled=self.multistream_overlap_shared_expert): # Ensure the shared experts wait for hidden_states to be ready. - torch.npu.current_stream().wait_event( - fused_moe_evts.before_routed_experts) + torch.npu.current_stream().wait_event(fused_moe_evts.before_routed_experts) # Execute the gate projection and activation concurrently with the # dispatch communication. maybe_wait_event(fused_moe_evts.before_dispatch) @@ -556,20 +531,22 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): # Make sure the default stream waits for the shared experts stream to # finish. if self.multistream_overlap_shared_expert: - torch.npu.current_stream().wait_stream( - shared_experts_calculation_stream()) + torch.npu.current_stream().wait_stream(shared_experts_calculation_stream()) # NOTE: This is exactly the opposite of # `maybe_all_reduce_tensor_model_parallel` forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \ - and not shared_expert_dp_enabled(): + if ( + moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} + and not shared_expert_dp_enabled() + ): shared_out = tensor_model_parallel_all_reduce(shared_out) return shared_out def forward_impl( # type: ignore[override] - self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + self, hidden_states: torch.Tensor, router_logits: torch.Tensor + ): if self.multistream_overlap_gate: set_flash_common3_context(shared_experts=self._shared_experts) @@ -596,6 +573,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): before_routed_experts=before_routed_experts, before_dispatch=fused_moe_results.before_dispatch_evt, before_combine=fused_moe_results.before_combine_evt, - )) + ), + ) return shared_out, routed_out diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index d135968c..e1c31520 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -17,7 +17,6 @@ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, Optional import torch from vllm.forward_context import get_forward_context @@ -27,18 +26,24 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.prepare_finalize import ( - PrepareAndFinalize, PrepareAndFinalizeWithAll2All, - PrepareAndFinalizeWithAllGather, PrepareAndFinalizeWithMC2, QuantType) + PrepareAndFinalize, + PrepareAndFinalizeWithAll2All, + PrepareAndFinalizeWithAllGather, + PrepareAndFinalizeWithMC2, + QuantType, +) from vllm_ascend.ops.fused_moe.token_dispatcher import ( - MoETokenDispatcher, TokenDispatcherWithAll2AllV, - TokenDispatcherWithAllGather, TokenDispatcherWithMC2) + MoETokenDispatcher, + TokenDispatcherWithAll2AllV, + TokenDispatcherWithAllGather, + TokenDispatcherWithMC2, +) -_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} +_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {} -def get_moe_comm_method( - moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]: - return _MoECommMethods.get(moe_comm_type, None) +def get_moe_comm_method(moe_comm_type: MoECommType | None) -> MoECommMethod | None: + return _MoECommMethods.get(moe_comm_type) def setup_moe_comm_method(moe_config): @@ -50,6 +55,7 @@ def setup_moe_comm_method(moe_config): def set_gmmswigluquant_method(): from vllm_ascend.ascend_config import get_ascend_config + ascend_config = get_ascend_config() return ascend_config.ascend_fusion_config.fusion_ops_gmmswigluquant @@ -84,51 +90,46 @@ class MoECommMethod(ABC): enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, quant_type: QuantType = QuantType.NONE, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare( - hidden_states, router_logits, enable_shared_expert_dp, - replace_allreduce, quant_type) + hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, quant_type + ) return hidden_states, router_logits, mc2_mask, context_metadata - def finalize(self, - hidden_states: torch.Tensor, - reduce_results: bool, - context_metadata: Optional[dict] = None) -> torch.Tensor: - hidden_states = self.prepare_finalize.finalize(hidden_states, - reduce_results, - context_metadata) + def finalize( + self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None + ) -> torch.Tensor: + hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, context_metadata) return hidden_states def fused_experts( - self, - hidden_states: torch.Tensor, - w1: torch.Tensor | list[torch.Tensor], - w2: torch.Tensor | list[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_int8_w8a8: bool = False, - use_int4_w4a8: bool = False, - use_int4_w4a16: bool = False, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[list[torch.Tensor]] = None, - w2_scale: Optional[list[torch.Tensor]] = None, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - w1_offset: Optional[torch.Tensor] = None, - w2_offset: Optional[torch.Tensor] = None, - # For load balance - log2phy: torch.Tensor = None, - need_trans: bool = False, - dynamic_eplb: bool = False, - mc2_mask: torch.Tensor = None, - pertoken_scale: Optional[torch.Tensor] = None): + self, + hidden_states: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_int8_w8a8: bool = False, + use_int4_w4a8: bool = False, + use_int4_w4a16: bool = False, + expert_map: torch.Tensor | None = None, + w1_scale: list[torch.Tensor] | None = None, + w2_scale: list[torch.Tensor] | None = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + w1_offset: torch.Tensor | None = None, + w2_offset: torch.Tensor | None = None, + # For load balance + log2phy: torch.Tensor = None, + need_trans: bool = False, + dynamic_eplb: bool = False, + mc2_mask: torch.Tensor = None, + pertoken_scale: torch.Tensor | None = None, + ): # Check constraints - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16, torch.int8 - ] + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8] moe_comm_method = get_forward_context().moe_comm_method assert moe_comm_method is not None, "Missing communication context" @@ -143,13 +144,13 @@ class MoECommMethod(ABC): topk_weights=topk_weights, topk_ids=topk_ids, expert_map=expert_map, - global_redundant_expert_num=self.moe_config. - global_redundant_expert_num, + global_redundant_expert_num=self.moe_config.global_redundant_expert_num, mc2_mask=mc2_mask, apply_router_weight_on_input=apply_router_weight_on_input, with_quant=use_int8_w8a8 or use_int4_w4a8, dynamic_eplb=dynamic_eplb, - pertoken_scale=pertoken_scale) + pertoken_scale=pertoken_scale, + ) mlp_output = unified_apply_mlp( hidden_states=dispatch_results.hidden_states, @@ -168,29 +169,29 @@ class MoECommMethod(ABC): with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16, fusion=use_int8_w8a8 and self.use_fusion_ops, need_trans=need_trans, - dynamic_eplb=dynamic_eplb) + dynamic_eplb=dynamic_eplb, + ) before_combine_evt = torch.npu.current_stream().record_event() combine_results = self.token_dispatcher.token_combine( - hidden_states=mlp_output, - context_metadata=dispatch_results.context_metadata) + hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata + ) return FusedExpertsResult( routed_out=combine_results.routed_out, before_dispatch_evt=before_dispatch_evt, before_combine_evt=before_combine_evt, group_list_type=dispatch_results.group_list_type, - expert_tokens=dispatch_results.group_list) + expert_tokens=dispatch_results.group_list, + ) @abstractmethod def _get_token_dispatcher(self) -> MoETokenDispatcher: - raise NotImplementedError( - "_get_token_dispatcher function not implemented.") + raise NotImplementedError("_get_token_dispatcher function not implemented.") @abstractmethod def _get_prepare_finalize(self) -> PrepareAndFinalize: - raise NotImplementedError( - "_get_prepare_finalize function not implemented.") + raise NotImplementedError("_get_prepare_finalize function not implemented.") class AllGatherCommImpl(MoECommMethod): @@ -216,7 +217,8 @@ class AllGatherCommImpl(MoECommMethod): return TokenDispatcherWithAllGather( top_k=self.moe_config.experts_per_token, num_experts=self.moe_config.num_experts, - num_local_experts=self.moe_config.num_local_experts) + num_local_experts=self.moe_config.num_local_experts, + ) def _get_prepare_finalize(self): return PrepareAndFinalizeWithAllGather(self.moe_config) @@ -227,7 +229,7 @@ class MC2CommImpl(MoECommMethod): 1. `enable_expert_parallel=True`. 2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available. 3. `enable_expert_parallel=False` is not supported. - + This implementation uses the MC2 communication method, which is optimized for Communication and Computation parallelism on Ascend devices. """ @@ -253,7 +255,8 @@ class AlltoAllCommImpl(MoECommMethod): return TokenDispatcherWithAll2AllV( top_k=self.moe_config.experts_per_token, num_experts=self.moe_config.num_experts, - num_local_experts=self.moe_config.num_local_experts) + num_local_experts=self.moe_config.num_local_experts, + ) def _get_prepare_finalize(self): return PrepareAndFinalizeWithAll2All(self.moe_config) @@ -264,7 +267,7 @@ class FusedMC2CommImpl(MoECommMethod): 1. `enable_expert_parallel=True`. 2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available. 3. `enable_expert_parallel=False` is not supported. - + This implementation uses the MC2 communication method, which is optimized for Communication and Computation parallelism on Ascend devices. """ @@ -276,36 +279,36 @@ class FusedMC2CommImpl(MoECommMethod): return PrepareAndFinalizeWithMC2(self.moe_config) def fused_experts( - self, - hidden_states: torch.Tensor, - w1: torch.Tensor | list[torch.Tensor], - w2: torch.Tensor | list[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - use_int8_w8a8: bool = False, - use_int4_w4a8: bool = False, - use_int4_w4a16: bool = False, - expert_map: Optional[torch.Tensor] = None, - w1_scale: Optional[list[torch.Tensor]] = None, - w2_scale: Optional[list[torch.Tensor]] = None, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - w1_offset: Optional[torch.Tensor] = None, - w2_offset: Optional[torch.Tensor] = None, - # For load balance - log2phy: torch.Tensor = None, - need_trans: bool = False, - dynamic_eplb: bool = False, - mc2_mask: torch.Tensor = None, - pertoken_scale: Optional[torch.Tensor] = None): - assert not ( - w1_scale is None or w2_scale is None - ), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl." + self, + hidden_states: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_int8_w8a8: bool = False, + use_int4_w4a8: bool = False, + use_int4_w4a16: bool = False, + expert_map: torch.Tensor | None = None, + w1_scale: list[torch.Tensor] | None = None, + w2_scale: list[torch.Tensor] | None = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + w1_offset: torch.Tensor | None = None, + w2_offset: torch.Tensor | None = None, + # For load balance + log2phy: torch.Tensor = None, + need_trans: bool = False, + dynamic_eplb: bool = False, + mc2_mask: torch.Tensor = None, + pertoken_scale: torch.Tensor | None = None, + ): + assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl." - assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \ + assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), ( "token_dispatcher must be an instance of TokenDispatcherWithMC2." + ) # Apply log2phy if needed if log2phy is not None: @@ -346,10 +349,8 @@ class FusedMC2CommImpl(MoECommMethod): ep_rank_size=self.token_dispatcher.ep_world_size, ep_rank_id=self.token_dispatcher.ep_rank_id, moe_expert_num=self.moe_config.num_experts, - global_bs=self.token_dispatcher.global_bs) + global_bs=self.token_dispatcher.global_bs, + ) else: - raise ValueError( - f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}") - return FusedExpertsResult(routed_out=out, - group_list_type=group_list_type, - expert_tokens=expert_tokens) + raise ValueError(f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}") + return FusedExpertsResult(routed_out=out, group_list_type=group_list_type, expert_tokens=expert_tokens) diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index e29945ea..65673b71 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -14,7 +14,6 @@ # limitations under the License. # This file is a part of the vllm-ascend project. -from typing import Optional import torch import torch_npu @@ -23,24 +22,22 @@ from vllm.forward_context import get_forward_context from vllm.triton_utils import HAS_TRITON from vllm_ascend.ascend_forward_context import MoECommType -from vllm_ascend.utils import (AscendDeviceType, dispose_tensor, - enable_custom_op, get_ascend_device_type, - get_weight_prefetch_method) +from vllm_ascend.utils import ( + dispose_tensor, + enable_custom_op, + get_weight_prefetch_method, +) def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): return fusion and dynamic_eplb and enable_custom_op() -def cumsum_group_list(group_list: torch.Tensor, - src_list_type: int, - dst_list_type: int, - active_num: int = 0, - expert_num: int = 0) -> torch.Tensor: +def cumsum_group_list( + group_list: torch.Tensor, src_list_type: int, dst_list_type: int, active_num: int = 0, expert_num: int = 0 +) -> torch.Tensor: if src_list_type not in [0, 1, 2]: - raise ValueError( - f"group_list_type should be in [0, 1, 2], but received {src_list_type}" - ) + raise ValueError(f"group_list_type should be in [0, 1, 2], but received {src_list_type}") if src_list_type == dst_list_type: return group_list @@ -53,10 +50,9 @@ def cumsum_group_list(group_list: torch.Tensor, if src_list_type == 2 and dst_list_type == 0: experts = pad(group_list[:, 0], (1, 0)) tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0)) - cumsum_group_list = torch.full(size=(expert_num, ), - fill_value=active_num, - dtype=group_list.dtype, - device=group_list.device) + cumsum_group_list = torch.full( + size=(expert_num,), fill_value=active_num, dtype=group_list.dtype, device=group_list.device + ) for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])): if end > start: @@ -65,30 +61,32 @@ def cumsum_group_list(group_list: torch.Tensor, return cumsum_group_list raise NotImplementedError( f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. " - "This feature is under development.") + "This feature is under development." + ) -def quant_apply_mlp(hidden_states: torch.Tensor, - w1: list[torch.Tensor], - w1_scale: list[torch.Tensor], - w2: list[torch.Tensor], - w2_scale: list[torch.Tensor], - group_list: torch.Tensor, - group_list_type: int = 1, - dynamic_scale: torch.Tensor = None, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - w1_offset: Optional[torch.Tensor] = None, - w2_offset: Optional[torch.Tensor] = None, - fusion: bool = False, - dynamic_eplb: bool = False) -> torch.Tensor: +def quant_apply_mlp( + hidden_states: torch.Tensor, + w1: list[torch.Tensor], + w1_scale: list[torch.Tensor], + w2: list[torch.Tensor], + w2_scale: list[torch.Tensor], + group_list: torch.Tensor, + group_list_type: int = 1, + dynamic_scale: torch.Tensor = None, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + w1_offset: torch.Tensor | None = None, + w2_offset: torch.Tensor | None = None, + fusion: bool = False, + dynamic_eplb: bool = False, +) -> torch.Tensor: if w1_offset is not None: unquantized_hidden_states = hidden_states quantized_hidden_states = None elif dynamic_scale is None: unquantized_hidden_states = hidden_states - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( - hidden_states) + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) # Dispose the original unquantized hidden states # to save npu memory because they're no longer used. dispose_tensor(unquantized_hidden_states) @@ -103,22 +101,18 @@ def quant_apply_mlp(hidden_states: torch.Tensor, weight_prefetch_method = get_weight_prefetch_method() if weight_prefetch_method: - weight_prefetch_method.maybe_prefetch_moe_weight_postprocess( - hidden_states) + weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states) is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 if w1_scale_bias is None and w1_offset is None and is_mc2: if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): # gmm1: gate_up_proj & act_fn: swiglu - hidden_states, swiglu_out_scale, _ = ( - torch.ops._C_ascend. - grouped_matmul_swiglu_quant_weight_nz_tensor_list( - x=hidden_states, - weight=w1, - weight_scale=w1_scale, - x_scale=pertoken_scale, - group_list=cumsum_group_list(group_list, group_list_type, - 0), - )) + hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list( + x=hidden_states, + weight=w1, + weight_scale=w1_scale, + x_scale=pertoken_scale, + group_list=cumsum_group_list(group_list, group_list_type, 0), + ) elif fusion and not dynamic_eplb: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( @@ -126,7 +120,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, weight=w1[0], group_list=cumsum_group_list(group_list, group_list_type, 0), weight_scale=w1_scale[0], - x_scale=pertoken_scale) + x_scale=pertoken_scale, + ) if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) else: @@ -140,7 +135,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=torch.int32)[0] + output_dtype=torch.int32, + )[0] if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) # act_fn: swiglu @@ -165,7 +161,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=w2_scale[0].dtype)[0] + output_dtype=w2_scale[0].dtype, + )[0] elif w1_offset is not None: # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( @@ -177,7 +174,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=_output_dtype)[0] + output_dtype=_output_dtype, + )[0] dispose_tensor(unquantized_hidden_states) # act_fn: swiglu hidden_states = torch_npu.npu_swiglu(hidden_states) @@ -191,13 +189,12 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=_output_dtype)[0] + output_dtype=_output_dtype, + )[0] else: if w1_scale_bias is not None: if group_list_type == 0: - group_list = torch.cat( - [group_list[:1], - torch.diff(group_list, dim=0)]) + group_list = torch.cat([group_list[:1], torch.diff(group_list, dim=0)]) group_list_type = 1 bias1 = [w1_scale_bias] if not fusion else w1_scale_bias bias2 = [w2_scale_bias] @@ -206,17 +203,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor, if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): # gmm1: gate_up_proj & act_fn: swiglu - hidden_states, swiglu_out_scale, _ = ( - torch.ops._C_ascend. - grouped_matmul_swiglu_quant_weight_nz_tensor_list( - x=hidden_states, - weight=w1, - weight_scale=w1_scale, - x_scale=pertoken_scale, - group_list=cumsum_group_list(group_list, group_list_type, - 0), - bias=bias1, - )) + hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list( + x=hidden_states, + weight=w1, + weight_scale=w1_scale, + x_scale=pertoken_scale, + group_list=cumsum_group_list(group_list, group_list_type, 0), + bias=bias1, + ) elif fusion and not dynamic_eplb: # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( @@ -225,7 +219,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor, bias=bias1, group_list=cumsum_group_list(group_list, group_list_type, 0), weight_scale=w1_scale[0], - x_scale=pertoken_scale) + x_scale=pertoken_scale, + ) if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) else: @@ -241,21 +236,20 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=_output_dtype)[0] + output_dtype=_output_dtype, + )[0] if quantized_hidden_states is not None: dispose_tensor(quantized_hidden_states) # act_fn: swiglu if HAS_TRITON: - from vllm_ascend.ops.triton.activation.swiglu_quant import \ - swiglu_quant + from vllm_ascend.ops.triton.activation.swiglu_quant import swiglu_quant + hidden_states, swiglu_out_scale = swiglu_quant( - hidden_states, - group_list=group_list, - group_list_type=group_list_type) + hidden_states, group_list=group_list, group_list_type=group_list_type + ) else: hidden_states = torch_npu.npu_swiglu(hidden_states) - hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( - hidden_states) + hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states) # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -267,18 +261,20 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_list_type=group_list_type, group_type=0, group_list=group_list, - output_dtype=_output_dtype)[0] + output_dtype=_output_dtype, + )[0] return hidden_states -def unquant_apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - group_list: torch.Tensor, - group_list_type: int = 1, - topk_scales: Optional[torch.Tensor] = None, - need_trans: bool = True) -> torch.Tensor: - +def unquant_apply_mlp( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, + topk_scales: torch.Tensor | None = None, + need_trans: bool = True, +) -> torch.Tensor: if need_trans: w1 = w1.transpose(1, 2) w2 = w2.transpose(1, 2) @@ -307,44 +303,50 @@ def unquant_apply_mlp(hidden_states: torch.Tensor, return hidden_states -def unified_apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor | list[torch.Tensor], - w2: torch.Tensor | list[torch.Tensor], - group_list: torch.Tensor, - w1_scale: Optional[list[torch.Tensor]] = None, - w2_scale: Optional[list[torch.Tensor]] = None, - dynamic_scale: torch.Tensor = None, - group_list_type: int = 1, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - w1_offset: Optional[torch.Tensor] = None, - w2_offset: Optional[torch.Tensor] = None, - topk_scales: Optional[torch.Tensor] = None, - with_quant: bool = False, - fusion: bool = False, - need_trans: bool = True, - dynamic_eplb: bool = False) -> torch.Tensor: +def unified_apply_mlp( + hidden_states: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], + group_list: torch.Tensor, + w1_scale: list[torch.Tensor] | None = None, + w2_scale: list[torch.Tensor] | None = None, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + w1_offset: torch.Tensor | None = None, + w2_offset: torch.Tensor | None = None, + topk_scales: torch.Tensor | None = None, + with_quant: bool = False, + fusion: bool = False, + need_trans: bool = True, + dynamic_eplb: bool = False, +) -> torch.Tensor: if with_quant: assert w1_scale is not None and w2_scale is not None - return quant_apply_mlp(hidden_states=hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=group_list, - dynamic_scale=dynamic_scale, - group_list_type=group_list_type, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias, - w1_offset=w1_offset, - w2_offset=w2_offset, - fusion=fusion, - dynamic_eplb=dynamic_eplb) + return quant_apply_mlp( + hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=group_list, + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + w1_offset=w1_offset, + w2_offset=w2_offset, + fusion=fusion, + dynamic_eplb=dynamic_eplb, + ) else: - return unquant_apply_mlp(hidden_states=hidden_states, - w1=w1, - w2=w2, - group_list=group_list, - group_list_type=group_list_type, - topk_scales=topk_scales, - need_trans=need_trans) + return unquant_apply_mlp( + hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + group_list_type=group_list_type, + topk_scales=topk_scales, + need_trans=need_trans, + ) diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index df8e4cf4..ce467cf6 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -16,22 +16,23 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Optional import torch import torch.distributed as dist import torch.nn as nn import torch_npu from vllm.distributed.parallel_state import ( - get_dp_group, get_pcp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_dp_group, + get_pcp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl -from vllm_ascend.utils import (enable_sp, npu_stream_switch, - prefill_context_parallel_enable) +from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable class QuantType(Enum): @@ -51,7 +52,8 @@ class PrepareAndFinalize(ABC): moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info, sizes, ranks, and communication settings. """ - quant_stream: Optional[torch.npu.Stream] = None + + quant_stream: torch.npu.Stream | None = None def __init__(self, moe_config: FusedMoEConfig): self.moe_config = moe_config @@ -67,9 +69,8 @@ class PrepareAndFinalize(ABC): router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, - quant_type: QuantType = QuantType.NONE - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + quant_type: QuantType = QuantType.NONE, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """ Prepare tensors before MoE computation. May involve: - Padding to align communication boundaries @@ -92,10 +93,9 @@ class PrepareAndFinalize(ABC): """ raise NotImplementedError("Prepare not implemented.") - def finalize(self, - hidden_states: torch.Tensor, - reduce_results: bool, - context_metadata: Optional[dict] = None) -> torch.Tensor: + def finalize( + self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None + ) -> torch.Tensor: """ Finalize MoE output. May involve: - Gathering sliced tensors across TP ranks @@ -135,9 +135,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize): router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, - quant_type=QuantType.NONE - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + quant_type=QuantType.NONE, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """ Preparation steps: 1. Pad hidden_states and router_logits to next multiple of TP size. @@ -158,33 +157,24 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize): pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic) if pad_size > 0: - hidden_states = nn.functional.pad(hidden_states, - (0, 0, 0, pad_size)) - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size)) padded_hidden_states_shape = hidden_states.shape if self.tp_size > 1: - split_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - split_router_logits = torch.tensor_split(router_logits, - self.tp_size, - dim=0) + split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0) + split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0) hidden_states = split_hidden_states[self.tp_rank] router_logits = split_router_logits[self.tp_rank] - context_metadata = { - "padded_hidden_states_shape": padded_hidden_states_shape - } + context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape} return hidden_states, router_logits, None, context_metadata - def finalize(self, - hidden_states: torch.Tensor, - reduce_results: bool, - context_metadata: Optional[dict] = None) -> torch.Tensor: + def finalize( + self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None + ) -> torch.Tensor: """ Finalization steps: 1. If TP > 1, all-gather slices to reconstruct full tensor. @@ -201,20 +191,16 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize): # may share memory with original hidden_states. Since shared # experts may use the original tensor, reusing it would cause # in-place modification during all_gather, corrupting the data. - padded_hidden_states_shape = context_metadata[ - "padded_hidden_states_shape"] + padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"] gathered_hidden_states = torch.empty( - padded_hidden_states_shape, - device=hidden_states.device, - dtype=hidden_states.dtype) - split_hidden_states = torch.tensor_split( - gathered_hidden_states, self.tp_size, dim=0) - dist.all_gather(list(split_hidden_states), hidden_states, - self.moe_config.tp_group.device_group) + padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype + ) + split_hidden_states = torch.tensor_split(gathered_hidden_states, self.tp_size, dim=0) + dist.all_gather(list(split_hidden_states), hidden_states, self.moe_config.tp_group.device_group) hidden_states = gathered_hidden_states if self.num_tokens < hidden_states.shape[0]: - hidden_states = hidden_states[:self.num_tokens] + hidden_states = hidden_states[: self.num_tokens] return hidden_states @@ -246,9 +232,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All): router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, - quant_type=QuantType.NONE - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + quant_type=QuantType.NONE, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """ Preparation steps: 1. Fetch `mc2_mask` and target padding length from forward context. @@ -278,20 +263,14 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All): # Pad if necessary (unless shared expert DP is enabled) if pad_size > 0 and not self.enable_shared_expert_dp: - hidden_states = nn.functional.pad(hidden_states, - (0, 0, 0, pad_size)) - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size)) padded_hidden_states_shape = hidden_states.shape # Slice across TP ranks if self.tp_size > 1 and not self.enable_shared_expert_dp: - split_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - split_router_logits = torch.tensor_split(router_logits, - self.tp_size, - dim=0) + split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0) + split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0) hidden_states = split_hidden_states[self.tp_rank] router_logits = split_router_logits[self.tp_rank] @@ -330,9 +309,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, - quant_type=QuantType.NONE - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + quant_type=QuantType.NONE, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """ Preparation steps: AllGather hidden_states and router_logits to form global tensors. @@ -341,46 +319,31 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): Tuple of (global_hidden_states, global_router_logits, None) """ if enable_sp(): - return self._prepare_with_ep_group(hidden_states, router_logits, - quant_type) + return self._prepare_with_ep_group(hidden_states, router_logits, quant_type) - return self._prepare_with_dp_group(hidden_states, router_logits, - enable_shared_expert_dp, - replace_allreduce) + return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce) def _prepare_with_ep_group( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - quant_type=QuantType.NONE - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: pertoken_scale = None if quant_type == QuantType.W8A8: - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( - hidden_states) + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) if self.multistream_overlap_gate: assert PrepareAndFinalize.quant_stream is not None - PrepareAndFinalize.quant_stream.wait_stream( - torch.npu.current_stream()) - with npu_stream_switch(PrepareAndFinalize.quant_stream, - enabled=self.multistream_overlap_gate): - hidden_states = fc3_all_gather_and_maybe_unpad_impl( - hidden_states) + PrepareAndFinalize.quant_stream.wait_stream(torch.npu.current_stream()) + with npu_stream_switch(PrepareAndFinalize.quant_stream, enabled=self.multistream_overlap_gate): + hidden_states = fc3_all_gather_and_maybe_unpad_impl(hidden_states) else: - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - hidden_states, True, True) - router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - router_logits, True, True) + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states, True, True) + router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(router_logits, True, True) if pertoken_scale is not None: - pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - pertoken_scale, True, True) + pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(pertoken_scale, True, True) if self.multistream_overlap_gate: - torch.npu.current_stream().wait_stream( - PrepareAndFinalize.quant_stream) + torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream) if pertoken_scale is not None: return (hidden_states, pertoken_scale), router_logits, None, None @@ -393,9 +356,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, - quant_type=QuantType.NONE - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], - Optional[torch.Tensor]]: + quant_type=QuantType.NONE, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """ Preparation steps: 1. Fetch max token count across DP group from forward context. @@ -413,16 +375,12 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): self.num_tokens = hidden_states.shape[0] pad_size = max_tokens_across_dp - self.num_tokens if pad_size > 0: - hidden_states = nn.functional.pad(hidden_states, - (0, 0, 0, pad_size)) - router_logits = nn.functional.pad(router_logits, - (0, 0, 0, pad_size)) + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size)) + router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size)) # All-gather across DP group - hidden_states = self.moe_config.dp_group.all_gather( - hidden_states, 0) - router_logits = self.moe_config.dp_group.all_gather( - router_logits, 0) + hidden_states = self.moe_config.dp_group.all_gather(hidden_states, 0) + router_logits = self.moe_config.dp_group.all_gather(router_logits, 0) if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: hidden_states = get_pcp_group().all_gather( @@ -436,10 +394,9 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): return hidden_states, router_logits, None, None - def finalize(self, - hidden_states: torch.Tensor, - reduce_results: bool, - context_metadata: Optional[dict] = None) -> torch.Tensor: + def finalize( + self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None + ) -> torch.Tensor: """ Finalization steps: Reduce Scatter hidden states. @@ -452,8 +409,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): return self._finalize_with_dp_group(hidden_states, reduce_results) - def _finalize_with_ep_group(self, - hidden_states: torch.Tensor) -> torch.Tensor: + def _finalize_with_ep_group(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled: 1. Reduce_results is False usually happens when models have shared experts and need to @@ -463,13 +419,11 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): 2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter here, then skip allreudce in FusedMoe. """ - hidden_states = torch.ops.vllm.maybe_pad_and_reduce( - hidden_states, True) + hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states, True) return hidden_states - def _finalize_with_dp_group(self, hidden_states: torch.Tensor, - reduce_results: bool) -> torch.Tensor: + def _finalize_with_dp_group(self, hidden_states: torch.Tensor, reduce_results: bool) -> torch.Tensor: """ Finalization steps: 1. If DP > 1 and not shared expert, reduce-scatter output across DP group. @@ -481,9 +435,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): """ if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp: hidden_states = get_dp_group().reduce_scatter(hidden_states, 0) - hidden_states = hidden_states[:self.num_tokens] + hidden_states = hidden_states[: self.num_tokens] if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: - hidden_states = get_pcp_group().reduce_scatter(hidden_states, - dim=0) + hidden_states = get_pcp_group().reduce_scatter(hidden_states, dim=0) return hidden_states diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index a783da19..d909ab89 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -22,7 +22,6 @@ # limitations under the License. from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional import torch import torch_npu @@ -30,10 +29,8 @@ from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.fused_moe.comm_utils import ( - async_all_to_all, gather_from_sequence_parallel_region) -from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, - is_hierarchical_communication_enabled) +from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled @dataclass @@ -52,7 +49,6 @@ class TokenCombineResult: class MoETokenDispatcher(ABC): - def __init__(self, **kwargs) -> None: """ Initialize the MoE Token Dispatcher. @@ -79,26 +75,24 @@ class MoETokenDispatcher(ABC): hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, + expert_map: torch.Tensor | None = None, global_redundant_expert_num: int = 0, - mc2_mask: Optional[torch.Tensor] = None, + mc2_mask: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, with_quant: bool = False, dynamic_eplb: bool = False, - pertoken_scale: Optional[torch.Tensor] = None, + pertoken_scale: torch.Tensor | None = None, ) -> TokenDispatchResult: raise NotImplementedError("Dispatch function not implemented.") @abstractmethod - def token_combine(self, - hidden_states: torch.Tensor, - context_metadata: dict, - bias: torch.Tensor | None = None) -> TokenCombineResult: + def token_combine( + self, hidden_states: torch.Tensor, context_metadata: dict, bias: torch.Tensor | None = None + ) -> TokenCombineResult: raise NotImplementedError("Combine function not implemented.") class TokenDispatcherWithMC2(MoETokenDispatcher): - def __init__(self, **kwargs): super().__init__(**kwargs) device_group = get_mc2_group().device_group @@ -108,10 +102,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) self.ep_rank_id = get_mc2_group().rank_in_group self.ep_world_size = get_mc2_group().world_size - self.enable_dispatch_v2 = hasattr(torch_npu, - "npu_moe_distribute_dispatch_v2") - self.need_extra_args = ( - get_ascend_device_type() == AscendDeviceType.A3) + self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") + self.need_extra_args = get_ascend_device_type() == AscendDeviceType.A3 # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly @@ -126,10 +118,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): compilation_config = vllm_config.compilation_config speculative_config = vllm_config.speculative_config tp_size = vllm_config.parallel_config.tensor_parallel_size - uniform_decode_query_len = 1 if not speculative_config else \ - 1 + speculative_config.num_speculative_tokens - decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', - 0) + uniform_decode_query_len = 1 if not speculative_config else 1 + speculative_config.num_speculative_tokens + decode_max_num_seqs = getattr(scheduler_config, "decode_max_num_seqs", 0) max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) if compilation_config.cudagraph_capture_sizes: max_num_tokens = compilation_config.max_cudagraph_capture_size @@ -167,44 +157,56 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "ep_rank_id": self.ep_rank_id, } if self.need_extra_args: - stage1_kwargs.update({ - "group_tp": self.moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) + stage1_kwargs.update( + { + "group_tp": self.moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + } + ) if self.need_expert_scale: - stage1_kwargs.update({ - "expert_scales": - topk_weights.to(torch.float32), - }) + stage1_kwargs.update( + { + "expert_scales": topk_weights.to(torch.float32), + } + ) kwargs_mc2.update(stage1_kwargs) return kwargs_mc2 - def token_dispatch(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - dynamic_eplb: bool = False, - pertoken_scale: Optional[torch.Tensor] = None): + def token_dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor | None = None, + global_redundant_expert_num: int = 0, + mc2_mask: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + dynamic_eplb: bool = False, + pertoken_scale: torch.Tensor | None = None, + ): self.with_quant = with_quant - kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights, - topk_ids, expert_map, - mc2_mask, - global_redundant_expert_num) - output = torch_npu.npu_moe_distribute_dispatch_v2( - **kwargs_mc2 - ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( - **kwargs_mc2) + kwargs_mc2 = self.get_dispatch_mc2_kwargs( + hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num + ) + output = ( + torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2) + if self.enable_dispatch_v2 + else torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) + ) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ - ep_recv_counts, tp_recv_counts, expand_scales = output[0:7] + ( + expand_x, + dynamic_scale, + assist_info_for_combine, + expert_token_nums, + ep_recv_counts, + tp_recv_counts, + expand_scales, + ) = output[0:7] context_metadata = { "topk_ids": topk_ids, @@ -213,18 +215,19 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "ep_recv_counts": ep_recv_counts, "tp_recv_counts": tp_recv_counts, "assist_info_for_combine": assist_info_for_combine, - "expand_scales": expand_scales + "expand_scales": expand_scales, } group_list_type = 0 - return TokenDispatchResult(hidden_states=expand_x, - dynamic_scale=dynamic_scale, - group_list=expert_token_nums, - group_list_type=group_list_type, - context_metadata=context_metadata) + return TokenDispatchResult( + hidden_states=expand_x, + dynamic_scale=dynamic_scale, + group_list=expert_token_nums, + group_list_type=group_list_type, + context_metadata=context_metadata, + ) - def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, - context_metadata: dict): + def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, context_metadata: dict): expert_map = context_metadata["expert_map"] topk_ids = context_metadata["topk_ids"] topk_weights = context_metadata["topk_weights"] @@ -246,9 +249,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): } if self.with_quant: - tp_recv_counts = torch.empty(1, - dtype=torch.int32, - device=hidden_states.device) + tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device) stage3_kwargs = { "ep_send_counts": ep_recv_counts, @@ -264,12 +265,14 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): stage3_kwargs["expand_idx"] = assist_info_for_combine if self.need_extra_args: - stage3_kwargs.update({ - "tp_send_counts": tp_recv_counts, - "group_tp": self.moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) + stage3_kwargs.update( + { + "tp_send_counts": tp_recv_counts, + "group_tp": self.moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + } + ) kwargs_mc2.update(stage3_kwargs) return kwargs_mc2 @@ -277,57 +280,58 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): def token_combine(self, hidden_states, context_metadata, bias=None): assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." - kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, - context_metadata) - combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \ - if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2) + kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, context_metadata) + combined_output = ( + torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) + if self.enable_dispatch_v2 + else torch_npu.npu_moe_distribute_combine(**kwargs_mc2) + ) - return TokenCombineResult(routed_out=combined_output, ) + return TokenCombineResult( + routed_out=combined_output, + ) class TokenDispatcherWithAllGather(MoETokenDispatcher): - def __init__(self, **kwargs): super().__init__(**kwargs) self.apply_router_weight_on_input = False self.max_num_tokens = kwargs.get("max_num_tokens") num_experts_local = kwargs.get("num_local_experts", 0) - self.num_experts_local = num_experts_local.item() if torch.is_tensor( - num_experts_local) else int(num_experts_local) + self.num_experts_local = ( + num_experts_local.item() if torch.is_tensor(num_experts_local) else int(num_experts_local) + ) self.original_shape = None self.with_quant = False - def token_dispatch(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - dynamic_eplb: bool = False, - pertoken_scale: Optional[torch.Tensor] = None): + def token_dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor | None = None, + global_redundant_expert_num: int = 0, + mc2_mask: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + dynamic_eplb: bool = False, + pertoken_scale: torch.Tensor | None = None, + ): self.with_quant = with_quant self.original_shape = hidden_states.shape num_tokens = hidden_states.shape[:-1].numel() self.apply_router_weight_on_input = apply_router_weight_on_input if self.apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" + assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)" _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - hidden_states = hidden_states * \ - topk_weights.to(hidden_states.dtype) + assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True" + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) if expert_map is not None: global_num_experts = len(expert_map) + global_redundant_expert_num - mask = (expert_map[topk_ids] != -1) + mask = expert_map[topk_ids] != -1 topk_weights = topk_weights * mask - first_expert_idx = get_ep_group( - ).rank_in_group * self.num_experts_local + first_expert_idx = get_ep_group().rank_in_group * self.num_experts_local last_expert_idx = first_expert_idx + self.num_experts_local else: first_expert_idx = 0 @@ -344,15 +348,12 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): expert_tokens_num_type=1, expert_tokens_num_flag=True, active_expert_range=[first_expert_idx, last_expert_idx], - quant_mode=1 - if self.with_quant and pertoken_scale is None else -1, - )) + quant_mode=1 if self.with_quant and pertoken_scale is None else -1, + ) + ) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 1 # `count` mode - context_metadata = { - "topk_weights": topk_weights, - "expanded_row_idx": expanded_row_idx - } + context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx} return TokenDispatchResult( hidden_states=sorted_hidden_states, @@ -367,7 +368,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): final_hidden_states = torch_npu.npu_moe_token_unpermute( permuted_tokens=hidden_states, sorted_indices=torch.abs(context_metadata["expanded_row_idx"]), - probs=context_metadata["topk_weights"]) + probs=context_metadata["topk_weights"], + ) if len(self.original_shape) == 3: final_hidden_states = final_hidden_states.view(self.original_shape) @@ -398,35 +400,33 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): device=torch.npu.current_device(), ) - local_expert_indices_offset = (self.ep_rank * self.num_local_experts) + local_expert_indices_offset = self.ep_rank * self.num_local_experts - self.local_expert_indices = [ - local_expert_indices_offset + i - for i in range(self.num_local_experts) - ] - assert (len(self.local_expert_indices) == self.num_local_experts - ), "Invalid local expert indices" + self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)] + assert len(self.local_expert_indices) == self.num_local_experts, "Invalid local expert indices" for i in range(len(self.local_expert_indices) - 1): - assert (self.local_expert_indices[i] == - self.local_expert_indices[i + 1] - - 1), "local_expert_indices must be continuous" + assert self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1, ( + "local_expert_indices must be continuous" + ) # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=self.ep_group) backend = self.ep_group._get_backend(torch.device("npu")) self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) - def token_dispatch(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - mc2_mask: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - dynamic_eplb: bool = False, - pertoken_scale: Optional[torch.Tensor] = None): + def token_dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor | None = None, + global_redundant_expert_num: int = 0, + mc2_mask: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + dynamic_eplb: bool = False, + pertoken_scale: torch.Tensor | None = None, + ): self.with_quant = with_quant self.hidden_shape = hidden_states.shape @@ -442,35 +442,32 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): dynamic_scale_after_all2all = None if self.with_quant: - permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant( - permutated_local_input_tokens) + permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(permutated_local_input_tokens) _, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all( - dynamic_scale, output_splits, input_splits, self.ep_group) + dynamic_scale, output_splits, input_splits, self.ep_group + ) permute2_ep_all_to_all_handle.wait() dynamic_scale.untyped_storage().resize_(0) _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( - permutated_local_input_tokens, output_splits, input_splits, - self.ep_group) + permutated_local_input_tokens, output_splits, input_splits, self.ep_group + ) permute1_ep_all_to_all_handle.wait() permutated_local_input_tokens.untyped_storage().resize_(0) # Postprocess - global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = self._dispatch_postprocess( - global_input_tokens, dynamic_scale_after_all2all, - global_input_tokens_local_experts_indices) + global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = ( + self._dispatch_postprocess( + global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices + ) + ) context_metadata = { - "input_splits": - input_splits, - "output_splits": - output_splits, - "topk_weights": - topk_weights, - "reversed_local_input_permutation_mapping": - reversed_local_input_permutation_mapping, - "reversed_global_input_permutation_mapping": - reversed_global_input_permutation_mapping + "input_splits": input_splits, + "output_splits": output_splits, + "topk_weights": topk_weights, + "reversed_local_input_permutation_mapping": reversed_local_input_permutation_mapping, + "reversed_global_input_permutation_mapping": reversed_global_input_permutation_mapping, } return TokenDispatchResult( @@ -485,8 +482,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." # 1. Preprocess using metadata - hidden_states = self._combine_preprocess(hidden_states, - context_metadata) + hidden_states = self._combine_preprocess(hidden_states, context_metadata) # 2. AllToAll _, permutated_local_input_tokens, handle = async_all_to_all( @@ -499,8 +495,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): hidden_states.untyped_storage().resize_(0) # 3. Postprocess using metadata - output = self._combine_postprocess(permutated_local_input_tokens, - context_metadata) + output = self._combine_postprocess(permutated_local_input_tokens, context_metadata) return TokenCombineResult(routed_out=output) @@ -534,42 +529,39 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): ) def _preprocess(self, topk_ids: torch.Tensor): - num_local_tokens_per_expert = torch.histc(topk_ids, - bins=self.num_experts, - min=0, - max=self.num_experts) + num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts) ep_size = self.ep_size self.num_out_tokens = topk_ids.numel() - input_splits = (num_local_tokens_per_expert.reshape( - ep_size, - self.num_local_experts).sum(axis=1).to(torch.device("cpu"), - non_blocking=True).numpy()) + input_splits = ( + num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts) + .sum(axis=1) + .to(torch.device("cpu"), non_blocking=True) + .numpy() + ) num_global_tokens_per_expert = gather_from_sequence_parallel_region( - num_local_tokens_per_expert, - group=self.ep_group).reshape(ep_size, self.num_experts) - num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ - 0]:self.local_expert_indices[-1] + 1] + num_local_tokens_per_expert, group=self.ep_group + ).reshape(ep_size, self.num_experts) + num_global_tokens_per_local_expert = num_global_tokens_per_expert[ + :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 + ] if num_global_tokens_per_local_expert is None: - raise ValueError( - "num_global_tokens_per_local_expert must be set before sum.") + raise ValueError("num_global_tokens_per_local_expert must be set before sum.") - output_splits = (num_global_tokens_per_local_expert.sum(axis=-1).to( - torch.device("cpu"), non_blocking=True).numpy()) - num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum( - axis=0) + output_splits = ( + num_global_tokens_per_local_expert.sum(axis=-1).to(torch.device("cpu"), non_blocking=True).numpy() + ) + num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(axis=0) global_input_tokens_local_experts_indices = None if self.num_local_experts > 1: if num_global_tokens_per_local_expert is None: - raise ValueError( - "num_global_tokens_per_local_expert must be set before operations." - ) + raise ValueError("num_global_tokens_per_local_expert must be set before operations.") global_input_tokens_local_experts_indices = torch.repeat_interleave( - self.expert_ids_per_ep_rank, - num_global_tokens_per_local_expert.ravel()) + self.expert_ids_per_ep_rank, num_global_tokens_per_local_expert.ravel() + ) else: torch.npu.synchronize() @@ -581,45 +573,41 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): global_input_tokens_local_experts_indices, ) - def _dispatch_postprocess(self, global_input_tokens, - dynamic_scale_after_all2all, - global_input_tokens_local_experts_indices): + def _dispatch_postprocess( + self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices + ): # Early return if no local experts or no tokens if self.num_local_experts <= 1: return global_input_tokens, dynamic_scale_after_all2all, None # Handle quantized case if self.with_quant: - assert global_input_tokens_local_experts_indices is not None, \ + assert global_input_tokens_local_experts_indices is not None, ( "global_input_tokens_local_experts_indices must be provided" + ) dynamic_scale_after_all2all, _ = torch_npu.npu_moe_token_permute( - dynamic_scale_after_all2all.unsqueeze(-1), - global_input_tokens_local_experts_indices) - dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze( - -1) + dynamic_scale_after_all2all.unsqueeze(-1), global_input_tokens_local_experts_indices + ) + dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze(-1) # Non-quantized case global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( - global_input_tokens, global_input_tokens_local_experts_indices) + global_input_tokens, global_input_tokens_local_experts_indices + ) return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping - def _combine_preprocess(self, hidden_states: torch.Tensor, - context_metadata: dict) -> torch.Tensor: + def _combine_preprocess(self, hidden_states: torch.Tensor, context_metadata: dict) -> torch.Tensor: # Unpermutation 2: expert output to AlltoAll input if hidden_states.shape[0] > 0 and self.num_local_experts > 1: - rev_global = context_metadata[ - "reversed_global_input_permutation_mapping"] - hidden_states = torch_npu.npu_moe_token_unpermute( - hidden_states, rev_global) + rev_global = context_metadata["reversed_global_input_permutation_mapping"] + hidden_states = torch_npu.npu_moe_token_unpermute(hidden_states, rev_global) return hidden_states - def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, - context_metadata: dict) -> torch.Tensor: + def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, context_metadata: dict) -> torch.Tensor: # Unpermutation 1: AlltoAll output to output output = torch_npu.npu_moe_token_unpermute( permuted_tokens=permutated_local_input_tokens, - sorted_indices=context_metadata[ - "reversed_local_input_permutation_mapping"].to(torch.int32), + sorted_indices=context_metadata["reversed_local_input_permutation_mapping"].to(torch.int32), probs=context_metadata["topk_weights"], restore_shape=self.hidden_shape_before_permute, )