diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 12194ba18..74c6216ac 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -186,6 +186,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): assert not no_combine, "unsupported" + if apply_router_weight_on_input: + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + x = x * topk_weights.to(x.dtype) + topk_weights = torch.ones_like( + topk_weights, dtype=torch.float32 + ) # topk_weights must be FP32 (float32) + return ck_moe_2stages( x, layer.w13_weight,