Fuse two kernels of hidden states padding into quantization kernel (#9005)
Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
@@ -210,13 +210,13 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
||||||
"enable_flashinfer_mxfp4_moe", False
|
"enable_flashinfer_mxfp4_moe", False
|
||||||
)
|
)
|
||||||
|
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
|
||||||
if (
|
if (
|
||||||
self.quant_config is not None
|
self.quant_config is not None
|
||||||
and self.quant_config.get_name() == "mxfp4"
|
and self.quant_config.get_name() == "mxfp4"
|
||||||
and self.use_enable_flashinfer_mxfp4_moe
|
and self.use_enable_flashinfer_mxfp4_moe
|
||||||
):
|
):
|
||||||
hidden_size = round_up(hidden_size, 256)
|
hidden_size = round_up(hidden_size, 256)
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.quant_method.create_weights(
|
self.quant_method.create_weights(
|
||||||
layer=self,
|
layer=self,
|
||||||
num_experts=self.num_local_experts,
|
num_experts=self.num_local_experts,
|
||||||
@@ -796,13 +796,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
||||||
origin_hidden_states_dim = hidden_states.shape[-1]
|
origin_hidden_states_dim = hidden_states.shape[-1]
|
||||||
if self.hidden_size != origin_hidden_states_dim:
|
|
||||||
hidden_states = torch.nn.functional.pad(
|
|
||||||
hidden_states,
|
|
||||||
(0, self.hidden_size - origin_hidden_states_dim),
|
|
||||||
mode="constant",
|
|
||||||
value=0.0,
|
|
||||||
)
|
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
||||||
|
|||||||
@@ -570,8 +570,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.use_flashinfer:
|
if self.use_flashinfer:
|
||||||
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
x_quant, x_scale = mxfp8_quantize(
|
||||||
|
x, False, alignment=self.hidden_size
|
||||||
|
) # to mxfp8
|
||||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||||
|
assert x_quant.shape[-1] == self.hidden_size
|
||||||
|
|
||||||
top_k, router_logits = topk_output
|
top_k, router_logits = topk_output
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user