fix: tmp revert gpt oss tp sharding on hopper (#9469)
This commit is contained in:
@@ -793,9 +793,12 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
intermediate_size % mxfp4_block == 0
|
intermediate_size % mxfp4_block == 0
|
||||||
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
|
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
|
||||||
intermediate_size_block = intermediate_size // mxfp4_block
|
intermediate_size_block = intermediate_size // mxfp4_block
|
||||||
per_rank_intermediate_size_block = math.ceil(
|
if _is_sm100_supported:
|
||||||
intermediate_size_block / moe_tp_size
|
per_rank_intermediate_size_block = math.ceil(
|
||||||
)
|
intermediate_size_block / moe_tp_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
|
||||||
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
||||||
|
|
||||||
# Calculate common slicing bounds for current rank
|
# Calculate common slicing bounds for current rank
|
||||||
|
|||||||
Reference in New Issue
Block a user