[fix] Fix mxfp4 triton MoE tp bug (#9473)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
@@ -111,9 +111,8 @@ class FusedMoE(torch.nn.Module):
|
||||
hidden_size: Input hidden state size of the transformer
|
||||
intermediate_size: Intermediate size of the experts
|
||||
params_dtype: Data type for the parameters.
|
||||
reduce_results: Whether to all all_reduce on the output of the layer
|
||||
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
||||
quant_config: Quantization configure.
|
||||
reduce_results: Whether to apply all_reduce on the output of the layer
|
||||
quant_config: Quantization configuration.
|
||||
inplace: suggestion to compute inplace (modify input activation).
|
||||
"""
|
||||
|
||||
@@ -182,9 +181,6 @@ class FusedMoE(torch.nn.Module):
|
||||
self.expert_map_cpu = torch.full(
|
||||
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
self.expert_map_cpu = torch.full(
|
||||
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
# Create a expert map for the local experts
|
||||
self.expert_map_cpu[
|
||||
self.moe_ep_rank
|
||||
|
||||
@@ -309,6 +309,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size, 64
|
||||
)
|
||||
elif has_triton_kernels:
|
||||
# TODO: this is a hack to make
|
||||
# intermediate_size_per_partition_after_pad the same as the
|
||||
# per_rank_intermediate_size during weight loading
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size, mxfp4_block
|
||||
)
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
|
||||
|
||||
@@ -793,12 +793,11 @@ class GptOssForCausalLM(nn.Module):
|
||||
intermediate_size % mxfp4_block == 0
|
||||
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
if _is_sm100_supported:
|
||||
per_rank_intermediate_size_block = math.ceil(
|
||||
intermediate_size_block / moe_tp_size
|
||||
)
|
||||
else:
|
||||
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
|
||||
|
||||
per_rank_intermediate_size_block = math.ceil(
|
||||
intermediate_size_block / moe_tp_size
|
||||
)
|
||||
|
||||
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
||||
|
||||
# Calculate common slicing bounds for current rank
|
||||
|
||||
Reference in New Issue
Block a user