fix the break in FlashInferFusedMoE (#10356)
Co-authored-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from sglang.srt.layers.moe import (
|
|||||||
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||||
CombineInput,
|
CombineInput,
|
||||||
StandardDispatcher,
|
StandardDispatcher,
|
||||||
|
StandardDispatchOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
@@ -981,8 +982,9 @@ class FlashInferFusedMoE(FusedMoE):
|
|||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
final_hidden_states = self.quant_method.apply_with_router_logits(
|
final_hidden_states = self.quant_method.apply_with_router_logits(
|
||||||
layer=self,
|
layer=self,
|
||||||
x=hidden_states,
|
dispatch_output=StandardDispatchOutput(
|
||||||
topk_output=topk_output,
|
hidden_states=hidden_states, topk_output=topk_output
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||||
|
|||||||
Reference in New Issue
Block a user