fix: tmp revert gpt oss tp sharding on hopper (#9469)

This commit is contained in:
Yineng Zhang
2025-08-21 17:03:21 -07:00
committed by GitHub
parent cded039b57
commit 849957bc76

View File

@@ -793,9 +793,12 @@ class GptOssForCausalLM(nn.Module):
intermediate_size % mxfp4_block == 0
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = math.ceil(
intermediate_size_block / moe_tp_size
)
if _is_sm100_supported:
per_rank_intermediate_size_block = math.ceil(
intermediate_size_block / moe_tp_size
)
else:
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
# Calculate common slicing bounds for current rank