diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 386fb9af..959f10b3 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -691,7 +691,7 @@ std::vector moe_grouped_matmul( y.emplace_back(y_0); at::TensorList result = at::TensorList(y); - EXEC_NPU_CMD(aclnnMoeGroupedMatmul, + EXEC_NPU_CMD(aclnnMoeGroupedMatmulWeightNz, x_list, weight_list, group_list, transpose_weight, result); return y; diff --git a/vllm_ascend/device/device_op.py b/vllm_ascend/device/device_op.py index 3a2b4fa1..46701d6a 100644 --- a/vllm_ascend/device/device_op.py +++ b/vllm_ascend/device/device_op.py @@ -17,7 +17,6 @@ # import torch import torch_npu -from vllm.forward_context import get_forward_context from vllm_ascend.device.mxfp_compat import ( FLOAT4_E2M1FN_X2_DTYPE, @@ -28,8 +27,6 @@ from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type class BaseDeviceAdaptor: - small_batch_gmm_batch_num = 16 - @classmethod def reshape_and_cache(cls, key, value, key_cache, value_cache, slot_mapping): torch_npu._npu_reshape_and_cache( @@ -49,32 +46,17 @@ class BaseDeviceAdaptor: active_expert_range=None, quant_mode: int = -1, ): - # In small batch and non-quantization scenarios, npu_moe_init_routing_v2 is more efficient. - # It is expected that further improvements will be made after it is incorporated into CANN on June 30th. - if quant_mode == -1 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num: - return torch_npu.npu_moe_init_routing_v2( - hidden_states, - topk_ids, - scale=scale, - active_num=active_num, - expert_num=expert_num, - expert_tokens_num_type=2, - expert_tokens_num_flag=expert_tokens_num_flag, - active_expert_range=active_expert_range, - quant_mode=quant_mode, - ) - else: - return torch.ops._C_ascend.npu_moe_init_routing_custom( - hidden_states, - topk_ids, - scale=scale, - active_num=active_num, - expert_num=expert_num, - expert_tokens_num_type=expert_tokens_num_type, - expert_tokens_num_flag=expert_tokens_num_flag, - active_expert_range=active_expert_range, - quant_mode=quant_mode, - ) + return torch.ops._C_ascend.npu_moe_init_routing_custom( + hidden_states, + topk_ids, + scale=scale, + active_num=active_num, + expert_num=expert_num, + expert_tokens_num_type=expert_tokens_num_type, + expert_tokens_num_flag=expert_tokens_num_flag, + active_expert_range=active_expert_range, + quant_mode=quant_mode, + ) @staticmethod def npu_dynamic_quant( diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 081b102c..00b2f41b 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -18,7 +18,6 @@ import torch import torch_npu from torch.nn.functional import pad -from vllm.forward_context import get_forward_context from vllm.triton_utils import HAS_TRITON from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType @@ -331,26 +330,15 @@ def unquant_apply_mlp( w1 = w1.transpose(1, 2) w2 = w2.transpose(1, 2) - # In the small batch scenario, use _C_ascend.moe_grouped_matmul - if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num: - gate_up_out = torch.ops._C_ascend.moe_grouped_matmul( - x=hidden_states, - weight=w1, - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] - else: - gate_up_out = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None, - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] + gate_up_out = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + bias=[w1_bias.to(dtype=torch.float32)] if w1_bias is not None else None, + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] if activation == "swigluoai": num_experts, _, hidden_size = w1.shape @@ -361,26 +349,15 @@ def unquant_apply_mlp( if topk_scales is not None: gate_up_out *= topk_scales - # In the small batch scenario, use _C_ascend.moe_grouped_matmul - if group_list.dim() == 2 and get_forward_context().num_tokens <= DeviceOperator.small_batch_gmm_batch_num: - hidden_states = torch.ops._C_ascend.moe_grouped_matmul( - x=gate_up_out, - weight=w2, - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] - else: - hidden_states = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - bias=[w2_bias.to(dtype=torch.float32)] if w2_bias is not None else None, - split_item=2, - group_list_type=group_list_type, - group_type=0, - group_list=group_list, - )[0] + hidden_states = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2], + bias=[w2_bias.to(dtype=torch.float32)] if w2_bias is not None else None, + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] return hidden_states