[fix] Fix mxfp4 weight loading bug with TP sharding in GPT-OSS (#9433)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com> Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -737,8 +737,8 @@ class ResponsesRequest(BaseModel):
|
|||||||
else:
|
else:
|
||||||
max_tokens = default_max_tokens
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
# Avoid exceed the context length by minus 1 token
|
# Avoid exceed the context length by minus 2 token
|
||||||
max_tokens -= 1
|
max_tokens -= 2
|
||||||
|
|
||||||
# Get parameters with defaults
|
# Get parameters with defaults
|
||||||
temperature = self.temperature
|
temperature = self.temperature
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
"""Inference-only GptOss model compatible with HuggingFace weights."""
|
"""Inference-only GptOss model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
@@ -788,18 +789,25 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
moe_ep_size = get_moe_expert_parallel_world_size()
|
moe_ep_size = get_moe_expert_parallel_world_size()
|
||||||
|
|
||||||
intermediate_size = self.config.intermediate_size
|
intermediate_size = self.config.intermediate_size
|
||||||
|
assert (
|
||||||
|
intermediate_size % mxfp4_block == 0
|
||||||
|
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
|
||||||
intermediate_size_block = intermediate_size // mxfp4_block
|
intermediate_size_block = intermediate_size // mxfp4_block
|
||||||
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
|
per_rank_intermediate_size_block = math.ceil(
|
||||||
|
intermediate_size_block / moe_tp_size
|
||||||
|
)
|
||||||
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
||||||
|
|
||||||
# Calculate common slicing bounds for current rank
|
# Calculate common slicing bounds for current rank
|
||||||
assert self.config.num_local_experts % moe_ep_size == 0
|
assert self.config.num_local_experts % moe_ep_size == 0
|
||||||
moe_num_global_experts = self.config.num_local_experts
|
moe_num_global_experts = self.config.num_local_experts
|
||||||
moe_num_local_experts = self.config.num_local_experts // moe_ep_size
|
moe_num_local_experts = self.config.num_local_experts // moe_ep_size
|
||||||
|
|
||||||
moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
|
moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
|
||||||
moe_tp_rank_end = min(
|
moe_tp_rank_end = min(
|
||||||
(moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
|
(moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
|
||||||
)
|
)
|
||||||
|
|
||||||
moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
|
moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
|
||||||
moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
|
moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user