From dae9a80f43e8dce2b27f93134fa78909a9a9feef Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Thu, 21 Aug 2025 03:50:51 -0700 Subject: [PATCH] [fix] Fix mxfp4 weight loading bug with TP sharding in GPT-OSS (#9433) Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com> Signed-off-by: Xinyuan Tong Co-authored-by: Xinyuan Tong --- python/sglang/srt/entrypoints/openai/protocol.py | 4 ++-- python/sglang/srt/models/gpt_oss.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 9360993df..d36a7f80c 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -737,8 +737,8 @@ class ResponsesRequest(BaseModel): else: max_tokens = default_max_tokens - # Avoid exceed the context length by minus 1 token - max_tokens -= 1 + # Avoid exceed the context length by minus 2 token + max_tokens -= 2 # Get parameters with defaults temperature = self.temperature diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index f3734d735..f068c9d1b 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -16,6 +16,7 @@ """Inference-only GptOss model compatible with HuggingFace weights.""" import logging +import math from collections.abc import Iterable from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -788,18 +789,25 @@ class GptOssForCausalLM(nn.Module): moe_ep_size = get_moe_expert_parallel_world_size() intermediate_size = self.config.intermediate_size + assert ( + intermediate_size % mxfp4_block == 0 + ), f"{intermediate_size=} must be divisible by {mxfp4_block=}" intermediate_size_block = intermediate_size // mxfp4_block - per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size + per_rank_intermediate_size_block = math.ceil( + intermediate_size_block / moe_tp_size + ) per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block # Calculate common slicing bounds for current rank assert self.config.num_local_experts % moe_ep_size == 0 moe_num_global_experts = self.config.num_local_experts moe_num_local_experts = self.config.num_local_experts // moe_ep_size + moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size moe_tp_rank_end = min( (moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size ) + moe_ep_rank_start = moe_ep_rank * moe_num_local_experts moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts