Revert "GMM custom operator optimization in small batch scenarios (vllm-project#7100)" (#7557)
### What this PR does / why we need it? This reverts commit42bcad7e9b. The commit cause accuracy decrease of qwen3Next, 150 items of gsm8k, 98 -> 91. - vLLM version: v0.18.0 - vLLM main:6a9cceb219Signed-off-by: Your Name <you@example.com> Co-authored-by: Your Name <you@example.com>
This commit is contained in:
@@ -18,7 +18,6 @@
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch.nn.functional import pad
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
@@ -331,26 +330,15 @@ def unquant_apply_mlp(
|
||||
w1 = w1.transpose(1, 2)
|
||||
w2 = w2.transpose(1, 2)
|
||||
|
||||
# In the small batch scenario, use _C_ascend.moe_grouped_matmul
|
||||
if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num:
|
||||
gate_up_out = torch.ops._C_ascend.moe_grouped_matmul(
|
||||
x=hidden_states,
|
||||
weight=w1,
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None,
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None,
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
if activation == "swigluoai":
|
||||
num_experts, _, hidden_size = w1.shape
|
||||
@@ -361,26 +349,15 @@ def unquant_apply_mlp(
|
||||
if topk_scales is not None:
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
# In the small batch scenario, use _C_ascend.moe_grouped_matmul
|
||||
if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num:
|
||||
hidden_states = torch.ops._C_ascend.moe_grouped_matmul(
|
||||
x=gate_up_out,
|
||||
weight=w2,
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
bias=[w2_bias.to(dtype=torch.float32)] if w2_bias is not None else None,
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
bias=[w2_bias.to(dtype=torch.float32)] if w2_bias is not None else None,
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user