# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py # Copyright 2024 HuggingFace Inc. team. All rights reserved. # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only NemotronH model.""" from collections.abc import Iterable from typing import Optional import torch from torch import nn from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, SupportsQuant) from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig class NemotronHMLP(nn.Module): def __init__( self, config: NemotronHConfig, layer_idx: int, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, prefix: str = "", ) -> None: super().__init__() hybrid_override_pattern = config.hybrid_override_pattern mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1 if isinstance(config.intermediate_size, list): if len(config.intermediate_size) == 1: intermediate_size = config.intermediate_size[0] else: intermediate_size = config.intermediate_size[mlp_index] else: intermediate_size = config.intermediate_size self.up_proj = ColumnParallelLinear( input_size=config.hidden_size, output_size=intermediate_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.up_proj", ) self.down_proj = RowParallelLinear( input_size=intermediate_size, output_size=config.hidden_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.down_proj", ) self.act_fn = ReLUSquaredActivation() def forward(self, x: torch.Tensor): x, _ = self.up_proj(x) x = self.act_fn(x) x, _ = self.down_proj(x) return x class NemotronHMLPDecoderLayer(nn.Module): def __init__( self, config: NemotronHConfig, layer_idx: int, model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.mixer = NemotronHMLP( config, quant_config=quant_config, bias=config.mlp_bias, prefix=f"{prefix}.mixer", layer_idx=layer_idx, ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], **kwargs, ): if residual is None: residual = hidden_states hidden_states = self.norm(hidden_states) else: hidden_states, residual = self.norm(hidden_states, residual) hidden_states = self.mixer(hidden_states) return hidden_states, residual class NemotronHMambaDecoderLayer(nn.Module): def __init__( self, config: NemotronHConfig, layer_idx: int, model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.config = config self.mixer = MambaMixer2( hidden_size=config.hidden_size, ssm_state_size=config.ssm_state_size, conv_kernel_size=config.conv_kernel, intermediate_size=config.mamba_num_heads * config.mamba_head_dim, use_conv_bias=config.use_conv_bias, use_bias=config.use_bias, n_groups=config.n_groups, num_heads=config.mamba_num_heads, head_dim=config.mamba_head_dim, rms_norm_eps=config.rms_norm_eps, activation=config.mamba_hidden_act, model_config=model_config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mixer", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], **kwargs, ): if residual is None: residual = hidden_states hidden_states = self.norm(hidden_states) else: hidden_states, residual = self.norm(hidden_states, residual) output = torch.empty_like(hidden_states) self.mixer(hidden_states, output) return output, residual class NemotronHAttention(nn.Module): def __init__( self, config: NemotronHConfig, layer_idx: int, model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = config.num_key_value_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) if hasattr(config, "head_dim") and config.head_dim is not None: self.head_dim = config.head_dim else: self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.qkv_proj = QKVParallelLinear( config.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, prefix=f"{prefix}.attn", ) def forward( self, hidden_states: torch.Tensor, **kwargs, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class NemotronHAttentionDecoderLayer(nn.Module): def __init__( self, config: NemotronHConfig, layer_idx: int, model_config: Optional[ModelConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.mixer = NemotronHAttention( config, layer_idx, model_config, cache_config, quant_config, prefix=f"{prefix}.mixer", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], **kwargs, ): if residual is None: residual = hidden_states hidden_states = self.norm(hidden_states) else: hidden_states, residual = self.norm(hidden_states, residual) hidden_states = self.mixer(hidden_states=hidden_states) return hidden_states, residual ALL_DECODER_LAYER_TYPES = { "M": NemotronHMambaDecoderLayer, "-": NemotronHMLPDecoderLayer, "*": NemotronHAttentionDecoderLayer, } @support_torch_compile class NemotronHModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: NemotronHConfig = vllm_config.model_config.hf_config model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config lora_vocab = ((lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) if lora_config else 0) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, ) def get_layer(prefix: str): layer_idx = int(prefix.rsplit(".", 1)[1]) layer_class = ALL_DECODER_LAYER_TYPES[ config.hybrid_override_pattern[layer_idx]] return layer_class( config, layer_idx, model_config, cache_config, quant_config=quant_config, prefix=prefix, ) self.start_layer, self.end_layer, self.layers = make_layers( len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers") self.make_empty_intmd_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size) self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] residual = None for i, layer in enumerate(self.layers): hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, residual=residual, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if "scale" in name: # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue # load stacked params for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break # load other params else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"backbone": "model"}, orig_to_new_substr={ "A_log": "A", "embeddings": "embed_tokens" }, ) packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], } # LoRA specific attributes embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", } embedding_padding_modules = ["lm_head"] @classmethod def get_mamba_state_dtype_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[torch.dtype, torch.dtype]: return MambaStateDtypeCalculator.mamba2_state_dtype( vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype, vllm_config.cache_config.mamba_ssm_cache_dtype, ) @classmethod def get_mamba_state_shape_from_config( cls, vllm_config: "VllmConfig", ) -> tuple[tuple[int, int], tuple[int, int, int]]: """Calculate shapes for Mamba's convolutional and state caches. Args: vllm_config: vLLM config Returns: Tuple containing: - conv_state_shape: Shape for convolutional state cache - temporal_state_shape: Shape for state space model cache """ parallel_config = vllm_config.parallel_config hf_config = vllm_config.model_config.hf_config intermediate_size = hf_config.mamba_num_heads * hf_config.mamba_head_dim return MambaStateShapeCalculator.mamba2_state_shape( intermediate_size=intermediate_size, tp_world_size=parallel_config.tensor_parallel_size, n_groups=hf_config.n_groups, num_heads=hf_config.mamba_num_heads, head_dim=hf_config.mamba_head_dim, state_size=hf_config.ssm_state_size, conv_kernel=hf_config.conv_kernel, ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config self.model_config = vllm_config.model_config cache_config = vllm_config.cache_config lora_config = vllm_config.lora_config scheduler_config = vllm_config.scheduler_config assert not cache_config.enable_prefix_caching, \ "NemotronH currently does not support prefix caching" self.quant_config = vllm_config.quant_config super().__init__() self.config = config self.scheduler_config = scheduler_config self.model = NemotronHModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.make_empty_intmd_tensors = (self.model.make_empty_intmd_tensors) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)