Fuse two kernels of hidden states padding into quantization kernel (#9005)
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
@@ -570,8 +570,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor:
|
||||
if self.use_flashinfer:
|
||||
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_quant, x_scale = mxfp8_quantize(
|
||||
x, False, alignment=self.hidden_size
|
||||
) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
assert x_quant.shape[-1] == self.hidden_size
|
||||
|
||||
top_k, router_logits = topk_output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user