From 6e38461af66cdb48563d1b29d0c2a1050825b401 Mon Sep 17 00:00:00 2001 From: Chranos <826995883@qq.com> Date: Thu, 5 Feb 2026 17:11:09 +0800 Subject: [PATCH] testing dynamic register --- .../models/transformers/__init__.py | 30 +- .../models/transformers/causal.py | 558 +++++++++++++++--- .../models/transformers/utils.py | 167 ++++++ 3 files changed, 657 insertions(+), 98 deletions(-) create mode 100644 vllm-v0.6.2/vllm/model_executor/models/transformers/utils.py diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/__init__.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/__init__.py index cf98e27..7b92369 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/transformers/__init__.py +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/__init__.py @@ -2,13 +2,39 @@ # Copyright 2024 The vLLM team. """Wrapper around `transformers` models for vLLM v0.6.2. -This module provides a simplified Transformers modeling backend that wraps +This module provides an advanced Transformers modeling backend that wraps any HuggingFace model with the vLLM interface, enabling support for custom models that define their implementation via `auto_map` in config.json. + +Key optimizations and features: +- Meta device initialization for memory efficiency +- Module replacement (Linear, RMSNorm, Embedding) with vLLM optimized versions +- VocabParallelEmbedding for input embeddings +- vLLM Attention instances for proper KV cache allocation +- AutoWeightsLoader for efficient weight loading with name mapping +- vLLM attention backend integration (when supported by upgraded transformers) """ -from vllm.model_executor.models.transformers.causal import TransformersForCausalLM +from vllm.model_executor.models.transformers.causal import ( + TransformersForCausalLM, + is_backend_compatible, +) +from vllm.model_executor.models.transformers.utils import ( + init_on_device_without_buffers, + replace_linear_class, + replace_rms_norm_class, + log_replacement, + maybe_prefix, +) __all__ = [ + # Main wrapper classes "TransformersForCausalLM", + "is_backend_compatible", + # Utility functions + "init_on_device_without_buffers", + "replace_linear_class", + "replace_rms_norm_class", + "log_replacement", + "maybe_prefix", ] diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py index 8104562..4f2b299 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py @@ -6,11 +6,15 @@ This module provides a wrapper class that enables vLLM to use any HuggingFace causal language model, including custom models that define their implementation via `auto_map` in config.json. -The key insight is that we use HuggingFace's AutoModelForCausalLM to load the -actual model, then wrap it with the vLLM interface (compute_logits, sample, etc). +Key optimizations: +1. Use meta device for delayed memory allocation +2. Replace nn.Linear with vLLM's optimized Linear classes +3. Replace RMSNorm with vLLM's fused RMSNorm +4. Replace input embeddings with VocabParallelEmbedding +5. Use vLLM's weight loading infrastructure (AutoWeightsLoader) """ -from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn @@ -19,8 +23,22 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, +) +from vllm.attention.layer import Attention from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, +) +from vllm.model_executor.models.transformers.utils import ( + init_on_device_without_buffers, + replace_linear_class, + replace_rms_norm_class, + log_replacement, + maybe_prefix, +) from vllm.sequence import IntermediateTensors if TYPE_CHECKING: @@ -30,25 +48,153 @@ if TYPE_CHECKING: logger = init_logger(__name__) +# Note: In v0.6.2, the vLLM Attention.forward requires (query, key, value, kv_cache, attn_metadata). +# The transformers backend integration works differently than in latest vLLM. +# We keep the vllm_flash_attention_forward for reference, but it may not be compatible +# with all transformers versions or MLU backends. + +# Global variable to store current attention metadata (set during forward pass) +_current_attn_metadata = None +_current_kv_caches = None + + +def set_attention_context(attn_metadata, kv_caches): + """Set the current attention context for vLLM attention functions.""" + global _current_attn_metadata, _current_kv_caches + _current_attn_metadata = attn_metadata + _current_kv_caches = kv_caches + + +def clear_attention_context(): + """Clear the current attention context.""" + global _current_attn_metadata, _current_kv_caches + _current_attn_metadata = None + _current_kv_caches = None + + +def vllm_flash_attention_forward( + # Transformers args + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + # Transformers kwargs + scaling: float = None, + # vLLM kwargs + attention_instances: Dict[int, Attention] = None, + **kwargs, +): + """ + vLLM's optimized attention function that replaces HuggingFace's attention. + This function is registered to transformers' ALL_ATTENTION_FUNCTIONS. + + Note: In v0.6.2, this function may have limited functionality due to + API differences in the Attention layer. For full functionality, the model + should fall back to HuggingFace's native attention when vLLM attention + is not properly configured. + """ + # Get the attention instance for this layer + layer_idx = getattr(module, 'layer_idx', 0) + + if attention_instances is None or layer_idx not in attention_instances: + # Fall back to standard attention computation + logger.debug("No attention instance for layer %d, using standard attention", layer_idx) + # Standard scaled dot-product attention + attn_weights = torch.matmul(query, key.transpose(-2, -1)) + if scaling is not None: + attn_weights = attn_weights * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value) + return attn_output, None + + self_attn = attention_instances[layer_idx] + + # v0.6.2 Attention.forward requires: (query, key, value, kv_cache, attn_metadata) + # We need to get these from the global context + global _current_attn_metadata, _current_kv_caches + + if _current_attn_metadata is None or _current_kv_caches is None: + # No context set, fall back to standard attention + logger.debug("No attention context, using standard attention for layer %d", layer_idx) + attn_weights = torch.matmul(query, key.transpose(-2, -1)) + if scaling is not None: + attn_weights = attn_weights * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value) + return attn_output, None + + # Update scale if provided + if scaling is not None: + self_attn.impl.scale = float(scaling) + + # Reshape tensors for vLLM: [batch, heads, seq, head_dim] -> [seq, heads * head_dim] + hidden = query.shape[-2] + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) + + # Get KV cache for this layer + kv_cache = _current_kv_caches[layer_idx] if layer_idx < len(_current_kv_caches) else None + + # Call vLLM attention + output = self_attn.forward(query, key, value, kv_cache, _current_attn_metadata) + + return output, None + + +# Try to register vLLM attention to transformers +_vllm_attention_registered = False +try: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward + _vllm_attention_registered = True + logger.info("Registered vLLM attention function to transformers") +except (ImportError, AttributeError) as e: + logger.warning("Could not register vLLM attention function - " + "transformers version may not support custom attention: %s", e) + + class TransformersForCausalLM(nn.Module): """ A wrapper class that adapts any HuggingFace causal language model - to the vLLM interface. + to the vLLM interface with memory optimizations. - This class provides: - 1. forward() - processes input through the model - 2. compute_logits() - computes output logits - 3. sample() - samples tokens from logits - 4. load_weights() - loads model weights - - The actual HuggingFace model is loaded using AutoModelForCausalLM and - stored in self.model. + Key optimizations (following latest vLLM): + 1. Meta device initialization - no GPU memory until weights are loaded + 2. Module replacement - Linear/RMSNorm replaced with vLLM optimized versions + 3. VocabParallelEmbedding for input embeddings + 4. AutoWeightsLoader for efficient weight loading with name mapping Interface compliance: - Implements VllmModel protocol (vllm_config init, forward with required args) - Implements VllmModelForTextGeneration protocol (compute_logits, sample) """ + # Weight name mapping from HuggingFace to vLLM format + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Add `model.` prefix for base model checkpoints + "": "model.", + "model.model.": "model.", + # Heads will be adjacent to `model` + "model.lm_head.": "lm_head.", + "lm_head.": "lm_head.", + # Handle different model architectures + "transformer.": "model.", + "model.transformer.": "model.", + # Handle embeddings + "embed_tokens.": "model.embed_tokens.", + "model.embed_tokens.": "model.embed_tokens.", + # Handle attention weights + "self_attn.": "self_attn.", + "attention.": "self_attn.", + } + ) + def __init__( self, vllm_config: VllmConfig, @@ -61,42 +207,250 @@ class TransformersForCausalLM(nn.Module): quant_config = vllm_config.quant_config self.config = config.hf_config + self.text_config = getattr(self.config, "text_config", self.config) self.model_config = config self.cache_config = cache_config + self.device_config = vllm_config.device_config self.quant_config = quant_config self.prefix = prefix + # Get model dimensions from config + self.hidden_size = getattr(self.text_config, "hidden_size", 4096) + self.vocab_size = getattr(self.text_config, "vocab_size", 32000) + + # Weight loading configuration + self.skip_prefixes: List[str] = [] + self.ignore_unexpected_prefixes: List[str] = [ + "model.layers.*.self_attn.rotary_emb.inv_freq", # Skip RoPE weights + "model.norm.bias", # Some models don't have bias in final norm + ] + logger.info("Using Transformers modeling backend for %s", config.hf_config.architectures) - # Load the actual HuggingFace model - self._load_hf_model() + # Configure vLLM attention backend + self._configure_attention_backend() + + # Load the HuggingFace model structure on meta device (no memory allocation) + self._load_hf_model_on_meta() + + # Replace modules with vLLM optimized versions + self._replace_modules() + + # Replace input embeddings with VocabParallelEmbedding + self._replace_input_embeddings() + + # Create attention instances for KV cache + self.attention_instances = self._create_attention_instances() + + # Initialize parameters (allocate memory on target device) + self._init_parameters() # Setup logits processor and sampler self.logits_processor = LogitsProcessor( - self.config.vocab_size, + self.vocab_size, logits_as_input=False, ) self.sampler = Sampler() - def _load_hf_model(self) -> None: - """Load the HuggingFace model using AutoModelForCausalLM.""" + def _load_hf_model_on_meta(self) -> None: + """ + Load the HuggingFace model structure on meta device. + + This creates the model structure without allocating GPU memory. + Memory will be allocated later during weight loading. + """ from transformers import AutoModelForCausalLM - # We load with minimal config first - weights will be loaded separately - # by vLLM's weight loader - logger.info("Loading HuggingFace model from config...") + logger.info("Creating model structure on meta device...") - self.model: "PreTrainedModel" = AutoModelForCausalLM.from_config( - self.config, - torch_dtype=self.model_config.dtype, - trust_remote_code=self.model_config.trust_remote_code, - ) + # Create model on meta device - no GPU memory allocated + with init_on_device_without_buffers("meta"): + self.model: "PreTrainedModel" = AutoModelForCausalLM.from_config( + self.config, + torch_dtype=self.model_config.dtype, + trust_remote_code=self.model_config.trust_remote_code, + ) # Disable gradient computation for inference self.model.eval() for param in self.model.parameters(): param.requires_grad = False + + logger.info("Model structure created on meta device") + + def _replace_modules(self) -> None: + """ + Replace HuggingFace modules with vLLM optimized versions. + + This replaces: + - nn.Linear with ReplicatedLinear (memory efficient, supports quantization) + - RMSNorm variants with vLLM's fused RMSNorm + """ + logger.info("Replacing modules with vLLM optimized versions...") + replaced_count = 0 + + def _recursive_replace(module: nn.Module, prefix: str = ""): + nonlocal replaced_count + for name, child in list(module.named_children()): + qual_name = maybe_prefix(prefix, name) + + if isinstance(child, nn.Linear): + # Replace Linear with vLLM's ReplicatedLinear + new_module = replace_linear_class( + child, + style="replicate", + quant_config=self.quant_config, + prefix=qual_name, + ) + setattr(module, name, new_module) + log_replacement(qual_name, child, new_module) + replaced_count += 1 + + elif child.__class__.__name__.endswith("RMSNorm"): + # Replace RMSNorm with vLLM's optimized version + new_module = replace_rms_norm_class(child, self.hidden_size) + setattr(module, name, new_module) + log_replacement(qual_name, child, new_module) + replaced_count += 1 + + elif child.__class__.__name__.endswith(("LayerNorm", "GroupNorm")): + # Also handle other normalization layers + logger.debug("Found normalization layer %s: %s", qual_name, type(child).__name__) + # Could add specialized replacement here if needed + _recursive_replace(child, qual_name) + + elif "Attention" in child.__class__.__name__: + # Mark attention layers for potential replacement + logger.debug("Found attention layer %s: %s", qual_name, type(child).__name__) + # Note: We don't replace the attention module itself, + # but we create separate vLLM Attention instances + _recursive_replace(child, qual_name) + + else: + # Recursively process children + _recursive_replace(child, qual_name) + + _recursive_replace(self.model, "model") + logger.info("Replaced %d modules with vLLM optimized versions", replaced_count) + + def _replace_input_embeddings(self) -> None: + """ + Replace the input embeddings with VocabParallelEmbedding. + + This provides memory efficiency for large vocabularies. + """ + input_embeddings = self.model.get_input_embeddings() + if input_embeddings is None: + logger.warning("Could not find input embeddings to replace") + return + + # Get embedding dimension + if hasattr(input_embeddings, "embedding_dim"): + embedding_dim = input_embeddings.embedding_dim + elif hasattr(input_embeddings, "weight"): + embedding_dim = input_embeddings.weight.shape[1] + else: + embedding_dim = self.hidden_size + + # Store embed_scale if present (some models scale embeddings) + self.embed_scale = getattr(input_embeddings, "embed_scale", None) + + logger.info("Replacing input embeddings with VocabParallelEmbedding " + "(vocab_size=%d, embedding_dim=%d)", self.vocab_size, embedding_dim) + + new_embeddings = VocabParallelEmbedding( + self.vocab_size, + embedding_dim, + org_num_embeddings=self.vocab_size, + quant_config=self.quant_config, + ) + self.model.set_input_embeddings(new_embeddings) + + def _init_parameters(self) -> None: + """ + Initialize parameters from meta device to target device. + + This allocates the actual GPU memory for all parameters. + """ + logger.info("Initializing parameters on target device...") + + # Use device_config to get the correct device (supports MLU, CUDA, etc.) + device = self.device_config.device + if device is None: + # Fallback for TPU or other special cases + device = torch.device("cpu") + dtype = self.model_config.dtype + + def _init_params(module: nn.Module): + for name, param in list(module.named_parameters(recurse=False)): + if param.device == torch.device("meta"): + new_param = nn.Parameter( + torch.empty_like( + param.data, + dtype=dtype, + device=device, + ), + requires_grad=False, + ) + setattr(module, name, new_param) + + for child in module.children(): + _init_params(child) + + _init_params(self.model) + logger.info("Parameters initialized on %s", device) + + def _create_attention_instances(self) -> Dict[int, Attention]: + """ + Create vLLM Attention instances for each layer. + + This enables proper KV cache allocation and vLLM's optimized attention. + Returns a dict mapping layer_idx to Attention instance. + """ + attention_instances: Dict[int, Attention] = {} + num_layers = getattr(self.text_config, "num_hidden_layers", + getattr(self.text_config, "num_layers", 32)) + + num_heads = getattr(self.text_config, "num_attention_heads", 32) + head_size = self.hidden_size // num_heads + scale = 1.0 / (head_size ** 0.5) + num_kv_heads = getattr(self.text_config, "num_key_value_heads", num_heads) + + logger.info("Creating %d attention instances for KV cache " + "(num_heads=%d, head_size=%d, num_kv_heads=%d)", + num_layers, num_heads, head_size, num_kv_heads) + + for layer_idx in range(num_layers): + attention = Attention( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + cache_config=self.cache_config, + quant_config=self.quant_config, + prefix=f"model.layers.{layer_idx}.self_attn", + ) + attention_instances[layer_idx] = attention + + return attention_instances + + def _configure_attention_backend(self) -> None: + """ + Configure vLLM attention backend for the model. + + This sets up the attention implementation BEFORE model creation. + Only sets 'vllm' implementation if the attention function was registered. + """ + global _vllm_attention_registered + + if _vllm_attention_registered: + # Set vLLM attention implementation in config (must be before from_config) + self.text_config._attn_implementation = "vllm" + logger.info("Set attention implementation to 'vllm' in text_config") + else: + # Use default eager attention if vLLM attention is not available + logger.info("Using default HuggingFace attention (vLLM attention not registered)") def forward( self, @@ -114,49 +468,66 @@ class TransformersForCausalLM(nn.Module): This method conforms to the VllmModel protocol by accepting: - input_ids: Token IDs - positions: Position IDs - - kv_caches: KV cache tensors (not used in basic HF forward) - - attn_metadata: Attention metadata (not used in basic HF forward) - - Note: This is a simplified implementation that does not use vLLM's - optimized attention mechanisms. For production use with KV caching, - a more sophisticated implementation would be needed. + - kv_caches: KV cache tensors + - attn_metadata: Attention metadata """ - # For simplicity, we use HuggingFace's native forward - # This won't have vLLM's optimizations but will work + # Set attention context for vLLM attention function + set_attention_context(attn_metadata, kv_caches) - if inputs_embeds is not None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids} - - # Position IDs - if positions is not None: - model_inputs["position_ids"] = positions.unsqueeze(0) if positions.dim() == 1 else positions - - # Run the model - with torch.no_grad(): - outputs = self.model( - **model_inputs, - use_cache=False, - return_dict=True, - ) - - # Get hidden states from the last layer - # For CausalLM, we typically want the hidden states before the LM head - if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None: - hidden_states = outputs.hidden_states[-1] - else: - # Fall back to running without output_hidden_states - # and getting logits directly - hidden_states = outputs.logits - if hidden_states.dim() == 3: + try: + # Prepare inputs - add batch dimension if needed + if inputs_embeds is not None: + if inputs_embeds.dim() == 2: + inputs_embeds = inputs_embeds.unsqueeze(0) + model_inputs = {"inputs_embeds": inputs_embeds} + else: + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + model_inputs = {"input_ids": input_ids} + + # Position IDs - add batch dimension + if positions is not None: + if positions.dim() == 1: + positions = positions.unsqueeze(0) + model_inputs["position_ids"] = positions + + # Apply embed_scale if needed + if ( + self.embed_scale is not None + and "input_ids" in model_inputs + and "inputs_embeds" not in model_inputs + ): + inputs_embeds = self.model.get_input_embeddings()(model_inputs["input_ids"]) + inputs_embeds = inputs_embeds * self.embed_scale + model_inputs = {"inputs_embeds": inputs_embeds} + if positions is not None: + model_inputs["position_ids"] = positions + + # Run the model with vLLM attention instances + with torch.no_grad(): + outputs = self.model( + **model_inputs, + use_cache=False, + return_dict=True, + output_hidden_states=True, + attention_instances=self.attention_instances, + ) + + # Get hidden states from the last layer + if outputs.hidden_states is not None: + hidden_states = outputs.hidden_states[-1] + else: + # Fallback: use logits directly + hidden_states = outputs.logits + + # Remove batch dimension + if hidden_states.dim() == 3 and hidden_states.size(0) == 1: hidden_states = hidden_states.squeeze(0) + return hidden_states - - if hidden_states.dim() == 3: - hidden_states = hidden_states.squeeze(0) - - return hidden_states + finally: + # Clear attention context + clear_attention_context() def compute_logits( self, @@ -168,12 +539,24 @@ class TransformersForCausalLM(nn.Module): This method conforms to the VllmModelForTextGeneration protocol. """ - # If hidden_states are already logits (from forward), process them - if hidden_states.shape[-1] == self.config.vocab_size: + # Check if hidden_states are already logits + if hidden_states.shape[-1] == self.vocab_size: logits = hidden_states else: # Apply the LM head - logits = self.model.lm_head(hidden_states) + lm_head = getattr(self.model, "lm_head", None) + if lm_head is None: + lm_head = getattr(self.model, "embed_out", None) + + if lm_head is not None: + output = lm_head(hidden_states) + # Handle tuple output from vLLM Linear layers + if isinstance(output, tuple): + logits = output[0] + else: + logits = output + else: + logits = hidden_states return self.logits_processor(None, logits, sampling_metadata) @@ -195,40 +578,23 @@ class TransformersForCausalLM(nn.Module): weights: Iterable[Tuple[str, torch.Tensor]], ) -> Set[str]: """ - Load weights into the model. + Load weights into the model using AutoWeightsLoader. - This method loads weights from an iterable of (name, tensor) pairs - into the HuggingFace model. + This uses vLLM's efficient weight loading infrastructure with + automatic name mapping. """ - loaded_params: Set[str] = set() - model_params = dict(self.model.named_parameters()) - - for name, loaded_weight in weights: - # Try to find the parameter in the model - if name in model_params: - param = model_params[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - else: - # Try common prefixes - for prefix in ["model.", ""]: - full_name = f"{prefix}{name}" if prefix else name - if full_name in model_params: - param = model_params[full_name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - break - - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=self.skip_prefixes, + ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, + ) + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + logger.info("Loaded %d weight tensors", len(loaded)) + return set(loaded) def is_backend_compatible() -> bool: """ Check if the current model is compatible with the Transformers backend. - - This is a simplified check - in practice, compatibility depends on - whether the model follows standard HuggingFace conventions. """ return True diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/utils.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/utils.py new file mode 100644 index 0000000..54dd58f --- /dev/null +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/utils.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 The vLLM team. +"""Transformers modeling backend utilities for v0.6.2. + +This module provides utility functions for the Transformers backend, +including context managers for meta device initialization and +module replacement functions. +""" + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Literal, Optional, Union + +import torch +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + +logger = init_logger(__name__) + + +@contextmanager +def init_on_device_without_buffers(device: Union[str, torch.device]): + """ + A context manager under which models are initialized with all + parameters on the specified device. However buffers are not + initialized on specified device. + + This is useful for creating model structure without allocating + GPU memory, which is essential for memory efficiency. + + Args: + device: Device to initialize all parameters on (e.g., "meta"). + + Example: + with init_on_device_without_buffers("meta"): + model = AutoModel.from_config(config) + # Now model is on meta device, no GPU memory allocated + """ + if isinstance(device, str): + device = torch.device(device) + + old_register_parameter = nn.Module.register_parameter + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs + ) + + try: + nn.Module.register_parameter = register_empty_parameter + yield + finally: + nn.Module.register_parameter = old_register_parameter + + +# Linear replacement styles +Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] + + +def replace_linear_class( + linear: nn.Linear, + style: Style = "replicate", + quant_config: Optional["QuantizationConfig"] = None, + prefix: str = "", +) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: + """ + Replace nn.Linear with one of vLLM's tensor parallel linear classes. + + This replacement provides: + - Memory efficiency through proper tensor allocation + - Support for quantization + - Tensor parallel support (when using ColumnParallel/RowParallel) + + Args: + linear: `nn.Linear` to be replaced. + style: Tensor parallel style of the new linear: + - "colwise": Column parallel (split output dim) + - "colwise_rep": Column parallel with gather output + - "rowwise": Row parallel (split input dim) + - "rowwise_rep": Row parallel without parallel input + - "replicate": Replicated (no parallelism) + quant_config: Quantization config for the new linear. + prefix: The name of the layer for weight loading. + + Returns: + The new vLLM linear layer. + """ + if not isinstance(style, str): + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") + + vllm_linear_cls, vllm_linear_kwargs = { + "colwise": (ColumnParallelLinear, {}), + "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), + "rowwise": (RowParallelLinear, {}), + "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), + "replicate": (ReplicatedLinear, {}), + }.get(style, (ReplicatedLinear, {})) + + return vllm_linear_cls( + input_size=linear.in_features, + output_size=linear.out_features, + bias=linear.bias is not None, + quant_config=quant_config, + prefix=prefix, + **vllm_linear_kwargs, + ) + + +def replace_rms_norm_class( + rms_norm: nn.Module, + hidden_size: int, +) -> RMSNorm: + """ + Replace a Transformers RMSNorm with vLLM's optimized RMSNorm. + + vLLM's RMSNorm provides: + - Fused CUDA kernels for better performance + - Support for fused add + norm operations + + Args: + rms_norm: The RMSNorm module to replace. + hidden_size: The hidden size of the model. + + Returns: + The new vLLM RMSNorm layer. + """ + # Try to get epsilon from various attribute names + eps = getattr(rms_norm, "eps", None) + if eps is None: + eps = getattr(rms_norm, "variance_epsilon", None) + if eps is None: + eps = 1e-6 + + # Check if weight exists and get its size + weight = getattr(rms_norm, "weight", None) + if weight is not None: + hidden_size = weight.size(0) + + return RMSNorm(hidden_size=hidden_size, eps=eps) + + +def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): + """Log module replacement for debugging.""" + logger.debug("Replaced %s: %s -> %s", name, type(old_module).__name__, type(new_module).__name__) + + +def maybe_prefix(prefix: str, name: str) -> str: + """Combine prefix and name with a dot separator.""" + if prefix: + return f"{prefix}.{name}" + return name