fix the break in FlashInferFusedMoE (#10356)

Co-authored-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com>
This commit is contained in:
chenqianfzh
2025-09-11 23:47:48 -07:00
committed by GitHub
parent b4c2c421e9
commit 4aa39d72c4

View File

@@ -26,6 +26,7 @@ from sglang.srt.layers.moe import (
from sglang.srt.layers.moe.token_dispatcher.standard import ( from sglang.srt.layers.moe.token_dispatcher.standard import (
CombineInput, CombineInput,
StandardDispatcher, StandardDispatcher,
StandardDispatchOutput,
) )
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
@@ -981,8 +982,9 @@ class FlashInferFusedMoE(FusedMoE):
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply_with_router_logits( final_hidden_states = self.quant_method.apply_with_router_logits(
layer=self, layer=self,
x=hidden_states, dispatch_output=StandardDispatchOutput(
topk_output=topk_output, hidden_states=hidden_states, topk_output=topk_output
),
) )
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):