# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import typing from collections.abc import Callable, Iterable import torch import torch.distributed as dist from torch import nn from transformers import GptOssConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import ( get_dp_group, get_ep_group, get_pcp_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.utils import rocm_unquantized_gemm from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv from vllm.v1.attention.backend import AttentionType from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, WeightsMapper, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) class OAIAttention(nn.Module): def __init__( self, config: GptOssConfig, quant_config: QuantizationConfig | None = None, cache_config: CacheConfig | None = None, prefix: str = "", ): super().__init__() self.layer_idx = extract_layer_index(prefix) self.head_dim = config.head_dim self.num_attention_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.hidden_size = config.hidden_size self.rotary_emb = get_rope( self.head_dim, max_position=config.max_position_embeddings, dtype=torch.float32, rope_parameters={ "rope_theta": config.rope_parameters["rope_theta"], "rope_type": "yarn", "factor": config.rope_parameters["factor"], "original_max_position_embeddings": config.rope_parameters[ "original_max_position_embeddings" ], "beta_fast": config.rope_parameters["beta_fast"], "beta_slow": config.rope_parameters["beta_slow"], "truncate": config.rope_parameters.get("truncate", True), }, is_neox_style=True, ) tp_size = get_tensor_model_parallel_world_size() self.sinks = torch.nn.Parameter( torch.empty(config.num_attention_heads // tp_size, requires_grad=False) ) self.q_size = self.num_attention_heads * self.head_dim // tp_size self.kv_size = self.num_key_value_heads * self.head_dim // tp_size self.scaling = self.head_dim**-0.5 self.qkv_proj = QKVParallelLinear( hidden_size=self.hidden_size, head_size=self.head_dim, total_num_heads=self.num_attention_heads, total_num_kv_heads=self.num_key_value_heads, bias=True, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( input_size=self.num_attention_heads * self.head_dim, output_size=self.hidden_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self.num_local_attention_heads = config.num_attention_heads // tp_size self.num_local_key_value_heads = config.num_key_value_heads // tp_size # Only apply sliding window to every other layer sliding_window = config.sliding_window if self.layer_idx % 2 == 0 else None self.attn = Attention( self.num_local_attention_heads, self.head_dim, self.scaling, num_kv_heads=self.num_local_key_value_heads, cache_config=cache_config, quant_config=quant_config, per_layer_sliding_window=sliding_window, attn_type=AttentionType.DECODER, prefix=f"{prefix}.attn", sinks=self.sinks, ) def forward( self, hidden_states: torch.Tensor, positions: torch.Tensor ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class MLPBlock(torch.nn.Module): def __init__( self, vllm_config: VllmConfig, layer_idx: int, prefix: str = "", ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe self.layer_idx = layer_idx self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts) assert config.intermediate_size % self.world_size == 0 self.experts = FusedMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, reduce_results=True, renormalize=True, quant_config=quant_config, prefix=f"{prefix}.experts", apply_router_weight_on_input=False, has_bias=True, activation="swigluoai", is_sequence_parallel=self.is_sequence_parallel, ) def forward(self, x: torch.Tensor) -> torch.Tensor: num_tokens = x.shape[0] if self.is_sequence_parallel: x = sequence_parallel_chunk(x) if current_platform.is_rocm(): g = rocm_unquantized_gemm( self, x[:, : self.hidden_size], self.router.weight, self.router.bias ) else: g = self.router(x) x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size] if self.is_sequence_parallel: x = tensor_model_parallel_all_gather(x.contiguous(), 0) x = x[:num_tokens] return x class TransformerBlock(torch.nn.Module): def __init__( self, vllm_config: VllmConfig, quant_config: QuantizationConfig, prefix: str = "", ): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config self.layer_idx = extract_layer_index(prefix) self.attn = OAIAttention( config, prefix=f"{prefix}.attn", quant_config=quant_config, cache_config=cache_config, ) self.mlp = MLPBlock(vllm_config, self.layer_idx, prefix=f"{prefix}.mlp") self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) def forward( self, hidden_states: torch.Tensor, positions: torch.Tensor, residual: torch.Tensor | None, ) -> torch.Tensor: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.attn(hidden_states, positions) # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) output = self.mlp(hidden_states) return output, residual @support_torch_compile class GptOssModel(nn.Module): def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", ): super().__init__() self.config = vllm_config.model_config.hf_config self.quant_config = vllm_config.quant_config self.parallel_config = vllm_config.parallel_config self.config.hidden_size = self.config.hidden_size self.embedding = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, ) self.start_layer, self.end_layer, self.layers = make_layers( self.config.num_hidden_layers, lambda prefix: TransformerBlock( vllm_config, prefix=prefix, quant_config=self.quant_config, ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(self.config.hidden_size, eps=1e-5) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], self.config.hidden_size ) self.aux_hidden_state_layers = tuple[int, ...]() def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embedding(input_ids) def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: x = inputs_embeds else: x = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None x = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] aux_hidden_states = [] for i in range(self.start_layer, self.end_layer): layer = self.layers[i] if i in self.aux_hidden_state_layers: aux_hidden_states.append(x if residual is None else x + residual) x, residual = layer(x, positions, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": x, "residual": residual}) x, _ = self.norm(x, residual) if len(aux_hidden_states) > 0: return x, aux_hidden_states return x def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, weight scales, activation scales # (param_name, weight_name, expert_id, shard_id) # NOTE: this is only used for quark. return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", num_experts=self.config.num_local_experts, num_redundant_experts=0, ) def _load_weights_mxfp4( self, ep_rank_end: int, ep_rank_start: int, heads_per_rank: int, head_start: int, weights: Iterable[tuple[str, torch.Tensor]], stacked_params_mapping: list[tuple[str, ...]], ) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() use_ep = self.parallel_config.enable_expert_parallel num_experts = self.config.num_local_experts # In MoE, we need to flatten the tensor parallel size across the data # parallel size when EP is disabled. tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp( tp_size=get_tensor_model_parallel_world_size(), dp_size=get_dp_group().world_size, dp_rank=get_dp_group().rank_in_group, pcp_size=get_pcp_group().world_size, pcp_rank=get_pcp_group().rank_in_group, ) intermediate_size = self.config.intermediate_size intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) per_rank_intermediate_size = ( per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE ) # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) for name, weight in weights: # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue if ".w13_weight_scale" in name: # Handle MLP gate and up projection weights scale if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader( param, narrow_weight, weight_name=name, shard_id=None, expert_id=None, ) loaded_params.add(name) continue elif ".w2_weight_scale" in name: # Handle MLP down projection weights if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[ ..., tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end // OCP_MX_BLOCK_SIZE, ] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader( param, narrow_weight, weight_name=name, shard_id=None, expert_id=None, ) loaded_params.add(name) continue elif ".w13_weight" in name: # Handle MLP gate and up projection weights # flat weight from (E, 2 * N, block_size, entry_per_block) # to (E, 2 * N, -1), shouldn't trigger copy for contiguous weight = weight.view( num_experts, 2 * intermediate_size, -1 ).contiguous() # Extract gate and up projection parts # since the weight is shuffled, we can slice directly if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader( param, narrow_weight, weight_name=name, shard_id=None, expert_id=None, ) loaded_params.add(name) continue elif ".w2_weight" in name: # Handle MLP down projection weights # same flatten here, but since 2 mx4 value are packed in 1 # uint8, divide by 2 weight = weight.view( num_experts, -1, intermediate_size // 2 ).contiguous() if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader( param, narrow_weight, weight_name=name, shard_id=None, expert_id=None, ) loaded_params.add(name) continue elif ".w13_bias" in name: # Handle MLP gate and up projection biases # Extract gate and up projection bias parts if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader( param, narrow_weight, weight_name=name, shard_id=None, expert_id=None, ) loaded_params.add(name) continue elif ".w2_bias" in name: # Handle MLP down projection bias param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if use_ep: weight = weight[ep_rank_start:ep_rank_end, ...] else: # (only load on rank 0 to avoid duplication) if tp_rank != 0: weight.zero_() weight_loader( param, weight, weight_name=name, shard_id=None, expert_id=None ) loaded_params.add(name) continue elif "sinks" in name: # Handle attention sinks (distributed across ranks) param = params_dict[name] narrow_weight = weight.narrow(0, head_start, heads_per_rank) param.data.copy_(narrow_weight) loaded_params.add(name) continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, weight) else: weight_loader(param, weight, shard_id) break else: # Handle all other weights with potential renaming if name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) loaded_params.add(name) return loaded_params def _load_weights_quark( self, ep_rank_end: int, ep_rank_start: int, heads_per_rank: int, head_start: int, weights: Iterable[tuple[str, torch.Tensor]], stacked_params_mapping: list[tuple[str, ...]], ) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() use_ep = self.parallel_config.enable_expert_parallel num_experts = self.config.num_local_experts if use_ep: tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() else: tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp( tp_size=get_tensor_model_parallel_world_size(), dp_size=get_dp_group().world_size, dp_rank=get_dp_group().rank_in_group, pcp_size=get_pcp_group().world_size, pcp_rank=get_pcp_group().rank_in_group, ) def _get_moe_weight_dtype(layer_id: int = 0) -> str | None: """Helper function to get MoE quantization weight dtype. Args: layer_id: Layer index to check (default 0, as all layers should have the same quantization method) Returns: Weight dtype string (e.g., "mxfp4", "fp8") or None if not available """ if hasattr(self.layers[layer_id].mlp.experts.quant_method, "weight_dtype"): return self.layers[layer_id].mlp.experts.quant_method.weight_dtype return None intermediate_size = self.config.intermediate_size moe_weight_dtype = _get_moe_weight_dtype(layer_id=0) if moe_weight_dtype == "mxfp4": # MXFP4 requires OCP_MX_BLOCK_SIZE alignment intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) per_rank_intermediate_size = ( per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE ) else: # FP8 and other formats don't need alignment per_rank_intermediate_size = cdiv(intermediate_size, tp_size) tp_rank_start = tp_rank * per_rank_intermediate_size tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if is_pp_missing_parameter(name, self): continue layer_id, expert_id, fused_name = None, None, None moe_quant_method = None if "experts" in name: parts = name.split(".") ids = [s for s in parts if s.isdigit()] # for amd-quark format that each expert is separated # need to extract the parameter name with experts fused. # example model: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8 if len(ids) == 2: layer_id, expert_id = int(ids[0]), int(ids[-1]) parts.pop(len(parts) - 1 - parts[::-1].index(str(expert_id))) fused_name = ".".join(parts) # for openai mxfp4 format that all experts are combined # no need to extract the parameter name with experts fused. # models: openai/gpt-oss-20b, openai/gpt-oss-120b elif len(ids) == 1: layer_id, expert_id = int(ids[0]), None fused_name = name else: raise NameError( f"Layer {name} contains more than 2 numeric indices. This is " "an unexpected condition. Please open an issue if encountered." ) moe_quant_method = _get_moe_weight_dtype(layer_id=layer_id) def kv_cache_scale_loader( quant_config: QuantizationConfig, name: str, params_dict: dict[str, typing.Any], weight: torch.Tensor, default_weight_loader: Callable[..., None], loaded_params: set[str], ) -> tuple[bool, set[str]]: """ Load KV cache output scales. Returns: Tuple of (bool, set): - bool: True if KV-cache scale was loaded into loaded_params - set: Updated set of loaded_params if True else the original set """ # load explicit cached KV output scale from quant_config if quant_config is not None and ( scale_name := quant_config.get_cache_scale(name) ): param = params_dict[scale_name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) if weight.numel() != 1: raise ValueError( f"KV cache scale '{scale_name}' is expected to be a " f"scalar, but got a tensor of shape {weight.shape}." ) # Ensure weight is a scalar before passing to loader. weight_loader(param, weight.flatten()[0]) loaded_params.add(scale_name) return True, loaded_params return False, loaded_params load_kv_cache_scale_completed, loaded_params = kv_cache_scale_loader( self.quant_config, name, params_dict, loaded_weight, default_weight_loader, loaded_params, ) if load_kv_cache_scale_completed: continue if ( all(key in name for key in ["input_scale", "mlp.experts"]) and expert_id is not None ): assert loaded_weight.numel() == 1 expert_data = params_dict[fused_name].data[expert_id] expert_data.copy_(loaded_weight) loaded_params.add(fused_name) continue # Unified handler for mxfp4 weights and scales elif moe_quant_method == "mxfp4" and any( name.endswith(suffix) for suffix in [ ".w13_weight_scale", ".w2_weight_scale", ".w13_weight", ".w2_weight", ] ): is_w13 = ".w13_" in name is_scale = "_scale" in name # Reshape weight for mxfp4 if needed (not for scales) if not is_scale and expert_id is None: if is_w13: if loaded_weight.dim() < 3: raise ValueError( f"Expected w13_weight to have at least 3 " f"dimensions, got shape " f"{loaded_weight.shape}" ) if loaded_weight.shape[0] != num_experts: raise ValueError( f"Expected w13_weight first dimension to be " f"{num_experts}, got " f"{loaded_weight.shape[0]}" ) loaded_weight = loaded_weight.view( num_experts, 2 * intermediate_size, -1 ).contiguous() else: if loaded_weight.dim() < 3: raise ValueError( f"Expected w2_weight to have at least 3 " f"dimensions, got shape " f"{loaded_weight.shape}" ) if loaded_weight.shape[0] != num_experts: raise ValueError( f"Expected w2_weight first dimension to be " f"{num_experts}, got " f"{loaded_weight.shape[0]}" ) loaded_weight = loaded_weight.view( num_experts, -1, intermediate_size // 2 ).contiguous() if use_ep: sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] else: if is_w13: if expert_id is None: sliced_weight = loaded_weight[ :, 2 * tp_rank_start : 2 * tp_rank_end, ... ] else: sliced_weight = loaded_weight[ 2 * tp_rank_start : 2 * tp_rank_end, ... ] else: if is_scale: sliced_weight = loaded_weight[ ..., tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end // OCP_MX_BLOCK_SIZE, ] else: sliced_weight = loaded_weight[ ..., tp_rank_start // 2 : tp_rank_end // 2 ] # NOTE(rob): because gpt-oss ckpt has "unique" structure with # fused gate_up_proj fused on disk, we cannot use the existing # weight loaders without added complexity, so just do the # direct load here. param = params_dict[fused_name] expert_data = param.data[expert_id] dim1 = sliced_weight.shape[0] dim2 = sliced_weight.shape[1] expert_data.data[:dim1, :dim2].copy_(sliced_weight) loaded_params.add(fused_name) continue elif name.endswith(".w13_weight") and moe_quant_method == "fp8": if use_ep: narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] else: if expert_id is None: narrow_weight = loaded_weight[ :, 2 * tp_rank_start : 2 * tp_rank_end, : ] else: narrow_weight = loaded_weight[ 2 * tp_rank_start : 2 * tp_rank_end, : ] assert fused_name is not None param = params_dict[fused_name] if expert_id is None: param.data.copy_(narrow_weight) else: param.data[expert_id].copy_(narrow_weight) loaded_params.add(fused_name) continue elif name.endswith(".w13_weight_scale") and moe_quant_method == "fp8": assert fused_name is not None param = params_dict[fused_name] # Check if this is per-channel or per-tensor scale if loaded_weight.numel() > 1 and loaded_weight.dim() == 1: if use_ep: narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = loaded_weight[ 2 * tp_rank_start : 2 * tp_rank_end ] else: narrow_weight = loaded_weight if expert_id is None: param.data.copy_(narrow_weight) else: param.data[expert_id].copy_(narrow_weight) loaded_params.add(fused_name) continue elif name.endswith(".w13_input_scale") and moe_quant_method == "fp8": assert fused_name is not None param = params_dict[fused_name] if expert_id is None: param.data.copy_(loaded_weight) else: param.data[expert_id].copy_(loaded_weight) loaded_params.add(fused_name) continue elif name.endswith(".w2_weight") and moe_quant_method == "fp8": if use_ep: narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] else: if expert_id is None: narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end] else: narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end] assert fused_name is not None param = params_dict[fused_name] if expert_id is None: param.data.copy_(narrow_weight) else: param.data[expert_id].copy_(narrow_weight) loaded_params.add(fused_name) continue elif name.endswith(".w2_weight_scale") and moe_quant_method == "fp8": assert fused_name is not None param = params_dict[fused_name] if use_ep: narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = loaded_weight if expert_id is None: param.data.copy_(narrow_weight) else: param.data[expert_id].copy_(narrow_weight) loaded_params.add(fused_name) continue # Unified handler for bias loading (w13_bias and w2_bias) elif name.endswith(".w13_bias") or name.endswith(".w2_bias"): is_w13_bias = name.endswith(".w13_bias") if use_ep: sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...] else: if is_w13_bias: if expert_id is None: sliced_weight = loaded_weight[ :, 2 * tp_rank_start : 2 * tp_rank_end ] else: sliced_weight = loaded_weight[ 2 * tp_rank_start : 2 * tp_rank_end ] else: sliced_weight = loaded_weight if tp_rank != 0: sliced_weight = sliced_weight.zero_() # NOTE(rob): because gpt-oss ckpt has "unique" structure with # fused gate_up_proj fused on disk, we cannot use the existing # weight loaders without added complexity, so just do the # direct load here. assert fused_name is not None param = params_dict[fused_name] expert_data = param.data[expert_id] dim1 = sliced_weight.shape[0] expert_data.data[:dim1].copy_(sliced_weight) loaded_params.add(fused_name) continue elif "sinks" in name: # Handle attention sinks (distributed across ranks) param = params_dict[name] narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank) param.data.copy_(narrow_weight) loaded_params.add(name) continue for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) if name.endswith("scale"): # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) break else: for mapping in expert_params_mapping: # Anyway, this is an expert weight and should not be # attempted to load as other weights later param_name, weight_name, mapping_expert_id, shard_id = mapping weight_name = ( weight_name[:-1] if weight_name.endswith(".") else weight_name ) if weight_name not in name: continue param = params_dict[fused_name] # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. weight_loader = typing.cast( Callable[..., bool], param.weight_loader ) # Use checkpoint's expert_id for quark format (when expert_id # is extracted from weight name), otherwise use mapping's expert_id actual_expert_id = ( expert_id if expert_id is not None else mapping_expert_id ) success = weight_loader( param, loaded_weight, fused_name, shard_id=shard_id, expert_id=actual_expert_id, return_success=True, ) if success: name = fused_name loaded_params.add(name) break else: if name not in params_dict: continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params def _load_weights_other( self, ep_rank_end: int, ep_rank_start: int, heads_per_rank: int, head_start: int, weights: Iterable[tuple[str, torch.Tensor]], stacked_params_mapping: list[tuple[str, ...]], ) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() use_ep = self.parallel_config.enable_expert_parallel # In MoE, we need to flatten the tensor parallel size across the data # parallel size when EP is disabled. tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp( tp_size=get_tensor_model_parallel_world_size(), dp_size=get_dp_group().world_size, dp_rank=get_dp_group().rank_in_group, pcp_size=get_pcp_group().world_size, pcp_rank=get_pcp_group().rank_in_group, ) intermediate_size = self.config.intermediate_size per_rank_intermediate_size = cdiv(intermediate_size, tp_size) # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) for name, weight in weights: # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue if ".w13_weight" in name: # Handle MLP gate and up projection weights # Extract gate and up projection parts if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end] narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() param = params_dict[name] param.copy_(narrow_weight) loaded_params.add(name) continue elif ".w2_weight" in name: # Handle MLP down projection weights if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() param = params_dict[name] param.copy_(narrow_weight) loaded_params.add(name) continue elif ".w13_bias" in name: # Handle MLP gate and up projection biases # Extract gate and up projection bias parts if use_ep: narrow_weight = weight[ep_rank_start:ep_rank_end, ...] else: narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] param = params_dict[name] param.copy_(narrow_weight) loaded_params.add(name) continue elif ".w2_bias" in name: # Handle MLP down projection bias if use_ep: weight = weight[ep_rank_start:ep_rank_end, ...] else: # (only load on rank 0 to avoid duplication) if tp_rank != 0: weight.zero_() param = params_dict[name] param.copy_(weight) loaded_params.add(name) continue elif "sinks" in name: # Handle attention sinks (distributed across ranks) param = params_dict[name] narrow_weight = weight.narrow(0, head_start, heads_per_rank) param.data.copy_(narrow_weight) loaded_params.add(name) continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, weight) else: weight_loader(param, weight, shard_id) break else: # Handle all other weights with potential renaming if name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) loaded_params.add(name) return loaded_params def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), ] tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() # Attention heads per rank heads_per_rank = self.config.num_attention_heads // tp_size head_start = tp_rank * heads_per_rank ep_size = get_ep_group().world_size ep_rank = get_ep_group().rank num_experts = self.config.num_local_experts experts_per_rank = num_experts // ep_size ep_rank_start = ep_rank * experts_per_rank ep_rank_end = (ep_rank + 1) * experts_per_rank quant_method = ( self.config.quantization_config["quant_method"] if hasattr(self.config, "quantization_config") else None ) if quant_method == "mxfp4": return self._load_weights_mxfp4( ep_rank_end, ep_rank_start, heads_per_rank, head_start, weights, stacked_params_mapping, ) elif quant_method == "quark": return self._load_weights_quark( ep_rank_end, ep_rank_start, heads_per_rank, head_start, weights, stacked_params_mapping, ) else: return self._load_weights_other( ep_rank_end, ep_rank_start, heads_per_rank, head_start, weights, stacked_params_mapping, ) class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): is_3d_moe_weight: bool = True packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ ".self_attn.": ".attn.", }, orig_to_new_suffix={ ".embed_tokens.weight": ".embedding.weight", # MoE MXFP4 weights ".gate_up_proj_blocks": ".w13_weight", ".down_proj_blocks": ".w2_weight", ".gate_up_proj_scales": ".w13_weight_scale", ".down_proj_scales": ".w2_weight_scale", # MoE other weights ".gate_up_proj": ".w13_weight", ".down_proj": ".w2_weight", # MoE Bias ".gate_up_proj_bias": ".w13_bias", ".down_proj_bias": ".w2_bias", # For quark format ".gate_up_proj.weight": ".w13_weight", ".gate_up_proj.weight_scale": ".w13_weight_scale", ".gate_up_proj.bias": ".w13_bias", ".gate_up_proj.input_scale": ".w13_input_scale", ".down_proj.weight": ".w2_weight", ".down_proj.weight_scale": ".w2_weight_scale", ".down_proj.bias": ".w2_bias", ".down_proj.input_scale": ".w2_input_scale", }, ) def __init__( self, vllm_config: VllmConfig, prefix: str = "", ): super().__init__() self.vllm_config = vllm_config self.config = vllm_config.model_config.hf_config self.model = GptOssModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"), ) self.lm_head = ParallelLMHead( self.config.vocab_size, self.config.hidden_size, prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(self.config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: return self.model(input_ids, positions, intermediate_tensors, inputs_embeds) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)