From 0db6670bfab8cb1d84c9e7270df0a1d42d6ce7ca Mon Sep 17 00:00:00 2001 From: yiz-liu <136800916+yiz-liu@users.noreply.github.com> Date: Tue, 11 Mar 2025 21:08:02 +0800 Subject: [PATCH] [Feature] Implement EP-compatible fused_moe (#121) ### What this PR does / why we need it? Enable Expert-Parallel for ascend devices. ### Does this PR introduce _any_ user-facing change? Enable EP add `enable_expert_parallel=True` in your offline inference scripts, like this: ```python llm = LLM( model="/path/to/model", trust_remote_code=True, tensor_parallel_size=4, max_model_len=4096, enforce_eager=True, distributed_executor_backend="mp", enable_expert_parallel=True, ) ``` ### How was this patch tested? Please use the `main` branch of vLLM. --------- Signed-off-by: Yizhou Liu Co-authored-by: Yizhou Liu --- tests/ops/test_fused_moe.py | 96 +++++++++ vllm_ascend/ops/fused_moe.py | 397 ++++++++++++++++++++++++----------- 2 files changed, 365 insertions(+), 128 deletions(-) create mode 100644 tests/ops/test_fused_moe.py diff --git a/tests/ops/test_fused_moe.py b/tests/ops/test_fused_moe.py new file mode 100644 index 0000000..c8ac3f4 --- /dev/null +++ b/tests/ops/test_fused_moe.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/kernels/test_moe.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE layers. + +Run `pytest tests/ops/test_fused_moe.py`. +""" + +import pytest +import torch +from vllm.model_executor.layers.activation import SiluAndMul + +from vllm_ascend.ops.fused_moe import fused_experts + +NUM_EXPERTS = [8, 64] +EP_SIZE = [1, 4] +TOP_KS = [2, 6] +DEVICE = ["npu"] + + +def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + topk_weights = topk_weights.view(-1) + topk_ids = topk_ids.view(-1) + if expert_map is not None: + topk_ids = expert_map[topk_ids] + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * + topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 511, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("ep_size", EP_SIZE) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("device", DEVICE) +def test_fused_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + ep_size: int, + dtype: torch.dtype, + device: str, +): + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 + + score = torch.randn((m, e), device=device, dtype=dtype) + + if ep_size > 1: + local_e = e // ep_size + e_ids = torch.randint(0, + e, (local_e, ), + device=device, + dtype=torch.int32) + e_map = torch.full((e, ), -1, device=device, dtype=torch.int32) + e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32) + w1 = w1[e_ids] + w2 = w2[e_ids] + else: + e_map = None + + score = torch.softmax(score, dim=-1, dtype=dtype) + topk_weights, topk_ids = torch.topk(score, topk) + topk_ids = topk_ids.to(torch.int32) + + output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map) + torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map) + # TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem + torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index e216f92..2061b68 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1,6 +1,7 @@ -# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/kernels/test_moe.py +# Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# from typing import Callable, Optional @@ -23,167 +23,308 @@ from vllm.model_executor.layers.fused_moe.layer import \ UnquantizedFusedMoEMethod -def group_topk(hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: Optional[int] = 0, - topk_group: Optional[int] = 0, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None): +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, +) -> torch.Tensor: + """ + Fused experts with top-k routing. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). + w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + top_k: Number of experts to select. + expert_map: Expert mapping of shape (num_experts,). - if scoring_func == "softmax": - scores = torch.softmax(gating_output, dim=-1) - elif scoring_func == "sigmoid": - scores = gating_output.sigmoid() - else: - raise ValueError(f"Unsupported scoring function: {scoring_func}") - - if e_score_correction_bias is not None: - # Store original scores before applying correction bias. We use biased - # scores for expert selection but original scores for routing weights - original_scores = scores - scores = scores + e_score_correction_bias.unsqueeze(0) - - topk_group = 0 if topk_group is None else topk_group - num_expert_group = 0 if num_expert_group is None else num_expert_group - - # TODO: Replace this piece of code to npu_group_topk when CANN and NNAL version is update - num_token = scores.shape[0] - group_scores = scores.view(num_token, num_expert_group, - -1).max(dim=-1).values - group_idx = torch.topk(group_scores.to(torch.float32), - k=topk_group, - dim=-1, - sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - scores.shape[-1] // num_expert_group).reshape(num_token, -1) - scores = scores.masked_fill(~score_mask.bool(), 0.0) - - if e_score_correction_bias is not None: - topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1] - # Use original unbiased scores for the routing weights - topk_weights = original_scores.gather(1, topk_ids) - else: - topk_weights, topk_ids = torch.topk(scores, - k=topk, - dim=-1, - sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids.to(torch.int32) - - -def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, top_k: int): + Returns: + hidden_states: Hidden states after routing. + """ # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - ori_shape = hidden_states.shape - if len(ori_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - num_tokens, _ = hidden_states.shape - E, N, _ = w1.shape + original_shape = hidden_states.shape + assert len(original_shape) == 2 - row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view(top_k, -1).permute( - 1, 0).contiguous() - expanded_x, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + dtype = hidden_states.dtype + device = hidden_states.device + assert dtype in [torch.float32, torch.float16, torch.bfloat16 + ], "Only float32, float16, and bfloat16 are supported" - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, E) - expert_tokens = expert_tokens.to(torch.int64) + if expert_map is not None: + # Generate token indices and flatten + token_indices = (torch.arange(num_tokens, + device=device, + dtype=torch.int64).unsqueeze(1).expand( + -1, top_k).reshape(-1)) + + # Flatten token-to-expert mappings and map to local experts + weights_flat = topk_weights.view(-1) + experts_flat = topk_ids.view(-1) + local_experts_flat = expert_map[experts_flat] + + # Filter valid token-expert pairs + mask = local_experts_flat != -1 + filtered_weights = torch.where( + mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) + filtered_experts = torch.where( + mask, local_experts_flat, + torch.full_like(local_experts_flat, + num_experts)).to(topk_ids.dtype) + + # Sort by local expert IDs + sort_indices = torch.argsort(filtered_experts) + sorted_token_indices = token_indices[sort_indices] + sorted_weights = filtered_weights[sort_indices] + + # Compute token counts with minlength of num_experts + # This is equivalent to but faster than: + # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] + token_counts = torch.zeros(num_experts + 1, + device=device, + dtype=torch.int64) + ones = torch.ones_like(filtered_experts, dtype=torch.int64) + token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) + token_counts = token_counts[:num_experts] + expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64) + + # Rearrange hidden_states + sorted_hidden_states = hidden_states[sorted_token_indices] + else: + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=device).view(top_k, -1).permute( + 1, 0).contiguous()) + sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts) + expert_tokens = expert_tokens.to(torch.int64) w1 = w1.transpose(1, 2) - gate_up_out_list = torch_npu.npu_grouped_matmul(x=[expanded_x], - weight=[w1], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens) + gate_up_out_list = torch_npu.npu_grouped_matmul( + x=[sorted_hidden_states], + weight=[w1], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + ) # TODO: Remove this in the future. gate_up_out = torch.cat(gate_up_out_list, dim=0) gate_up_out = torch_npu.npu_swiglu(gate_up_out) w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul(x=[gate_up_out], - weight=[w2], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens) + down_out_list = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + ) down_out_list = torch.cat(down_out_list, dim=0) - # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - hidden_states = torch_npu.npu_moe_finalize_routing( - down_out_list, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids) - if len(ori_shape) == 3: - hidden_states = hidden_states.view(ori_shape) - return hidden_states + + if expert_map is not None: + weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) + + final_hidden_states = torch.zeros(*original_shape, + device=hidden_states.device, + dtype=dtype) + final_hidden_states.index_add_(0, sorted_token_indices, + weighted_down_out) + # TODO: This should not happen! Look into it! + # fill nan with 0.0 + final_hidden_states[torch.isnan(final_hidden_states)] = 0.0 + else: + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. + final_hidden_states = torch_npu.npu_moe_finalize_routing( + down_out_list, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + + return final_hidden_states + + +def native_grouped_topk( + topk_weights: torch.Tensor, + num_expert_group: Optional[int], + topk_group: Optional[int], +): + topk_group = 0 if topk_group is None else topk_group + num_expert_group = 0 if num_expert_group is None else num_expert_group + + num_token = topk_weights.shape[0] + grouped_weights = topk_weights.view(num_token, num_expert_group, + -1).max(dim=-1).values + topk_group_indices = torch.topk(grouped_weights.to(torch.float32), + k=topk_group, + dim=-1, + sorted=False)[1] + topk_group_mask = torch.zeros_like(grouped_weights) + topk_group_mask.scatter_(1, topk_group_indices, 1) + topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) + topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) + + return topk_weights + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Select top-k experts based on router logits. + + Args: + hidden_states: Hidden states of shape (num_tokens, hidden_size). + router_logits: Router logits of shape (num_tokens, num_experts). + top_k: Number of experts to select. + use_grouped_topk: Whether to group experts before selecting top-k. + renormalize: Whether to renormalize the routing weights. + topk_group: Number of expert groups to select from. + num_expert_group: Number of experts in each group. + custom_routing_function: Custom routing function. + scoring_func: Scoring function to use. + e_score_correction_bias: Correction bias to apply to expert scores. + + Returns: + topk_weights: Routing weights of shape (num_tokens, top_k). + topk_ids: Selected expert IDs of shape (num_tokens, top_k). + + Raises: + ValueError: If an unsupported scoring function is provided. + """ + assert hidden_states.shape[0] == router_logits.shape[0], ( + "Number of tokens mismatch") + + if custom_routing_function is not None: + raise NotImplementedError( + "Custom routing function is not supported now") + + if scoring_func == "softmax": + # NOTE: vLLM use dtype=torch.float here + topk_weights = router_logits.softmax(dim=-1) + elif scoring_func == "sigmoid": + topk_weights = router_logits.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_weights = topk_weights + topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) + + # TODO: Change to npu_group_topk when the latest CANN and NNAL is available + # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) + topk_weights = native_grouped_topk(topk_weights, num_expert_group, + topk_group) + + if e_score_correction_bias is not None: + topk_ids = torch.topk(topk_weights, k=top_k, dim=-1, + sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_weights.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(topk_weights, + k=top_k, + dim=-1, + sorted=False) + else: + topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) + topk_weights = topk_weights.to(hidden_states.dtype) + + # Required by npu_moe_init_routing + topk_ids = topk_ids.to(torch.int32) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids def forward_oot( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None -) -> torch.Tensor: + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + **kwargs, +): + assert router_logits.shape[ + 1] == global_num_experts, "Number of global experts mismatch" - topk_weights, topk_ids = group_topk( + topk_weights, topk_ids = select_experts( hidden_states=x, - gating_output=router_logits, - topk=top_k, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, renormalize=renormalize, - num_expert_group=num_expert_group, topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + ) return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - top_k=top_k) + top_k=top_k, + expert_map=expert_map) UnquantizedFusedMoEMethod.forward_oot = forward_oot