[fix] Fix mxfp4 triton MoE tp bug (#9473)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
@@ -111,9 +111,8 @@ class FusedMoE(torch.nn.Module):
|
|||||||
hidden_size: Input hidden state size of the transformer
|
hidden_size: Input hidden state size of the transformer
|
||||||
intermediate_size: Intermediate size of the experts
|
intermediate_size: Intermediate size of the experts
|
||||||
params_dtype: Data type for the parameters.
|
params_dtype: Data type for the parameters.
|
||||||
reduce_results: Whether to all all_reduce on the output of the layer
|
reduce_results: Whether to apply all_reduce on the output of the layer
|
||||||
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
quant_config: Quantization configuration.
|
||||||
quant_config: Quantization configure.
|
|
||||||
inplace: suggestion to compute inplace (modify input activation).
|
inplace: suggestion to compute inplace (modify input activation).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -182,9 +181,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.expert_map_cpu = torch.full(
|
self.expert_map_cpu = torch.full(
|
||||||
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
||||||
)
|
)
|
||||||
self.expert_map_cpu = torch.full(
|
|
||||||
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
|
||||||
)
|
|
||||||
# Create a expert map for the local experts
|
# Create a expert map for the local experts
|
||||||
self.expert_map_cpu[
|
self.expert_map_cpu[
|
||||||
self.moe_ep_rank
|
self.moe_ep_rank
|
||||||
|
|||||||
@@ -309,6 +309,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
intermediate_size_per_partition_after_pad = round_up(
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
intermediate_size, 64
|
intermediate_size, 64
|
||||||
)
|
)
|
||||||
|
elif has_triton_kernels:
|
||||||
|
# TODO: this is a hack to make
|
||||||
|
# intermediate_size_per_partition_after_pad the same as the
|
||||||
|
# per_rank_intermediate_size during weight loading
|
||||||
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
|
intermediate_size, mxfp4_block
|
||||||
|
)
|
||||||
|
|
||||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||||
|
|
||||||
|
|||||||
@@ -793,12 +793,11 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
intermediate_size % mxfp4_block == 0
|
intermediate_size % mxfp4_block == 0
|
||||||
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
|
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
|
||||||
intermediate_size_block = intermediate_size // mxfp4_block
|
intermediate_size_block = intermediate_size // mxfp4_block
|
||||||
if _is_sm100_supported:
|
|
||||||
per_rank_intermediate_size_block = math.ceil(
|
per_rank_intermediate_size_block = math.ceil(
|
||||||
intermediate_size_block / moe_tp_size
|
intermediate_size_block / moe_tp_size
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
|
|
||||||
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
||||||
|
|
||||||
# Calculate common slicing bounds for current rank
|
# Calculate common slicing bounds for current rank
|
||||||
|
|||||||
Reference in New Issue
Block a user