Fuse two kernels of hidden states padding into quantization kernel (#9005)
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
@@ -210,13 +210,13 @@ class FusedMoE(torch.nn.Module):
|
||||
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
||||
"enable_flashinfer_mxfp4_moe", False
|
||||
)
|
||||
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
|
||||
if (
|
||||
self.quant_config is not None
|
||||
and self.quant_config.get_name() == "mxfp4"
|
||||
and self.use_enable_flashinfer_mxfp4_moe
|
||||
):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
self.hidden_size = hidden_size
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
num_experts=self.num_local_experts,
|
||||
@@ -796,13 +796,6 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
||||
origin_hidden_states_dim = hidden_states.shape[-1]
|
||||
if self.hidden_size != origin_hidden_states_dim:
|
||||
hidden_states = torch.nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, self.hidden_size - origin_hidden_states_dim),
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
assert self.quant_method is not None
|
||||
|
||||
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
||||
|
||||
@@ -570,8 +570,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor:
|
||||
if self.use_flashinfer:
|
||||
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_quant, x_scale = mxfp8_quantize(
|
||||
x, False, alignment=self.hidden_size
|
||||
) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
assert x_quant.shape[-1] == self.hidden_size
|
||||
|
||||
top_k, router_logits = topk_output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user