[fix] Fix mxfp4 triton MoE tp bug (#9473)

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
hlu1
2025-08-23 01:48:40 -07:00
committed by GitHub
parent c9dd70fbde
commit ccd3fb946e
3 changed files with 14 additions and 12 deletions

View File

@@ -111,9 +111,8 @@ class FusedMoE(torch.nn.Module):
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
reduce_results: Whether to apply all_reduce on the output of the layer
quant_config: Quantization configuration.
inplace: suggestion to compute inplace (modify input activation).
"""
@@ -182,9 +181,6 @@ class FusedMoE(torch.nn.Module):
self.expert_map_cpu = torch.full(
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
)
self.expert_map_cpu = torch.full(
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
)
# Create a expert map for the local experts
self.expert_map_cpu[
self.moe_ep_rank

View File

@@ -309,6 +309,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, 64
)
elif has_triton_kernels:
# TODO: this is a hack to make
# intermediate_size_per_partition_after_pad the same as the
# per_rank_intermediate_size during weight loading
intermediate_size_per_partition_after_pad = round_up(
intermediate_size, mxfp4_block
)
self.intermediate_size = intermediate_size_per_partition_after_pad

View File

@@ -793,12 +793,11 @@ class GptOssForCausalLM(nn.Module):
intermediate_size % mxfp4_block == 0
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
intermediate_size_block = intermediate_size // mxfp4_block
if _is_sm100_supported:
per_rank_intermediate_size_block = math.ceil(
intermediate_size_block / moe_tp_size
)
else:
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
per_rank_intermediate_size_block = math.ceil(
intermediate_size_block / moe_tp_size
)
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
# Calculate common slicing bounds for current rank