[fix] Fix mxfp4 weight loading bug with TP sharding in GPT-OSS (#9433)

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
hlu1
2025-08-21 03:50:51 -07:00
committed by GitHub
parent e85cb1ce9d
commit dae9a80f43
2 changed files with 11 additions and 3 deletions

View File

@@ -737,8 +737,8 @@ class ResponsesRequest(BaseModel):
else: else:
max_tokens = default_max_tokens max_tokens = default_max_tokens
# Avoid exceed the context length by minus 1 token # Avoid exceed the context length by minus 2 token
max_tokens -= 1 max_tokens -= 2
# Get parameters with defaults # Get parameters with defaults
temperature = self.temperature temperature = self.temperature

View File

@@ -16,6 +16,7 @@
"""Inference-only GptOss model compatible with HuggingFace weights.""" """Inference-only GptOss model compatible with HuggingFace weights."""
import logging import logging
import math
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@@ -788,18 +789,25 @@ class GptOssForCausalLM(nn.Module):
moe_ep_size = get_moe_expert_parallel_world_size() moe_ep_size = get_moe_expert_parallel_world_size()
intermediate_size = self.config.intermediate_size intermediate_size = self.config.intermediate_size
assert (
intermediate_size % mxfp4_block == 0
), f"{intermediate_size=} must be divisible by {mxfp4_block=}"
intermediate_size_block = intermediate_size // mxfp4_block intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size per_rank_intermediate_size_block = math.ceil(
intermediate_size_block / moe_tp_size
)
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
# Calculate common slicing bounds for current rank # Calculate common slicing bounds for current rank
assert self.config.num_local_experts % moe_ep_size == 0 assert self.config.num_local_experts % moe_ep_size == 0
moe_num_global_experts = self.config.num_local_experts moe_num_global_experts = self.config.num_local_experts
moe_num_local_experts = self.config.num_local_experts // moe_ep_size moe_num_local_experts = self.config.num_local_experts // moe_ep_size
moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
moe_tp_rank_end = min( moe_tp_rank_end = min(
(moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size (moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
) )
moe_ep_rank_start = moe_ep_rank * moe_num_local_experts moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts