diff --git a/docs/supported_models/generative_models.md b/docs/supported_models/generative_models.md index 4d2c6eecb..3647e56e0 100644 --- a/docs/supported_models/generative_models.md +++ b/docs/supported_models/generative_models.md @@ -51,3 +51,4 @@ in the GitHub search bar. | **Ling** (16.8B–290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAI’s open MoE models. Ling-Lite has 16.8B total / 2.75B active parameters, and Ling-Plus has 290B total / 28.8B active parameters. They are designed for high performance on NLP and complex reasoning tasks. | | **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. | | **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. | +| **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family builds on the strongest open models in the ecosystem by enhancing them with greater accuracy, efficiency, and transparency using NVIDIA open synthetic datasets, advanced techniques, and tools. This enables the creation of practical, right-sized, and high-performing AI agents. | diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index bdb124e51..6aa7e39e1 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -341,6 +341,19 @@ class ModelConfig: "kv_n_heads", self.hf_config.num_attention_heads, ) + if self.hf_config.model_type in ["nemotron-nas"]: + nkvh = { + self.hf_config.num_attention_heads // block.attention.n_heads_in_group + for block in self.hf_config.block_configs + if not block.attention.no_op + } + if len(nkvh) == 0: + raise RuntimeError("Couldn't determine number of kv heads") + if len(nkvh) > 1: + raise ValueError( + "Variable GQA (VGQA) is not yet supported for nemotron-nas in sglang" + ) + return next(iter(nkvh)) attributes = [ # For Falcon: diff --git a/python/sglang/srt/models/nemotron_nas.py b/python/sglang/srt/models/nemotron_nas.py new file mode 100644 index 000000000..bda70a2b1 --- /dev/null +++ b/python/sglang/srt/models/nemotron_nas.py @@ -0,0 +1,435 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_nas.py + +"""Inference-only deci model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Tuple, Type, Union + +import torch +from torch import nn +from transformers import LlamaConfig + +from python.sglang.srt.layers.utils import PPMissingLayer +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.models.llama import LlamaAttention, LlamaMLP +from sglang.srt.utils import add_prefix, make_layers +from sglang.utils import logger + + +def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int: + # DeciLM-specific code + intermediate_size = int(2 * ffn_mult * n_embd / 3) + return _find_multiple(intermediate_size, 256) + + +def _find_multiple(n: int, k: int) -> int: + # DeciLM-specific code + if n % k == 0: + return n + return n + k - (n % k) + + +class DeciLMDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + block_config = config.block_configs[layer_idx] + self._is_no_op_attention = block_config.attention.no_op + self._is_no_op_ffn = block_config.ffn.no_op + + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + rope_is_neox_style = getattr(config, "rope_is_neox_style", True) + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) + # support internlm/internlm3-8b with qkv_bias + if hasattr(config, "qkv_bias"): + attention_bias = config.qkv_bias + + if not self._is_no_op_attention: + num_kv_heads = ( + config.num_attention_heads // block_config.attention.n_heads_in_group + ) + self.self_attn = LlamaAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=num_kv_heads, + layer_id=layer_idx, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + rope_is_neox_style=rope_is_neox_style, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + bias=attention_bias, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if not self._is_no_op_ffn: + ffn_mult = block_config.ffn.ffn_mult + intermediate_size = _ffn_mult_to_intermediate_size( + ffn_mult, config.hidden_size + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + + if self._is_no_op_attention: + pass + else: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + # Fully Connected + if not self._is_no_op_ffn: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DeciModel(nn.Module): + def __init__( + self, + *, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + layer_type: Type[DeciLMDecoderLayer] = DeciLMDecoderLayer, + ): + super().__init__() + + lora_config = None + self.config = config + self.quant_config = quant_config + self.padding_idx = config.pad_token_id + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + vocab_size = config.vocab_size + lora_vocab + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + + def get_layer(idx: int, prefix: str): + return layer_type( + config, + layer_idx=idx, + quant_config=quant_config, + prefix=prefix, + ) + + self.layers, self.start_layer, self.end_layer = make_layers( + config.num_hidden_layers, + get_layer, + pp_rank=get_pp_group().rank_in_group, + pp_size=get_pp_group().world_size, + prefix=add_prefix("layers", prefix), + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer(return_tuple=True) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + forward_batch: ForwardBatch, + inputs_embeds: Optional[torch.Tensor] = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + kv_cache_index = 0 + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + if not layer._is_no_op_attention: + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) + kv_cache_index += 1 + else: + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) + + if not get_pp_group().is_last_rank: + return PPProxyTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeciLMForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + # Mistral/Llama models can also be loaded with --load-format mistral + # from consolidated.safetensors checkpoints + mistral_mapping = { + "layers": "model.layers", + "attention": "self_attn", + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "attention_norm": "input_layernorm", + "feed_forward": "mlp", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "ffn_norm": "post_attention_layernorm", + "tok_embeddings": "model.embed_tokens", + "output": "lm_head", + "norm": "model.norm", + } + + def __init__( + self, + *, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + lora_config = None + self.config = config + self.lora_config = lora_config + + self.model = self._init_model( + config=config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config + else lora_config.lora_vocab_padding_size + ), + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor(config) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def _init_model( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + return DeciModel(config=config, quant_config=quant_config, prefix=prefix) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + inputs_embeds: Optional[torch.Tensor] = None, + get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> LogitsProcessorOutput: + hidden_states = self.model( + input_ids, + positions, + forward_batch, + inputs_embeds, + pp_proxy_tensors=pp_proxy_tensors, + ) + if get_pp_group().is_last_rank: + if not get_embedding: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return self.pooler(hidden_states, forward_batch) + else: + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if self.model.quant_config is not None and ( + scale_name := self.model.quant_config.get_cache_scale(name) + ): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + continue + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in params_dict") + + +EntryClass = [DeciLMForCausalLM] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d15ef2a93..bcba0503b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -449,8 +449,10 @@ def set_cpu_offload_max_bytes(max_bytes: int) -> None: def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: - device = next(module.parameters()).device + if (params := next(module.parameters(), None)) is None: + return module + device = params.device if device == torch.device("cpu"): return module diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index ba1519951..248ba7285 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -231,11 +231,14 @@ class HFRunner: # Load the model and tokenizer if self.model_type == "generation": - config = AutoConfig.from_pretrained(model_path) - if model_archs := getattr(config, "architectures"): - model_cls = getattr(transformers, model_archs[0]) - else: + config = AutoConfig.from_pretrained( + model_path, trust_remote_code=self.trust_remote_code + ) + if self.trust_remote_code: model_cls = AutoModelForCausalLM + else: + model_arch = getattr(config, "architectures")[0] + model_cls = getattr(transformers, model_arch) self.base_model = model_cls.from_pretrained( model_path, torch_dtype=torch_dtype, diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index eb6763c67..fa55de947 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -77,6 +77,12 @@ ALL_MODELS = [ trust_remote_code=True, skip_long_prompt=True, ), + ModelCase( + "nvidia/Llama-3_3-Nemotron-Super-49B-v1_5", + tp_size=2, + trust_remote_code=True, + skip_long_prompt=True, + ), ] TORCH_DTYPES = [torch.float16]