[Bugfix] Qwen3MoE aclrtMemcpy failed with NPUGraph (#10013)
This commit is contained in:
@@ -384,19 +384,83 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
import torch_npu
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
topk_weights, topk_ids, _ = dispatch_output.topk_output
|
||||
|
||||
output = moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
topk_output,
|
||||
self.moe_runner_config,
|
||||
original_dtype = x.dtype
|
||||
num_tokens = x.shape[0]
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
num_experts = layer.num_experts
|
||||
top_k = layer.top_k
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (
|
||||
torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
|
||||
.view(top_k, -1)
|
||||
.permute(1, 0)
|
||||
.contiguous()
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = (
|
||||
torch_npu.npu_moe_init_routing(
|
||||
x, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
|
||||
)
|
||||
)
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
expanded_expert_idx, num_experts
|
||||
)
|
||||
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
if layer.w13_weight.shape[-1] == layer.hidden_size:
|
||||
w13 = layer.w13_weight.transpose(1, 2)
|
||||
w2 = layer.w2_weight.transpose(1, 2)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w13],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
|
||||
# act_fn:
|
||||
if self.moe_runner_config.activation == "silu":
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
else:
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
|
||||
hidden_states = GeluAndMul()(hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
|
||||
return StandardCombineInput(hidden_states=final_hidden_states)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs) -> CombineInput:
|
||||
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
||||
|
||||
@@ -17,7 +17,11 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
from sglang.srt.utils import is_npu, set_weight_attrs
|
||||
|
||||
_is_npu = is_npu()
|
||||
if not _is_npu:
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
|
||||
Reference in New Issue
Block a user