Files
sglang/python/sglang/srt/models/gpt_oss.py

1240 lines
45 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only GptOss model compatible with HuggingFace weights."""
import logging
from collections.abc import Iterable
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers
_is_cuda = is_cuda()
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg
class GptOssConfig(PretrainedConfig):
model_type = "gpt_oss"
def __init__(self, **kwargs):
super().__init__(**kwargs)
logger = logging.getLogger(__name__)
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def get_attention_sliding_window_size(config):
return config.sliding_window - 1
class GptOssSparseMoeBlock(nn.Module):
def __init__(
self,
layer_id: int,
config: GptOssConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
self.activation = config.hidden_act
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
self.swiglu_limit = config.swiglu_limit
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
self.topk = None
else:
self.topk = TopK(
top_k=config.num_experts_per_tok,
renormalize=True,
)
self.top_k = config.num_experts_per_tok
experts_type = get_moe_impl_class()
extra_kwargs = {}
if experts_type.__name__ == "FusedMoE":
quant_config_name = (
quant_config.get_name() if quant_config is not None else None
)
extra_kwargs = {
"enable_flashinfer_cutlass_moe": global_server_args_dict[
"enable_flashinfer_cutlass_moe"
],
# for moe gate_up_proj and down_proj and their bias loading
"use_weight_loader_fused": quant_config_name != "mxfp4",
}
self.experts = experts_type(
num_experts=config.num_local_experts
+ global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok,
layer_id=layer_id,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
activation=self.activation,
activation_alpha=self.activation_alpha,
swiglu_limit=self.swiglu_limit,
with_bias=True,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
**extra_kwargs,
)
self.router = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=True,
quant_config=None,
prefix=add_prefix("gate", prefix),
params_dtype=config.torch_dtype,
)
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
should_allreduce_fusion: bool = False,
) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
return self.forward_normal(hidden_states, should_allreduce_fusion)
else:
raise Exception("forward_deepep branch not implemented yet")
def get_moe_weights(self):
return [
x.data
for name, x in self.experts.named_parameters()
if name not in ["correction_bias"]
]
def forward_normal(
self,
hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False,
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["topk_output"] = (self.top_k, router_logits)
final_hidden_states = self.experts(**kwargs)
if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
ans = final_hidden_states.view(num_tokens, hidden_dim)
return ans
def _enable_fused_set_kv_buffer():
return _is_cuda
# TODO maybe move to a model-common utils
def _create_fused_set_kv_buffer_arg(
value: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
):
layer_id = layer.layer_id
token_to_kv_pool = forward_batch.token_to_kv_pool
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
return FusedSetKVBufferArg(
value=value,
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
k_scale=layer.k_scale,
v_scale=layer.v_scale,
cache_loc=forward_batch.out_cache_loc,
)
class GptOssAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
attention_bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
sliding_window_size: int = -1, # if -1, normal attention, else, window attention.
layer_type: str = "",
params_dtype: torch.dtype = torch.bfloat16,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.sliding_window_size = sliding_window_size
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.total_num_heads = num_heads
assert self.total_num_heads % attn_tp_size == 0
self.num_heads = self.total_num_heads // attn_tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= attn_tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % attn_tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert attn_tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.tp_rank = get_tensor_model_parallel_rank()
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=attention_bias,
params_dtype=params_dtype,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("qkv_proj", prefix),
)
self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=torch.float32), requires_grad=False
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=attention_bias,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
reduce_results=False,
params_dtype=params_dtype,
prefix=add_prefix("o_proj", prefix),
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
assert layer_type in {"sliding_attention", "full_attention"}
use_sliding_window = layer_type == "sliding_attention"
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
sliding_window_size=(sliding_window_size if use_sliding_window else -1),
)
self.layer_id = layer_id
def forward_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
if hidden_states.shape[0] == 0:
return hidden_states, forward_batch, None
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
_create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if _enable_fused_set_kv_buffer()
else None
),
)
inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state
def forward_core(self, intermediate_state):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(
*inner_state,
sinks=self.sinks,
save_kv_cache=not _enable_fused_set_kv_buffer(),
)
output, _ = self.o_proj(attn_output)
return output
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
s = self.forward_prepare(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
return self.forward_core(s)
class GptOssDecoderLayer(nn.Module):
def __init__(
self,
config: GptOssConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
sliding_window_size: int | None = None,
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
rms_norm_eps = config.rms_norm_eps
attention_bias = config.attention_bias
if sliding_window_size is None:
self.sliding_window_size = get_attention_sliding_window_size(self.config)
else:
self.sliding_window_size = sliding_window_size
self.self_attn = GptOssAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
prefix=add_prefix("self_attn", prefix),
sliding_window_size=self.sliding_window_size,
layer_type=config.layer_types[layer_id],
params_dtype=config.torch_dtype,
)
self.layer_id = layer_id
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# GptOss all layers are sparse and have no nextn now
self.is_layer_sparse = True
self.is_nextn = False
is_previous_layer_sparse = True
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
if self.is_layer_sparse:
self.mlp = GptOssSparseMoeBlock(
layer_id=self.layer_id,
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
else:
raise NotImplementedError(
"Dense MLP is not implemented for GptOssDecoderLayer. "
"Please use GptOssSparseMoeBlock instead."
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
batch_size = (
forward_batch.input_ids.shape[0]
if hasattr(forward_batch, "input_ids")
else 0
)
if batch_size > 128:
return False
return self._fuse_allreduce_lookup_table.get(batch_size, False)
def _build_fuse_allreduce_lookup_table(self):
static_conditions_met = (
self.layer_id != self.config.num_hidden_layers - 1
and get_tensor_model_parallel_world_size() > 1
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and _is_sm100_supported
and _is_flashinfer_available
)
if not static_conditions_met:
return {}
lookup_table = {}
for batch_size in range(129): # 0 to 128
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
lookup_table[batch_size] = should_fuse
return lookup_table
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
if hidden_states.shape[0] != 0:
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
should_allreduce_fusion = (
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
and not self.is_nextn
)
hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
if should_allreduce_fusion:
hidden_states._sglang_needs_allreduce_fusion = True
if not should_allreduce_fusion:
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
class GptOssModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
decoder_layer_type: type[nn.Module] = GptOssDecoderLayer,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
prefix=add_prefix("embed_tokens", prefix),
)
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to GptOssDecoderLayer
decoder_layer_type = decoder_layer_type or GptOssDecoderLayer
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: decoder_layer_type(
layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
self.layers_to_capture = []
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer):
with get_global_expert_distribution_recorder().with_current_layer(i):
if i in self.layers_to_capture:
aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class GptOssForCausalLM(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: GptOssConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.pp_group = get_pp_group()
self.config = config
self.quant_config = quant_config
self.model = GptOssModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
# quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids,
hidden_states,
self.lm_head,
forward_batch,
aux_hidden_states,
)
else:
return hidden_states
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def _get_default_weight_mapping(self):
"""Generate default weight name mapping for GptOss safetensors."""
weight_mapping = {}
# Map router weights to gate
weight_mapping["embedding.weight"] = "model.embed_tokens.weight"
weight_mapping["unembedding.weight"] = "lm_head.weight"
weight_mapping["norm.scale"] = "model.norm.weight"
for layer_id in range(self.config.num_hidden_layers):
weight_mapping[f"block.{layer_id}.attn.q_proj.weight"] = (
f"model.layers.{layer_id}.self_attn.q_proj.weight"
)
weight_mapping[f"block.{layer_id}.attn.q_proj.bias"] = (
f"model.layers.{layer_id}.self_attn.q_proj.bias"
)
weight_mapping[f"block.{layer_id}.attn.k_proj.weight"] = (
f"model.layers.{layer_id}.self_attn.k_proj.weight"
)
weight_mapping[f"block.{layer_id}.attn.k_proj.bias"] = (
f"model.layers.{layer_id}.self_attn.k_proj.bias"
)
weight_mapping[f"block.{layer_id}.attn.v_proj.weight"] = (
f"model.layers.{layer_id}.self_attn.v_proj.weight"
)
weight_mapping[f"block.{layer_id}.attn.v_proj.bias"] = (
f"model.layers.{layer_id}.self_attn.v_proj.bias"
)
weight_mapping[f"block.{layer_id}.attn.out.weight"] = (
f"model.layers.{layer_id}.self_attn.o_proj.weight"
)
weight_mapping[f"block.{layer_id}.attn.out.bias"] = (
f"model.layers.{layer_id}.self_attn.o_proj.bias"
)
weight_mapping[f"block.{layer_id}.attn.sinks"] = (
f"model.layers.{layer_id}.self_attn.sinks"
)
weight_mapping[f"block.{layer_id}.attn.norm.scale"] = (
f"model.layers.{layer_id}.input_layernorm.weight"
)
weight_mapping[f"block.{layer_id}.mlp.gate.weight"] = (
f"model.layers.{layer_id}.mlp.router.weight"
)
weight_mapping[f"block.{layer_id}.mlp.gate.bias"] = (
f"model.layers.{layer_id}.mlp.router.bias"
)
weight_mapping[f"block.{layer_id}.mlp.norm.scale"] = (
f"model.layers.{layer_id}.post_attention_layernorm.weight"
)
weight_mapping[f"block.{layer_id}.mlp.experts.gate_up_proj"] = (
f"model.layers.{layer_id}.mlp.experts.gate_up_proj"
)
weight_mapping[f"block.{layer_id}.mlp.gate_up_proj_bias"] = (
f"model.layers.{layer_id}.mlp.experts.gate_up_proj_bias"
)
weight_mapping[f"block.{layer_id}.mlp.down_proj"] = (
f"model.layers.{layer_id}.mlp.experts.mlp2_weight"
)
weight_mapping[f"block.{layer_id}.mlp.down_proj_bias"] = (
f"model.layers.{layer_id}.mlp.experts.mlp2_bias"
)
return weight_mapping
# TODO beautify code
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
is_nextn: bool = False,
weight_name_mapping: dict = None,
):
quant_config_name = (
self.quant_config.get_name() if self.quant_config is not None else None
)
if quant_config_name != "mxfp4":
self._load_normal_weights(
weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
)
else:
self._load_weights_mxfp4(
weights, is_nextn=is_nextn, weight_name_mapping=weight_name_mapping
)
def _load_weights_mxfp4(self, weights, is_nextn, weight_name_mapping):
mxfp4_weights = []
normal_weights = []
for name, weight in weights:
if (
".experts" in name
and self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
):
mxfp4_weights.append((name, weight))
else:
normal_weights.append((name, weight))
mxfp4_loaded_params = self._load_mxfp4_experts_weights(mxfp4_weights)
self._load_normal_weights(
normal_weights,
is_nextn=is_nextn,
weight_name_mapping=weight_name_mapping,
other_loaded_param_names=mxfp4_loaded_params,
)
def _load_mxfp4_experts_weights(self, weights):
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
mxfp4_block = 32
moe_tp_rank = get_moe_tensor_parallel_rank()
moe_tp_size = get_moe_tensor_parallel_world_size()
moe_ep_rank = get_moe_expert_parallel_rank()
moe_ep_size = get_moe_expert_parallel_world_size()
intermediate_size = self.config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
# Calculate common slicing bounds for current rank
assert self.config.num_local_experts % moe_ep_size == 0
moe_num_global_experts = self.config.num_local_experts
moe_num_local_experts = self.config.num_local_experts // moe_ep_size
moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
moe_tp_rank_end = min(
(moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
)
moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
for name, weight in weights:
weight = weight.cuda()
if "gate_up_proj_blocks" in name:
# Handle MLP gate and up projection weights
new_name = name.replace("gate_up_proj_blocks", "w13_weight")
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view(
moe_num_global_experts, 2 * intermediate_size, -1
).contiguous()
narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
...,
]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
elif "down_proj_blocks" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_blocks", "w2_weight")
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight = weight.view(
moe_num_global_experts, -1, intermediate_size // 2
).contiguous()
narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
...,
moe_tp_rank_start // 2 : moe_tp_rank_end // 2,
]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
elif "gate_up_proj_scales" in name:
# Handle MLP gate and up projection weights scale
new_name = name.replace("gate_up_proj_scales", "w13_weight_scale")
narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
...,
]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
elif "down_proj_scales" in name:
# Handle MLP down projection weights
new_name = name.replace("down_proj_scales", "w2_weight_scale")
narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
...,
moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block,
]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
elif "gate_up_proj_bias" in name:
# Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_weight_bias")
narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
elif "down_proj_bias" in name:
narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...]
if moe_tp_rank != 0:
narrow_weight = torch.zeros_like(narrow_weight)
# Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_weight_bias")
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
return loaded_params
def _load_normal_weights(
self,
weights,
is_nextn: bool,
weight_name_mapping: dict,
other_loaded_param_names=[],
):
tp_rank = get_tensor_model_parallel_rank()
if is_nextn:
logging.warning(
"Loading weights for nextn is currently not supported in GptOssForCausalLM. "
)
return
weights = _canonicalize_weights(self.config, weights)
weights = sorted(weights, key=lambda x: x[0]) # Sort by name for consistency
new_weights = []
for name, p in weights:
if "qkv.weight" in name:
q_proj, k_proj, v_proj = p.split(
[
self.config.num_attention_heads * self.config.head_dim,
self.config.num_key_value_heads * self.config.head_dim,
self.config.num_key_value_heads * self.config.head_dim,
],
dim=0,
)
new_weights.append(
(f"{name.replace('qkv.weight', 'q_proj.weight')}", q_proj)
)
new_weights.append(
(f"{name.replace('qkv.weight', 'k_proj.weight')}", k_proj)
)
new_weights.append(
(f"{name.replace('qkv.weight', 'v_proj.weight')}", v_proj)
)
elif "qkv.bias" in name:
q_bias, k_bias, v_bias = p.split(
[
self.config.num_attention_heads * self.config.head_dim,
self.config.num_key_value_heads * self.config.head_dim,
self.config.num_key_value_heads * self.config.head_dim,
],
dim=0,
)
new_weights.append(
(f"{name.replace('qkv.bias', 'q_proj.bias')}", q_bias)
)
new_weights.append(
(f"{name.replace('qkv.bias', 'k_proj.bias')}", k_bias)
)
new_weights.append(
(f"{name.replace('qkv.bias', 'v_proj.bias')}", v_bias)
)
else:
new_weights.append((name, p))
weights = new_weights
# Use provided weight name mapping if available, otherwise use default
if weight_name_mapping is None:
weight_name_mapping = self._get_default_weight_mapping()
else:
# Merge with default mapping
default_mapping = self._get_default_weight_mapping()
default_mapping.update(weight_name_mapping)
weight_name_mapping = default_mapping
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
ckpt_gate_up_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
ckpt_down_proj_bias_name="down_proj_bias",
)
params_dict = dict(self.named_parameters())
params_checker = {k: False for k, v in params_dict.items()}
for other_loaded_param_name in other_loaded_param_names:
params_checker[other_loaded_param_name] = True
for name, loaded_weight in weights:
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
# Apply weight name mapping if provided
if weight_name_mapping and name in weight_name_mapping:
name = weight_name_mapping[name]
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
params_checker[name] = True
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
if "bias" not in name:
loaded_weight = loaded_weight.transpose(-2, -1)
if "w2_weight_bias" in name and get_moe_tensor_parallel_rank() != 0:
loaded_weight = loaded_weight.zero_()
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
)
params_checker[name] = True
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
if "sinks" in name:
start = tp_rank * param.numel()
param.data.copy_(
loaded_weight[start : start + param.numel()]
)
else:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
params_checker[name] = True
else:
logger.warning(f"Parameter {name} not found in params_dict")
not_loaded_params = [k for k, v in params_checker.items() if not v]
if tp_rank == 0:
if len(not_loaded_params) > 0:
raise Exception(f"Not all parameters loaded: {not_loaded_params}")
else:
logging.info("All parameters loaded successfully.")
self.routed_experts_weights_of_layer = {
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
for layer_id in range(self.start_layer, self.end_layer)
if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
}
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return
if layer_ids is None:
self.capture_aux_hidden_states = True
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
else:
self.capture_aux_hidden_states = True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self.model.layers_to_capture = [val + 1 for val in layer_ids]
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.num_local_experts,
num_groups=None,
)
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)
def _canonicalize_weights(config, weights_in: Iterable[Tuple[str, torch.Tensor]]):
weights_out_dict = dict(weights_in)
for layer_id in range(config.num_hidden_layers):
for name_chunk in ["mlp1_weight", "mlp2_weight"]:
name_prefix = f"block.{layer_id}.mlp.{name_chunk}"
w_blocks = weights_out_dict.pop(f"{name_prefix}.blocks", None)
w_scales = weights_out_dict.pop(f"{name_prefix}.scales", None)
if w_blocks is not None:
weights_out_dict[name_prefix] = _WeightCreator(
partial(
_dequant_mlp_weight,
debug_name=name_prefix,
w_blocks=w_blocks,
w_scales=w_scales,
)
)
return list(weights_out_dict.items())
def _dequant_mlp_weight(debug_name, w_blocks, w_scales):
if get_tensor_model_parallel_rank() == 0:
logger.info(f"Dequantize {debug_name} start")
original_device = w_blocks.device
w_blocks = w_blocks.cuda()
w_scales = w_scales.cuda()
w_bf16 = dequant_mxfp4(w_block=w_blocks, w_scale=w_scales, out_dtype=torch.bfloat16)
w_bf16 = w_bf16.transpose(-2, -1).contiguous()
if get_tensor_model_parallel_rank() == 0:
logger.info(
f"Dequantize {debug_name} end {w_blocks.shape=} {w_scales.shape=} {w_bf16.shape=}"
)
return w_bf16.to(original_device)
class _WeightCreator:
def __init__(self, fn):
self._fn = fn
@staticmethod
def maybe_materialize(obj):
if isinstance(obj, _WeightCreator):
output = obj._fn()
obj._fn = None
return output
return obj
EntryClass = GptOssForCausalLM