moe_gating_top_k (#5271)

1. What this PR does / why we need it?
This PR supports the moe_gating_top_k operator, which enables
post-positioned renormalization (renorm) on the basis of softmax.
2. Does this PR introduce any user-facing change?
No user-facing changes are required.
3. How was this patch tested?
This patch was tested with the test_npu_moe_gating_top_k test case.
vLLM version: release/v0.13.0
vLLM main:
ad32e3e19c

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
Signed-off-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
ZCG12345
2025-12-30 09:28:01 +08:00
committed by GitHub
parent 15d73f248e
commit 45c3c279e2
34 changed files with 4791 additions and 22 deletions

View File

@@ -179,7 +179,8 @@ class SmallOps(DecodeMoeOps):
shared_expert_rank_num=self.shared_expert_rank_num,
quant_mode=2,
global_bs=self.batch_size * self.ep_world_size,
expert_token_nums_type=1, # 0代表前缀和1代表各自数量
expert_token_nums_type=
1, # 0 represents prefix sum, 1 represents individual counts
)
expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums, ep_send_counts, tp_send_counts, expand_scales = outputs
output_dtype = x.dtype
@@ -188,8 +189,8 @@ class SmallOps(DecodeMoeOps):
x=[expand_x],
weight=[self.gmm1_weight],
split_item=3,
group_list_type=1, # 默认为0代表前缀和形式
group_type=0, # 0代表m轴分组
group_list_type=1, # Default is 0, represents prefix sum format
group_type=0, # 0 represents m-axis grouping
group_list=expert_token_nums,
output_dtype=torch.int32)[0]
y1, y1_scale = torch_npu.npu_dequant_swiglu_quant(
@@ -365,7 +366,7 @@ def run_once(local_rank_id,
with_mc2_mask=False):
log_file = redirect_output(f"local_rank_{local_rank_id}.log"
) if output_to_file(local_rank_id) else None
global_rank_id = local_rank_id # 单机
global_rank_id = local_rank_id # Single machine
device_id = local_rank_id % 16
torch_npu.npu.set_device(device_id)

View File

@@ -0,0 +1,322 @@
import itertools
import logging
import random
from typing import Optional, Tuple
import numpy as np
import torch
from torch_npu.testing.testcase import TestCase, run_tests
try:
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
except ImportError:
logging.warning(
"vllm_ascend.utils.enable_custom_op not found, skip custom op initialization"
)
def enable_custom_op() -> None:
pass
# Set random seed for reproducibility
SEED = 45
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.manual_seed_all(SEED)
# Configure logging
logging.basicConfig(level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger(__name__)
def softmax_func(
x: np.ndarray,
axis: Optional[int] = None,
eps: float = 1e-20) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Stable softmax implementation for MOE gating.
Args:
x: Input array
axis: Axis to compute softmax
eps: Epsilon to avoid division by zero
Returns:
softmax_output: Softmax result
x_max: Max value for numerical stability
x_sum: Sum of exponentials
"""
if "float16" in x.dtype.name:
x = x.astype(np.float32)
x_max = x.max(axis=axis, keepdims=True)
x_sub = x - x_max
y = np.exp(x_sub)
x_sum = y.sum(axis=axis, keepdims=True)
softmax_output = y / (x_sum + eps)
return softmax_output, x_max, x_sum
class TestNpuMoeGatingTopK(TestCase):
"""Test suite for NPU MOE Gating Top-K operator compatibility."""
def moe_gating_top_k_np(
self,
x: np.ndarray,
k: int,
bias: Optional[np.ndarray] = None,
k_group: int = 1,
group_count: int = 1,
group_select_mode: int = 0,
renorm: int = 0,
norm_type: int = 0,
y2_flag: bool = False,
routed_scaling_factor: float = 1.0,
eps: float = 1e-20
) -> Tuple[torch.Tensor, np.ndarray, Optional[np.ndarray]]:
"""
NumPy reference implementation of MOE gating Top-K logic.
Args:
x: Input features, shape [batch_size, num_experts]
k: Number of experts to select per sample
bias: Gating bias, shape [num_experts]
k_group: Number of groups to select (group mode)
group_count: Number of expert groups
group_select_mode: 0 (max per group), 1 (sum of top-2 per group)
renorm: Whether to renormalize weights (1=enable, 0=disable)
norm_type: 0 (softmax), 1 (sigmoid)
y2_flag: Whether to return original x as y2
routed_scaling_factor: Weight scaling factor
eps: Epsilon for numerical stability
Returns:
y: Selected expert weights (Tensor)
indices: Selected expert indices (int32 numpy array)
y2: Original x if y2_flag=True, else None
"""
# Convert torch tensors to numpy arrays if needed (compatibility layer)
if isinstance(x, torch.Tensor):
x = x.cpu().numpy()
if isinstance(bias, torch.Tensor):
bias = bias.cpu().numpy()
# Type conversion for numerical stability
orig_dtype = x.dtype
if orig_dtype != np.float32:
x = x.astype(np.float32)
if bias is not None:
bias = bias.astype(np.float32)
# Apply normalization (softmax/sigmoid)
if norm_type == 0:
x, _, _ = softmax_func(x, axis=-1, eps=eps)
else:
x = 1 / (1 + np.exp(-x)) # Sigmoid
original_x = x.copy()
# Apply bias if provided
if bias is not None:
x = x + bias
# Group-based expert selection
if group_count > 1:
batch_size, num_experts = x.shape
if num_experts % group_count != 0:
raise ValueError(
f"num_experts ({num_experts}) must be divisible by group_count ({group_count})"
)
group_size = num_experts // group_count
# Reshape to [batch, groups, group_size]
x_reshaped = x.reshape(batch_size, group_count, group_size)
# Compute group scores
if group_select_mode == 0:
group_scores = np.amax(x_reshaped, axis=-1)
else:
# Sum of top-2 values per group
group_scores = np.partition(x_reshaped, -2,
axis=-1)[..., -2:].sum(axis=-1)
# Select top-k_group groups
top_groups = np.argsort(-group_scores, axis=-1,
kind="stable")[:, :k_group]
# Mask out non-selected groups with -inf
mask = np.ones((batch_size, group_count), dtype=bool)
mask[np.arange(batch_size)[:, None], top_groups] = False
x_reshaped = np.where(mask[..., None], float("-inf"), x_reshaped)
# Reshape back to original
x = x_reshaped.reshape(batch_size, num_experts)
# Select top-k experts
x_tensor = torch.from_numpy(x)
_, topk_indices = torch.sort(x_tensor,
dim=-1,
stable=True,
descending=True)
topk_indices = np.asarray(topk_indices[:, :k], dtype=np.int32)
# Extract weights for selected experts
selected_weights = np.take_along_axis(original_x, topk_indices, axis=1)
# Apply renormalization if needed
if norm_type == 1 or renorm == 1:
weight_sum = np.sum(selected_weights, axis=-1, keepdims=True)
selected_weights = selected_weights / (weight_sum + eps)
# Apply scaling factor
selected_weights *= routed_scaling_factor
# Prepare y2 output
y2 = original_x if y2_flag else None
# Convert back to torch tensor with original dtype
selected_weights_tensor = torch.tensor(selected_weights,
dtype=orig_dtype)
return selected_weights_tensor, topk_indices, y2
def test_npu_moe_gating_topk_multi(self) -> None:
"""
Multi-case test for NPU MOE Gating Top-K operator.
Validates compatibility with different input shapes and parameter combinations.
"""
# Test parameter space (aligned with vllm-ascend use cases)
test_configs = {
"group_select_modes": [0, 1],
"renorms": [1],
"norm_types": [0, 1],
"group_counts": [1, 8],
"k_ranges": [4, 8, 12, 16, 6, 32],
"x_dim0": range(1, 17), # Batch size 1-16
"x_dim1": [256, 128, 64, 208, 192, 160] # Expert counts
}
# Generate parameter combinations
param_combinations = itertools.product(
test_configs["group_select_modes"], test_configs["renorms"],
test_configs["norm_types"], test_configs["group_counts"],
test_configs["k_ranges"], test_configs["x_dim0"],
test_configs["x_dim1"])
# Limit test cases to avoid excessive runtime (adjust as needed)
max_test_cases = 100
tested_cases = 0
for params in param_combinations:
if tested_cases >= max_test_cases:
break
(group_select_mode, renorm, norm_type, group_count, k, dim0,
dim1) = params
# Skip invalid configurations
if group_count > 1:
if dim1 % group_count != 0:
continue
if k > (dim1 // group_count):
continue
# Generate random inputs (consistent with vllm-ascend input distribution)
x_np = np.random.uniform(-2.0, 2.0,
(dim0, dim1)).astype(np.float32)
bias_np = np.random.uniform(-2.0, 2.0, (dim1, )).astype(np.float32)
# Convert to torch tensors
x_tensor = torch.tensor(x_np, dtype=torch.float32)
bias_tensor = torch.tensor(bias_np, dtype=torch.float32)
# Random k_group (within valid range)
k_group = random.randint(1, min(group_count, 4))
# Fixed parameters (aligned with NPU OP defaults)
y2_flag = False
routed_scaling_factor = 1.0
eps = 1e-20
try:
# Get NumPy reference result
ref_weights, ref_indices, ref_y2 = self.moe_gating_top_k_np(
x=x_tensor,
k=k,
bias=bias_tensor,
k_group=k_group,
group_count=group_count,
group_select_mode=group_select_mode,
renorm=renorm,
norm_type=norm_type,
y2_flag=y2_flag,
routed_scaling_factor=routed_scaling_factor,
eps=eps)
# Skip if NPU OP is not available
if not hasattr(torch.ops, "_C_ascend") or not hasattr(
torch.ops._C_ascend, "moe_gating_top_k"):
logger.warning(
"NPU MOE gating OP not found, skipping NPU test")
continue
# Get NPU OP result
npu_weights, npu_indices, npu_y2 = torch.ops._C_ascend.moe_gating_top_k(
x=x_tensor.npu(),
k=k,
kGroup=k_group,
groupCount=group_count,
groupSelectMode=group_select_mode,
renorm=renorm,
normType=norm_type,
outFlag=y2_flag,
routedScalingFactor=routed_scaling_factor,
eps=eps,
biasOptional=bias_tensor.npu()
if bias_tensor is not None else None)
# Convert NPU results to CPU for comparison
npu_weights_cpu = npu_weights.cpu()
npu_indices_cpu = npu_indices.cpu().numpy()
# Log test case info (vllm-ascend standard format)
logger.info(
f"Test Case {tested_cases + 1}: "
f"x_shape=({dim0},{dim1}), k={k}, group_count={group_count}, "
f"select_mode={group_select_mode}, norm_type={norm_type}, renorm={renorm}"
)
# Validate results (RTOL=1e-3 is standard for NPU numerical tolerance)
self.assertRtolEqual(ref_weights,
npu_weights_cpu,
rtol=1e-3,
atol=1e-5)
self.assertRtolEqual(ref_indices, npu_indices_cpu)
# Validate y2 if enabled
if y2_flag:
self.assertRtolEqual(ref_y2,
npu_y2.cpu().numpy(),
rtol=1e-3,
atol=1e-5)
tested_cases += 1
logger.info(f"Test Case {tested_cases} passed ")
except Exception as e:
logger.error(f"Test Case failed with error: {str(e)}",
exc_info=True)
continue
logger.info(f"Completed {tested_cases}/{max_test_cases} test cases")
if __name__ == "__main__":
# Run tests with vllm-ascend standard verbosity
run_tests(verbosity=2)