fix: tmp revert gpt oss tp sharding on hopper (#9469)
This commit is contained in:
@@ -793,9 +793,12 @@ class GptOssForCausalLM(nn.Module):
|
||||
intermediate_size % mxfp4_block == 0
|
||||
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
per_rank_intermediate_size_block = math.ceil(
|
||||
intermediate_size_block / moe_tp_size
|
||||
)
|
||||
if _is_sm100_supported:
|
||||
per_rank_intermediate_size_block = math.ceil(
|
||||
intermediate_size_block / moe_tp_size
|
||||
)
|
||||
else:
|
||||
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
|
||||
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
||||
|
||||
# Calculate common slicing bounds for current rank
|
||||
|
||||
Reference in New Issue
Block a user