[AMD] Fix Llama 4 Scout and Maverick accuracy issues on MI300X (#6274)
This commit is contained in:
@@ -186,6 +186,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
||||
assert not no_combine, "unsupported"
|
||||
if apply_router_weight_on_input:
|
||||
assert (
|
||||
topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
x = x * topk_weights.to(x.dtype)
|
||||
topk_weights = torch.ones_like(
|
||||
topk_weights, dtype=torch.float32
|
||||
) # topk_weights must be FP32 (float32)
|
||||
|
||||
return ck_moe_2stages(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
|
||||
Reference in New Issue
Block a user