# SPDX-License-Identifier: Apache-2.0 # Copyright 2024 The vLLM team. """Transformers modeling backend for causal language models. This module provides a wrapper class that enables vLLM to use any HuggingFace causal language model, including custom models that define their implementation via `auto_map` in config.json. Key optimizations: 1. Use meta device for delayed memory allocation 2. Replace nn.Linear with vLLM's optimized Linear classes 3. Replace RMSNorm with vLLM's fused RMSNorm 4. Replace input embeddings with VocabParallelEmbedding 5. Use vLLM's weight loading infrastructure (AutoWeightsLoader) """ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from vllm.attention.layer import Attention from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, ) from vllm.model_executor.models.transformers.utils import ( init_on_device_without_buffers, replace_linear_class, replace_rms_norm_class, log_replacement, maybe_prefix, ) from vllm.sequence import IntermediateTensors if TYPE_CHECKING: from transformers import PreTrainedModel from vllm.attention import AttentionMetadata logger = init_logger(__name__) # Note: In v0.6.2, the vLLM Attention.forward requires (query, key, value, kv_cache, attn_metadata). # The transformers backend integration works differently than in latest vLLM. # We keep the vllm_flash_attention_forward for reference, but it may not be compatible # with all transformers versions or MLU backends. # Global variable to store current attention metadata (set during forward pass) _current_attn_metadata = None _current_kv_caches = None def set_attention_context(attn_metadata, kv_caches): """Set the current attention context for vLLM attention functions.""" global _current_attn_metadata, _current_kv_caches _current_attn_metadata = attn_metadata _current_kv_caches = kv_caches def clear_attention_context(): """Clear the current attention context.""" global _current_attn_metadata, _current_kv_caches _current_attn_metadata = None _current_kv_caches = None def vllm_flash_attention_forward( # Transformers args module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, # Transformers kwargs scaling: float = None, # vLLM kwargs attention_instances: Dict[int, Attention] = None, **kwargs, ): """ vLLM's optimized attention function that replaces HuggingFace's attention. This function is registered to transformers' ALL_ATTENTION_FUNCTIONS. Note: In v0.6.2, this function may have limited functionality due to API differences in the Attention layer. For full functionality, the model should fall back to HuggingFace's native attention when vLLM attention is not properly configured. """ # Get the attention instance for this layer layer_idx = getattr(module, 'layer_idx', 0) if attention_instances is None or layer_idx not in attention_instances: # Fall back to standard attention computation logger.debug("No attention instance for layer %d, using standard attention", layer_idx) # Standard scaled dot-product attention attn_weights = torch.matmul(query, key.transpose(-2, -1)) if scaling is not None: attn_weights = attn_weights * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value) return attn_output, None self_attn = attention_instances[layer_idx] # v0.6.2 Attention.forward requires: (query, key, value, kv_cache, attn_metadata) # We need to get these from the global context global _current_attn_metadata, _current_kv_caches if _current_attn_metadata is None or _current_kv_caches is None: # No context set, fall back to standard attention logger.debug("No attention context, using standard attention for layer %d", layer_idx) attn_weights = torch.matmul(query, key.transpose(-2, -1)) if scaling is not None: attn_weights = attn_weights * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value) return attn_output, None # Update scale if provided if scaling is not None: self_attn.impl.scale = float(scaling) # Reshape tensors for vLLM: [batch, heads, seq, head_dim] -> [seq, heads * head_dim] hidden = query.shape[-2] query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) # Get KV cache for this layer kv_cache = _current_kv_caches[layer_idx] if layer_idx < len(_current_kv_caches) else None # Call vLLM attention output = self_attn.forward(query, key, value, kv_cache, _current_attn_metadata) return output, None # Try to register vLLM attention to transformers _vllm_attention_registered = False try: from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward _vllm_attention_registered = True logger.info("Registered vLLM attention function to transformers") except (ImportError, AttributeError) as e: logger.warning("Could not register vLLM attention function - " "transformers version may not support custom attention: %s", e) class TransformersForCausalLM(nn.Module): """ A wrapper class that adapts any HuggingFace causal language model to the vLLM interface with memory optimizations. Key optimizations (following latest vLLM): 1. Meta device initialization - no GPU memory until weights are loaded 2. Module replacement - Linear/RMSNorm replaced with vLLM optimized versions 3. VocabParallelEmbedding for input embeddings 4. AutoWeightsLoader for efficient weight loading with name mapping Interface compliance: - Implements VllmModel protocol (vllm_config init, forward with required args) - Implements VllmModelForTextGeneration protocol (compute_logits, sample) """ # Weight name mapping from HuggingFace to vLLM format hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # Add `model.` prefix for base model checkpoints "": "model.", "model.model.": "model.", # Heads will be adjacent to `model` "model.lm_head.": "lm_head.", "lm_head.": "lm_head.", # Handle different model architectures "transformer.": "model.", "model.transformer.": "model.", # Handle embeddings "embed_tokens.": "model.embed_tokens.", "model.embed_tokens.": "model.embed_tokens.", # Handle attention weights "self_attn.": "self_attn.", "attention.": "self_attn.", } ) def __init__( self, vllm_config: VllmConfig, prefix: str = "", ) -> None: super().__init__() config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config.hf_config self.text_config = getattr(self.config, "text_config", self.config) self.model_config = config self.cache_config = cache_config self.device_config = vllm_config.device_config self.quant_config = quant_config self.prefix = prefix # Get model dimensions from config self.hidden_size = getattr(self.text_config, "hidden_size", 4096) self.vocab_size = getattr(self.text_config, "vocab_size", 32000) # Weight loading configuration self.skip_prefixes: List[str] = [] self.ignore_unexpected_prefixes: List[str] = [ "model.layers.*.self_attn.rotary_emb.inv_freq", # Skip RoPE weights "model.norm.bias", # Some models don't have bias in final norm ] logger.info("Using Transformers modeling backend for %s", config.hf_config.architectures) # Configure vLLM attention backend self._configure_attention_backend() # Load the HuggingFace model structure on meta device (no memory allocation) self._load_hf_model_on_meta() # Replace modules with vLLM optimized versions self._replace_modules() # Replace input embeddings with VocabParallelEmbedding self._replace_input_embeddings() # Create attention instances for KV cache self.attention_instances = self._create_attention_instances() # Initialize parameters (allocate memory on target device) self._init_parameters() # Setup logits processor and sampler self.logits_processor = LogitsProcessor( self.vocab_size, logits_as_input=False, ) self.sampler = Sampler() def _load_hf_model_on_meta(self) -> None: """ Load the HuggingFace model structure on meta device. This creates the model structure without allocating GPU memory. Memory will be allocated later during weight loading. """ from transformers import AutoModelForCausalLM logger.info("Creating model structure on meta device...") # Create model on meta device - no GPU memory allocated with init_on_device_without_buffers("meta"): self.model: "PreTrainedModel" = AutoModelForCausalLM.from_config( self.config, torch_dtype=self.model_config.dtype, trust_remote_code=self.model_config.trust_remote_code, ) # Disable gradient computation for inference self.model.eval() for param in self.model.parameters(): param.requires_grad = False logger.info("Model structure created on meta device") def _replace_modules(self) -> None: """ Replace HuggingFace modules with vLLM optimized versions. This replaces: - nn.Linear with ReplicatedLinear (memory efficient, supports quantization) - RMSNorm variants with vLLM's fused RMSNorm """ logger.info("Replacing modules with vLLM optimized versions...") replaced_count = 0 def _recursive_replace(module: nn.Module, prefix: str = ""): nonlocal replaced_count for name, child in list(module.named_children()): qual_name = maybe_prefix(prefix, name) if isinstance(child, nn.Linear): # Replace Linear with vLLM's ReplicatedLinear new_module = replace_linear_class( child, style="replicate", quant_config=self.quant_config, prefix=qual_name, ) setattr(module, name, new_module) log_replacement(qual_name, child, new_module) replaced_count += 1 elif child.__class__.__name__.endswith("RMSNorm"): # Replace RMSNorm with vLLM's optimized version new_module = replace_rms_norm_class(child, self.hidden_size) setattr(module, name, new_module) log_replacement(qual_name, child, new_module) replaced_count += 1 elif child.__class__.__name__.endswith(("LayerNorm", "GroupNorm")): # Also handle other normalization layers logger.debug("Found normalization layer %s: %s", qual_name, type(child).__name__) # Could add specialized replacement here if needed _recursive_replace(child, qual_name) elif "Attention" in child.__class__.__name__: # Mark attention layers for potential replacement logger.debug("Found attention layer %s: %s", qual_name, type(child).__name__) # Note: We don't replace the attention module itself, # but we create separate vLLM Attention instances _recursive_replace(child, qual_name) else: # Recursively process children _recursive_replace(child, qual_name) _recursive_replace(self.model, "model") logger.info("Replaced %d modules with vLLM optimized versions", replaced_count) def _replace_input_embeddings(self) -> None: """ Replace the input embeddings with VocabParallelEmbedding. This provides memory efficiency for large vocabularies. """ input_embeddings = self.model.get_input_embeddings() if input_embeddings is None: logger.warning("Could not find input embeddings to replace") return # Get embedding dimension if hasattr(input_embeddings, "embedding_dim"): embedding_dim = input_embeddings.embedding_dim elif hasattr(input_embeddings, "weight"): embedding_dim = input_embeddings.weight.shape[1] else: embedding_dim = self.hidden_size # Store embed_scale if present (some models scale embeddings) self.embed_scale = getattr(input_embeddings, "embed_scale", None) logger.info("Replacing input embeddings with VocabParallelEmbedding " "(vocab_size=%d, embedding_dim=%d)", self.vocab_size, embedding_dim) new_embeddings = VocabParallelEmbedding( self.vocab_size, embedding_dim, org_num_embeddings=self.vocab_size, quant_config=self.quant_config, ) self.model.set_input_embeddings(new_embeddings) def _init_parameters(self) -> None: """ Initialize parameters from meta device to target device. This allocates the actual GPU memory for all parameters. """ logger.info("Initializing parameters on target device...") # Use device_config to get the correct device (supports MLU, CUDA, etc.) device = self.device_config.device if device is None: # Fallback for TPU or other special cases device = torch.device("cpu") dtype = self.model_config.dtype def _init_params(module: nn.Module): for name, param in list(module.named_parameters(recurse=False)): if param.device == torch.device("meta"): new_param = nn.Parameter( torch.empty_like( param.data, dtype=dtype, device=device, ), requires_grad=False, ) setattr(module, name, new_param) for child in module.children(): _init_params(child) _init_params(self.model) logger.info("Parameters initialized on %s", device) def _create_attention_instances(self) -> Dict[int, Attention]: """ Create vLLM Attention instances for each layer. This enables proper KV cache allocation and vLLM's optimized attention. Returns a dict mapping layer_idx to Attention instance. """ attention_instances: Dict[int, Attention] = {} num_layers = getattr(self.text_config, "num_hidden_layers", getattr(self.text_config, "num_layers", 32)) num_heads = getattr(self.text_config, "num_attention_heads", 32) head_size = self.hidden_size // num_heads scale = 1.0 / (head_size ** 0.5) num_kv_heads = getattr(self.text_config, "num_key_value_heads", num_heads) logger.info("Creating %d attention instances for KV cache " "(num_heads=%d, head_size=%d, num_kv_heads=%d)", num_layers, num_heads, head_size, num_kv_heads) for layer_idx in range(num_layers): attention = Attention( num_heads=num_heads, head_size=head_size, scale=scale, num_kv_heads=num_kv_heads, cache_config=self.cache_config, quant_config=self.quant_config, prefix=f"model.layers.{layer_idx}.self_attn", ) attention_instances[layer_idx] = attention return attention_instances def _configure_attention_backend(self) -> None: """ Configure vLLM attention backend for the model. This sets up the attention implementation BEFORE model creation. Only sets 'vllm' implementation if the attention function was registered. """ global _vllm_attention_registered if _vllm_attention_registered: # Set vLLM attention implementation in config (must be before from_config) self.text_config._attn_implementation = "vllm" logger.info("Set attention implementation to 'vllm' in text_config") else: # Use default eager attention if vLLM attention is not available logger.info("Using default HuggingFace attention (vLLM attention not registered)") def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: "AttentionMetadata", intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ Forward pass through the model. This method conforms to the VllmModel protocol by accepting: - input_ids: Token IDs - positions: Position IDs - kv_caches: KV cache tensors - attn_metadata: Attention metadata """ # Set attention context for vLLM attention function set_attention_context(attn_metadata, kv_caches) try: # Prepare inputs - add batch dimension if needed if inputs_embeds is not None: if inputs_embeds.dim() == 2: inputs_embeds = inputs_embeds.unsqueeze(0) model_inputs = {"inputs_embeds": inputs_embeds} else: if input_ids.dim() == 1: input_ids = input_ids.unsqueeze(0) model_inputs = {"input_ids": input_ids} # Position IDs - add batch dimension if positions is not None: if positions.dim() == 1: positions = positions.unsqueeze(0) model_inputs["position_ids"] = positions # Apply embed_scale if needed if ( self.embed_scale is not None and "input_ids" in model_inputs and "inputs_embeds" not in model_inputs ): inputs_embeds = self.model.get_input_embeddings()(model_inputs["input_ids"]) inputs_embeds = inputs_embeds * self.embed_scale model_inputs = {"inputs_embeds": inputs_embeds} if positions is not None: model_inputs["position_ids"] = positions # Run the model with vLLM attention instances with torch.no_grad(): outputs = self.model( **model_inputs, use_cache=False, return_dict=True, output_hidden_states=True, attention_instances=self.attention_instances, ) # Get hidden states from the last layer if outputs.hidden_states is not None: hidden_states = outputs.hidden_states[-1] else: # Fallback: use logits directly hidden_states = outputs.logits # Remove batch dimension if hidden_states.dim() == 3 and hidden_states.size(0) == 1: hidden_states = hidden_states.squeeze(0) return hidden_states finally: # Clear attention context clear_attention_context() def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: """ Compute logits from hidden states. This method conforms to the VllmModelForTextGeneration protocol. """ # Check if hidden_states are already logits if hidden_states.shape[-1] == self.vocab_size: logits = hidden_states else: # Apply the LM head lm_head = getattr(self.model, "lm_head", None) if lm_head is None: lm_head = getattr(self.model, "embed_out", None) if lm_head is not None: output = lm_head(hidden_states) # Handle tuple output from vLLM Linear layers if isinstance(output, tuple): logits = output[0] else: logits = output else: logits = hidden_states return self.logits_processor(None, logits, sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: """ Sample tokens from logits. This method conforms to the VllmModelForTextGeneration protocol. """ next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], ) -> Set[str]: """ Load weights into the model using AutoWeightsLoader. This uses vLLM's efficient weight loading infrastructure with automatic name mapping. """ loader = AutoWeightsLoader( self, skip_prefixes=self.skip_prefixes, ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, ) loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) logger.info("Loaded %d weight tensors", len(loaded)) return set(loaded) def is_backend_compatible() -> bool: """ Check if the current model is compatible with the Transformers backend. """ return True