Revert "GMM custom operator optimization in small batch scenarios (vllm-project#7100)" (#7557)

### What this PR does / why we need it?
This reverts commit 42bcad7e9b. The commit
cause accuracy decrease of qwen3Next, 150 items of gsm8k, 98 -> 91.

- vLLM version: v0.18.0
- vLLM main:
6a9cceb219

Signed-off-by: Your Name <you@example.com>
Co-authored-by: Your Name <you@example.com>
This commit is contained in:
LeeWenquan
2026-03-24 14:24:44 +08:00
committed by GitHub
parent 83bd77c983
commit 475b4b0cea
3 changed files with 30 additions and 71 deletions

View File

@@ -691,7 +691,7 @@ std::vector<at::Tensor> moe_grouped_matmul(
y.emplace_back(y_0);
at::TensorList result = at::TensorList(y);
EXEC_NPU_CMD(aclnnMoeGroupedMatmul,
EXEC_NPU_CMD(aclnnMoeGroupedMatmulWeightNz,
x_list, weight_list, group_list, transpose_weight, result);
return y;

View File

@@ -17,7 +17,6 @@
#
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.device.mxfp_compat import (
FLOAT4_E2M1FN_X2_DTYPE,
@@ -28,8 +27,6 @@ from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
class BaseDeviceAdaptor:
small_batch_gmm_batch_num = 16
@classmethod
def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping):
torch_npu._npu_reshape_and_cache(
@@ -49,32 +46,17 @@ class BaseDeviceAdaptor:
active_expert_range=None,
quant_mode: int = -1,
):
# In small batch and non-quantization scenarios, npu_moe_init_routing_v2 is more efficient.
# It is expected that further improvements will be made after it is incorporated into CANN on June 30th.
if quant_mode == -1 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num:
return torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
scale=scale,
active_num=active_num,
expert_num=expert_num,
expert_tokens_num_type=2,
expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range,
quant_mode=quant_mode,
)
else:
return torch.ops._C_ascend.npu_moe_init_routing_custom(
hidden_states,
topk_ids,
scale=scale,
active_num=active_num,
expert_num=expert_num,
expert_tokens_num_type=expert_tokens_num_type,
expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range,
quant_mode=quant_mode,
)
return torch.ops._C_ascend.npu_moe_init_routing_custom(
hidden_states,
topk_ids,
scale=scale,
active_num=active_num,
expert_num=expert_num,
expert_tokens_num_type=expert_tokens_num_type,
expert_tokens_num_flag=expert_tokens_num_flag,
active_expert_range=active_expert_range,
quant_mode=quant_mode,
)
@staticmethod
def npu_dynamic_quant(

View File

@@ -18,7 +18,6 @@
import torch
import torch_npu
from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
@@ -331,26 +330,15 @@ def unquant_apply_mlp(
w1 = w1.transpose(1, 2)
w2 = w2.transpose(1, 2)
# In the small batch scenario, use _C_ascend.moe_grouped_matmul
if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num:
gate_up_out = torch.ops._C_ascend.moe_grouped_matmul(
x=hidden_states,
weight=w1,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
else:
gate_up_out = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
gate_up_out = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
if activation == "swigluoai":
num_experts, _, hidden_size = w1.shape
@@ -361,26 +349,15 @@ def unquant_apply_mlp(
if topk_scales is not None:
gate_up_out *= topk_scales
# In the small batch scenario, use _C_ascend.moe_grouped_matmul
if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num:
hidden_states = torch.ops._C_ascend.moe_grouped_matmul(
x=gate_up_out,
weight=w2,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
else:
hidden_states = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
bias=[w2_bias.to(dtype=torch.float32)] if w2_bias is not None else None,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
hidden_states = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
bias=[w2_bias.to(dtype=torch.float32)] if w2_bias is not None else None,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
return hidden_states