Fuse two kernels of hidden states padding into quantization kernel (#9005)

Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
fzyzcjy
2025-08-12 16:20:13 +08:00
committed by GitHub
parent 5438886c87
commit 5190ba7f42
2 changed files with 5 additions and 9 deletions

View File

@@ -570,8 +570,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
) -> torch.Tensor:
if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_quant, x_scale = mxfp8_quantize(
x, False, alignment=self.hidden_size
) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
assert x_quant.shape[-1] == self.hidden_size
top_k, router_logits = topk_output