[Bugfix] Qwen3MoE aclrtMemcpy failed with NPUGraph (#10013)

This commit is contained in:
Even Zhou
2025-09-08 12:50:49 +08:00
committed by GitHub
parent 8116804e4f
commit b67c277f86
5 changed files with 180 additions and 10 deletions

View File

@@ -384,19 +384,83 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
import torch_npu
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = dispatch_output.topk_output
output = moe_forward_native(
layer,
x,
topk_output,
self.moe_runner_config,
original_dtype = x.dtype
num_tokens = x.shape[0]
topk_weights = topk_weights.to(x.dtype)
topk_ids = topk_ids.to(torch.int32)
num_experts = layer.num_experts
top_k = layer.top_k
row_idx_len = num_tokens * top_k
row_idx = (
torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
.view(top_k, -1)
.permute(1, 0)
.contiguous()
)
return StandardCombineInput(hidden_states=output)
hidden_states, expanded_row_idx, expanded_expert_idx = (
torch_npu.npu_moe_init_routing(
x, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
)
)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts
)
expert_tokens = expert_tokens.to(torch.int64)
if layer.w13_weight.shape[-1] == layer.hidden_size:
w13 = layer.w13_weight.transpose(1, 2)
w2 = layer.w2_weight.transpose(1, 2)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w13],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
output_dtype=original_dtype,
)[0]
# act_fn:
if self.moe_runner_config.activation == "silu":
hidden_states = torch_npu.npu_swiglu(hidden_states)
else:
from sglang.srt.layers.activation import GeluAndMul
hidden_states = GeluAndMul()(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
output_dtype=original_dtype,
)[0]
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
return StandardCombineInput(hidden_states=final_hidden_states)
def forward_tpu(self, *args, **kwargs) -> CombineInput:
raise NotImplementedError("The TPU backend currently does not support MoE.")

View File

@@ -17,7 +17,11 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import is_npu, set_weight_attrs
_is_npu = is_npu()
if not _is_npu:
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig