From 16ff3d4b05766a44ee821c55bcb66cd0591f4569 Mon Sep 17 00:00:00 2001 From: wenhuipeng <75769315+wenhuipeng@users.noreply.github.com> Date: Tue, 9 Sep 2025 12:45:00 +0800 Subject: [PATCH] Support opt model (#10165) --- python/sglang/srt/models/opt.py | 637 ++++++++++++++++++++++ test/srt/models/test_generation_models.py | 1 + 2 files changed, 638 insertions(+) create mode 100644 python/sglang/srt/models/opt.py diff --git a/python/sglang/srt/models/opt.py b/python/sglang/srt/models/opt.py new file mode 100644 index 000000000..a571e8937 --- /dev/null +++ b/python/sglang/srt/models/opt.py @@ -0,0 +1,637 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Inference-only OPT model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import OPTConfig + +from sglang.srt.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.utils import add_prefix, make_layers + + +def get_activation(name="relu"): + """Select an activation function by name + + Args: + name: str + activation function name, + one of ["relu", "gelu", "swish", "sigmoid"], + default "relu". + """ + name = name.lower() + if name == "relu": + return nn.ReLU() + if name == "gelu": + return nn.GELU() + if name == "sigmoid": + return torch.nn.Sigmoid() + return nn.Identity() + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the + # embedding ids by 2 and adjust num_embeddings appropriately. Other + # models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, positions: torch.Tensor): + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + layer_id: int = 0, + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.embed_dim = embed_dim + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + total_num_heads = num_heads + assert num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = embed_dim // total_num_heads + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + embed_dim, + self.head_dim, + total_num_heads, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + ) + + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_heads, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + def forward( + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.out_proj(attn_output) + return output + + +class OPTDecoderLayer(nn.Module): + + def __init__( + self, + config: OPTConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.self_attn = OPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + layer_id=layer_id, + bias=config.enable_bias, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + self.do_layer_norm_before = config.do_layer_norm_before + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) + self.fc1 = ColumnParallelLinear( + self.embed_dim, + config.ffn_dim, + bias=config.enable_bias, + quant_config=quant_config, + prefix=add_prefix("fc1", prefix), + ) + self.activation_fn = get_activation(config.activation_function) + self.fc2 = RowParallelLinear( + config.ffn_dim, + self.embed_dim, + bias=config.enable_bias, + quant_config=quant_config, + prefix=add_prefix("fc2", prefix), + ) + self.final_layer_norm = nn.LayerNorm( + self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine + ) + + def forward( + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, forward_batch=forward_batch + ) + hidden_states = residual + hidden_states + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + hidden_states = residual + hidden_states + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class OPTDecoder(nn.Module): + + def __init__( + self, + config: OPTConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.pp_group = get_pp_group() + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.word_embed_proj_dim, + prefix=add_prefix("embed_tokens", prefix), + ) + # Positional embeddings are replicated (not sharded). + self.embed_positions = OPTLearnedPositionalEmbedding( + config.max_position_embeddings, config.hidden_size + ) + + # Project out & in will be replicated if they exist. + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = ReplicatedLinear( + config.hidden_size, + config.word_embed_proj_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("project_out", prefix), + ) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = ReplicatedLinear( + config.word_embed_proj_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("project_in", prefix), + ) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to + # keep backward compatibility with checkpoints that have been fine-tuned + # before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, + elementwise_affine=config.layer_norm_elementwise_affine, + ) + else: + self.final_layer_norm = None + + self.layers, self.start_layer, self.end_layer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: OPTDecoderLayer( + config=config, layer_id=idx, quant_config=quant_config, prefix=prefix + ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, + prefix="model.layers", + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + input_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + if self.pp_group.is_first_rank: + if input_embeds is None: + input_embeds = self.embed_tokens(input_ids) + pos_embeds = self.embed_positions(positions) + if self.project_in is not None: + input_embeds, _ = self.project_in(input_embeds) + hidden_states = input_embeds + pos_embeds + else: + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + + for layer in self.layers[self.start_layer : self.end_layer]: + hidden_states = layer( + hidden_states=hidden_states, forward_batch=forward_batch + ) + if not self.pp_group.is_last_rank: + return PPProxyTensors({"hidden_states": hidden_states}) + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + # 没有经过这里 + if self.project_out is not None: + hidden_states, _ = self.project_out(hidden_states) + return hidden_states + + +class OPTModel(nn.Module): + + def __init__( + self, + config: OPTConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + # config = vllm_config.model_config.hf_config + # quant_config = vllm_config.quant_config + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.pp_group = get_pp_group() + + self.decoder = OPTDecoder( + config=config, + quant_config=quant_config, + prefix=add_prefix("decoder", prefix), + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + pp_proxy_tensors: Optional[PPProxyTensors], + input_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + return self.decoder( + input_ids, + positions, + pp_proxy_tensors=pp_proxy_tensors, + input_embeds=input_embeds, + forward_batch=forward_batch, + ) + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.decoder.layers[layer_idx], nn.Identity): + layer_self_attn = self.decoder.layers[layer_idx].self_attn + + if hasattr(layer_self_attn.attn, "k_scale"): + layer_self_attn.attn.k_scale = scaling_factor + layer_self_attn.attn.v_scale = scaling_factor + else: + raise RuntimeError( + "Self attention has no KV cache scaling " "factor attribute!" + ) + + +class OPTForCausalLM(nn.Module): + # BitandBytes specific attributes + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] + + def __init__( + self, + config: OPTConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.quant_config = quant_config + + self.model = OPTModel( + config=config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + if self.config.tie_word_embeddings: + self.lm_head = self.model.decoder.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.word_embed_proj_dim, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.capture_aux_hidden_states = False + self.pp_group = get_pp_group() + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + input_embeds: Optional[torch.Tensor] = None, + get_embedding: bool = False, + ) -> LogitsProcessorOutput: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=input_embeds, + pp_proxy_tensors=pp_proxy_tensors, + ) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + + if self.pp_group.is_last_rank: + if not get_embedding: + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + aux_hidden_states=aux_hidden_states, + ) + else: + return self.pooler(hidden_states, forward_batch) + else: + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + for name, loaded_weight in weights: + if name.startswith("decoder"): + name = name.replace("decoder.", "model.decoder.") + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # if is_pp_missing_parameter(name, self): + # continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # if is_pp_missing_parameter(name, self): + # continue + if name not in params_dict: + continue + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in params_dict") + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + + def get_module_name_from_weight_name(self, name): + for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: + if weight_name in name: + return ( + name.replace(weight_name, param_name)[: -len(".weight")], + num_shard, + ) + return name[: -len(".weight")], 1 + + def get_num_params(self): + params_dict = dict(self.named_parameters()) + return len(params_dict) + + def get_weights_by_name( + self, name: str, truncate_size: int = 100, tp_size: int = 1 + ) -> Optional[torch.Tensor]: + """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face. + + Only used for unit test with an unoptimized performance. + For optimized performance, please use torch.save and torch.load. + """ + try: + if name == "lm_head.weight" and self.config.tie_word_embeddings: + logger.info( + "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight." + ) + return ( + self.model.embed_tokens.weight.cpu() + .to(torch.float32) + .numpy() + .tolist()[:truncate_size] + ) + + mapped_name = name + mapped_shard_id = None + for param_name, weight_name, shard_id in self.stacked_params_mapping: + if weight_name in name: + mapped_name = name.replace(weight_name, param_name) + mapped_shard_id = shard_id + break + params_dict = dict(self.named_parameters()) + param = params_dict[mapped_name] + if mapped_shard_id is not None: + if mapped_shard_id in ["q", "k", "v"]: + num_heads = self.config.num_attention_heads // tp_size + num_kv_heads = self.config.num_attention_heads // tp_size + head_dim = ( + self.config.hidden_size // self.config.num_attention_heads + ) + if mapped_shard_id == "q": + offset = 0 + size = num_heads * head_dim + elif mapped_shard_id == "k": + offset = num_heads * head_dim + size = num_kv_heads * head_dim + elif mapped_shard_id == "v": + offset = (num_heads + num_kv_heads) * head_dim + size = num_kv_heads * head_dim + weight = param.data.narrow(0, offset, size) + elif mapped_shard_id in [0, 1]: + intermediate_size = self.config.ffn_dim + slice_size = intermediate_size // tp_size + if mapped_shard_id == 0: # gate_proj + offset = 0 + size = slice_size + elif mapped_shard_id == 1: # up_proj + offset = slice_size + size = slice_size + + weight = param.data.narrow(0, offset, size) + else: + weight = param.data + else: + weight = param.data + if tp_size > 1 and ("o_proj" in name or "down_proj" in name): + gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)] + torch.distributed.all_gather(gathered_weights, weight) + weight = torch.cat(gathered_weights, dim=1) + return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size] + + except Exception: + logger.error( + f"Error getting weights by name {name} in OPTForCausalLM: {get_exception_traceback()}" + ) + return None + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def get_embed(self): + return self.model.embed_tokens.weight + + def set_embed(self, embed): + # NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3 + if ( + hasattr(self.config, "target_hidden_size") + and self.config.target_hidden_size != self.config.hidden_size + ): + return + del self.model.embed_tokens.weight + self.model.embed_tokens.weight = embed + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + + +EntryClass = [OPTForCausalLM] diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 6d79d35aa..a9d8fe0df 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -77,6 +77,7 @@ ALL_MODELS = [ trust_remote_code=True, skip_long_prompt=True, ), + ModelCase("facebook/opt-125m", skip_long_prompt=True), ModelCase( "nvidia/Llama-3_3-Nemotron-Super-49B-v1_5", tp_size=2,