Cleanup MoE Refactor (#9223)
This commit is contained in:
@@ -573,6 +573,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||
|
||||
if self.use_flashinfer:
|
||||
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||
x_quant, x_scale = mxfp8_quantize(
|
||||
@@ -580,8 +583,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
assert x_quant.shape[-1] == self.hidden_size
|
||||
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||
|
||||
top_k, router_logits = topk_output
|
||||
top_k = topk_output.topk_config.top_k
|
||||
router_logits = topk_output.router_logits
|
||||
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
@@ -602,8 +607,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
None, # output2_scale_scalar
|
||||
layer.num_experts,
|
||||
top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
None, # n_group # TODO: support n_group
|
||||
None, # topk_group # TODO: support topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
||||
layer.num_local_experts, # local num experts
|
||||
|
||||
Reference in New Issue
Block a user