# SPDX-License-Identifier: Apache-2.0 # Copyright 2024 The vLLM team. """Transformers modeling backend base class for v0.6.2. This module provides the Base class following latest vLLM architecture: - Meta device initialization for memory efficiency - Pipeline parallel support (PPMissingLayer) - Tensor parallel support (tp_plan based module replacement) - Module replacement (Linear, RMSNorm) with vLLM optimized versions - VocabParallelEmbedding for input embeddings - Attention instances for KV cache allocation - Weight loading with AutoWeightsLoader and WeightsMapper """ import re from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.distributed import get_pp_group, get_tp_group from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, WeightsMapper, make_empty_intermediate_tensors_factory, ) from vllm.attention.layer import Attention from vllm.sequence import IntermediateTensors from .utils import ( init_on_device_without_buffers, replace_linear_class, replace_rms_norm_class, log_replacement, maybe_prefix, ) if TYPE_CHECKING: from transformers import PreTrainedModel from vllm.attention import AttentionMetadata logger = init_logger(__name__) # ============================================================================ # Attention Context Management (for vLLM attention integration) # ============================================================================ _current_attn_metadata = None _current_kv_caches = None def set_attention_context(attn_metadata, kv_caches): """Set the current attention context for vLLM attention functions.""" global _current_attn_metadata, _current_kv_caches _current_attn_metadata = attn_metadata _current_kv_caches = kv_caches def clear_attention_context(): """Clear the current attention context after forward pass.""" global _current_attn_metadata, _current_kv_caches _current_attn_metadata = None _current_kv_caches = None def get_attention_context(): """Get the current attention context.""" return _current_attn_metadata, _current_kv_caches # ============================================================================ # vLLM Attention Function for Transformers Integration # ============================================================================ def vllm_flash_attention_forward( module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, scaling: float = None, attention_instances: Dict[int, Attention] = None, **kwargs, ): """ vLLM's optimized attention function for transformers integration. In v0.6.2, Attention.forward signature is: (query, key, value, kv_cache, attn_metadata) """ layer_idx = getattr(module, 'layer_idx', 0) if attention_instances is None or layer_idx not in attention_instances: return _standard_attention(query, key, value, attention_mask, scaling) self_attn = attention_instances[layer_idx] attn_metadata, kv_caches = get_attention_context() if attn_metadata is None or kv_caches is None: return _standard_attention(query, key, value, attention_mask, scaling) if scaling is not None: self_attn.impl.scale = float(scaling) # Reshape: [batch, heads, seq, head_dim] -> [seq, heads * head_dim] hidden = query.shape[-2] query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) kv_cache = kv_caches[layer_idx] if layer_idx < len(kv_caches) else None output = self_attn.forward(query, key, value, kv_cache, attn_metadata) return output, None def _standard_attention(query, key, value, attention_mask, scaling): """Standard scaled dot-product attention fallback.""" attn_weights = torch.matmul(query, key.transpose(-2, -1)) if scaling is not None: attn_weights = attn_weights * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, value) return attn_output, None # Register vLLM attention to transformers _vllm_attention_registered = False try: from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward _vllm_attention_registered = True logger.info("Registered vLLM attention function to transformers") except (ImportError, AttributeError) as e: logger.warning("Could not register vLLM attention: %s", e) # ============================================================================ # Base Class with Pipeline Parallel and Tensor Parallel Support # ============================================================================ class Base(nn.Module): """ Base class for Transformers backend models with full parallel support. Features: - Pipeline Parallel: PPMissingLayer for distributed layers - Tensor Parallel: tp_plan based module replacement - Meta device initialization - Module replacement (Linear → vLLM Linear, RMSNorm → vLLM RMSNorm) - VocabParallelEmbedding for input embeddings - Attention instances for KV cache allocation """ # For vLLM's weight loader embedding_modules = ["embed_tokens"] # Weight name mapping following latest vLLM pattern hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ # Add `model.` prefix for base model checkpoints, # handling the case where it is already present "": "model.", "model.model.": "model.", # Heads will be adjacent to `model` (pooling included because of adapters) "model.lm_head.": "lm_head.", "model.score.": "classifier.", "model.classifier.": "classifier.", } ) # Note: __init_subclass__ with WeightsMapper merging is not supported in v0.6.2 # because WeightsMapper doesn't implement __or__/__ior__ operators. # Each Mixin should define its own hf_to_vllm_mapper if needed. def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() logger.info("Using Transformers modeling backend.") # Store configuration self.config = vllm_config.model_config.hf_config self.text_config = getattr(self.config, "text_config", self.config) self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.device_config = vllm_config.device_config self.parallel_config = vllm_config.parallel_config self.quant_config = vllm_config.quant_config self.prefix = prefix # Parallel groups self.pp_group = get_pp_group() self.tp_group = get_tp_group() # Model dimensions self.hidden_size = getattr(self.text_config, "hidden_size", 4096) self.vocab_size = getattr(self.text_config, "vocab_size", 32000) # Weight loading configuration self.skip_prefixes: List[str] = [] self.ignore_unexpected_prefixes: List[str] = [] # Configure attention backend self._configure_attention_backend() # Create model on meta device self._init_model_on_meta() # Apply pipeline parallel self._apply_pipeline_parallel() # Replace modules (with tensor parallel support) self._replace_modules() # Fix attention head_dim in case config was incorrect self._fix_attention_head_dim() # Add debug hook to first attention module to capture tensor shapes self._add_attention_debug_hook() # Replace input embeddings self._replace_input_embeddings() # Create attention instances self.attention_instances = self._create_attention_instances() # Initialize parameters on target device self._init_parameters() # Pipeline parallel intermediate tensors self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states"], self.hidden_size ) def _configure_attention_backend(self) -> None: """Configure vLLM attention backend.""" # Note: attention implementation is set in _init_model_on_meta # This method is kept for potential platform-specific configuration pass def _init_model_on_meta(self) -> None: """Create model structure on meta device.""" from transformers import AutoModel logger.info("Creating model structure on meta device...") # Set attention implementation to vLLM's self.text_config._attn_implementation = "vllm" # Ensure head_dim is correctly set in BOTH config and text_config # Transformers models use config.head_dim to compute attention dimensions # Some models may have incorrect head_dim, so we compute and set it if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"): correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads # Check and fix head_dim in text_config if hasattr(self.text_config, "head_dim"): if self.text_config.head_dim != correct_head_dim: logger.warning( "Correcting head_dim in text_config: %d -> %d", self.text_config.head_dim, correct_head_dim ) self.text_config.head_dim = correct_head_dim else: self.text_config.head_dim = correct_head_dim # Also set in self.config (which is passed to AutoModel.from_config) if hasattr(self.config, "head_dim"): if self.config.head_dim != correct_head_dim: logger.warning( "Correcting head_dim in config: %d -> %d", self.config.head_dim, correct_head_dim ) self.config.head_dim = correct_head_dim else: self.config.head_dim = correct_head_dim # Some models also need _attn_implementation in config self.config._attn_implementation = "vllm" with init_on_device_without_buffers("meta"): self.model: "PreTrainedModel" = AutoModel.from_config( self.config, torch_dtype=self.model_config.dtype, trust_remote_code=self.model_config.trust_remote_code, ) self.model.eval() for param in self.model.parameters(): param.requires_grad = False def _apply_pipeline_parallel(self) -> None: """ Apply pipeline parallelization plan. For models that don't explicitly support pp_plan, we do a best-effort approach by splitting layers based on num_hidden_layers. """ if self.pp_group.world_size <= 1: return logger.info("Applying pipeline parallel (world_size=%d, rank=%d)", self.pp_group.world_size, self.pp_group.rank_in_group) num_layers = getattr(self.text_config, "num_hidden_layers", getattr(self.text_config, "num_layers", 32)) start_layer, end_layer = get_pp_indices( num_layers, self.pp_group.rank_in_group, self.pp_group.world_size, ) # Find and process layer modules layers_module = self._find_layers_module() if layers_module is not None: layers = list(layers_module.children()) for i, layer in enumerate(layers): if not (start_layer <= i < end_layer): # Replace layers not on this rank with PPMissingLayer setattr(layers_module, str(i), PPMissingLayer()) # Handle embeddings (only on first rank) if not self.pp_group.is_first_rank: input_embeddings = self.model.get_input_embeddings() if input_embeddings is not None: # Keep a reference but mark as missing for forward self._has_embeddings = False else: self._has_embeddings = True # Handle final norm and lm_head (only on last rank) if not self.pp_group.is_last_rank: # Mark lm_head as missing if hasattr(self.model, 'lm_head'): self.model.lm_head = PPMissingLayer() logger.info("Pipeline parallel applied: layers %d-%d on this rank", start_layer, end_layer) def _find_layers_module(self) -> Optional[nn.Module]: """Find the ModuleList containing transformer layers.""" # Common layer container names layer_names = ['layers', 'h', 'blocks', 'layer', 'encoder.layer', 'decoder.layers'] def _search_layers(module: nn.Module, prefix: str = "") -> Optional[nn.Module]: for name, child in module.named_children(): if name in ['layers', 'h', 'blocks', 'layer'] and isinstance(child, nn.ModuleList): return child # Recursively search in model backbone if name in ['model', 'transformer', 'encoder', 'decoder']: result = _search_layers(child, f"{prefix}.{name}" if prefix else name) if result is not None: return result return None return _search_layers(self.model) def _get_tp_plan(self) -> Dict[str, str]: """ Get tensor parallel plan for module replacement. This maps module name patterns to parallelization styles: - "colwise": Column parallel (split output dim) - "rowwise": Row parallel (split input dim) - "replicate": Replicated (no split) Returns a dict mapping regex patterns to styles. """ # Check if model has explicit tp_plan if hasattr(self.model, 'tp_plan') and self.model.tp_plan: return {maybe_prefix("model", k): v for k, v in self.model.tp_plan.items()} # Default tp_plan for common LLM architectures # Based on typical transformer structure return { r".*\.q_proj$": "colwise", r".*\.k_proj$": "colwise", r".*\.v_proj$": "colwise", r".*\.o_proj$": "rowwise", r".*\.gate_proj$": "colwise", r".*\.up_proj$": "colwise", r".*\.down_proj$": "rowwise", r".*\.query$": "colwise", r".*\.key$": "colwise", r".*\.value$": "colwise", r".*\.dense$": "rowwise", r".*\.fc1$": "colwise", r".*\.fc2$": "rowwise", } def _replace_modules(self) -> None: """ Replace modules with vLLM optimized versions. Uses tp_plan for tensor parallel style selection. Note: lm_head is NOT replaced here - it's created at wrapper level by CausalMixin. """ logger.info("Replacing modules with vLLM optimized versions...") replaced_count = 0 # Get tensor parallel plan tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {} # Modules to skip replacement (handled at wrapper level) skip_modules = {"lm_head", "score", "classifier"} def _recursive_replace(module: nn.Module, prefix: str = ""): nonlocal replaced_count for name, child in list(module.named_children()): # Skip PPMissingLayer if isinstance(child, PPMissingLayer): continue # Skip modules that are handled at wrapper level if name in skip_modules: logger.debug("Skipping %s (handled at wrapper level)", name) continue qual_name = maybe_prefix(prefix, name) new_module = None if isinstance(child, nn.Linear): # Determine parallelization style from tp_plan style = "replicate" for pattern, plan_style in tp_plan.items(): if re.match(pattern, qual_name): style = plan_style break new_module = replace_linear_class( child, style=style, quant_config=self.quant_config, prefix=qual_name, ) replaced_count += 1 elif child.__class__.__name__.endswith("RMSNorm") and \ not isinstance(child, RMSNorm): new_module = replace_rms_norm_class(child, self.hidden_size) replaced_count += 1 if new_module is not None: setattr(module, name, new_module) log_replacement(qual_name, child, new_module) else: _recursive_replace(child, qual_name) _recursive_replace(self.model, "model") logger.info("Replaced %d modules", replaced_count) def _add_attention_debug_hook(self) -> None: """No-op. Debug hooks removed after root cause identified.""" pass def _fix_attention_head_dim(self) -> None: """ Fix head_dim in attention modules and rotary embeddings after model creation. Some models may have incorrect head_dim in config, which causes Transformers attention modules and RoPE to use wrong dimensions. This method corrects head_dim in all attention modules and recreates rotary embeddings if needed. """ correct_head_dim = self.hidden_size // getattr( self.text_config, "num_attention_heads", 32 ) fixed_count = 0 for name, module in self.model.named_modules(): module_name = module.__class__.__name__ # Fix head_dim in Attention modules if "Attention" in module_name: if hasattr(module, "head_dim"): if module.head_dim != correct_head_dim: logger.warning( "Fixing head_dim in %s: %d -> %d", name, module.head_dim, correct_head_dim ) module.head_dim = correct_head_dim fixed_count += 1 # Fix rotary embeddings - recreate inv_freq buffer if needed if "RotaryEmbedding" in module_name: if hasattr(module, "inv_freq"): current_dim = module.inv_freq.shape[0] * 2 if current_dim != correct_head_dim: logger.warning( "Recreating rotary embedding %s: dim %d -> %d", name, current_dim, correct_head_dim ) base = getattr(module.config, 'rope_theta', 10000.0) if hasattr(module.config, 'rope_parameters'): base = module.config.rope_parameters.get('rope_theta', base) device = module.inv_freq.device inv_freq = 1.0 / ( base ** ( torch.arange(0, correct_head_dim, 2, dtype=torch.int64) .to(device=device, dtype=torch.float) / correct_head_dim ) ) module.register_buffer("inv_freq", inv_freq, persistent=False) if hasattr(module, "original_inv_freq"): module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) if fixed_count > 0: logger.info("Fixed head_dim in %d attention modules", fixed_count) def _replace_input_embeddings(self) -> None: """Replace input embeddings with VocabParallelEmbedding.""" input_embeddings = self.model.get_input_embeddings() if input_embeddings is None or isinstance(input_embeddings, PPMissingLayer): return if hasattr(input_embeddings, "embedding_dim"): embedding_dim = input_embeddings.embedding_dim elif hasattr(input_embeddings, "weight"): embedding_dim = input_embeddings.weight.shape[1] else: embedding_dim = self.hidden_size self.embed_scale = getattr(input_embeddings, "embed_scale", None) logger.info("Replacing input embeddings (vocab=%d, dim=%d)", self.vocab_size, embedding_dim) new_embeddings = VocabParallelEmbedding( self.vocab_size, embedding_dim, org_num_embeddings=self.vocab_size, quant_config=self.quant_config, ) self.model.set_input_embeddings(new_embeddings) def _create_attention_instances(self) -> Dict[int, Attention]: """Create Attention instances for KV cache allocation.""" num_layers = getattr(self.text_config, "num_hidden_layers", getattr(self.text_config, "num_layers", 32)) num_heads = getattr(self.text_config, "num_attention_heads", 32) head_size = self.hidden_size // num_heads num_kv_heads = getattr(self.text_config, "num_key_value_heads", num_heads) # Get PP layer range pp_rank = self.pp_group.rank_in_group pp_size = self.pp_group.world_size start_layer, end_layer = get_pp_indices(num_layers, pp_rank, pp_size) logger.info("Creating attention instances for layers %d-%d " "(heads=%d, head_size=%d, kv_heads=%d)", start_layer, end_layer, num_heads, head_size, num_kv_heads) attention_instances: Dict[int, Attention] = {} for layer_idx in range(start_layer, end_layer): per_layer_sliding_window = None if hasattr(self.config, "layer_types"): layer_types = self.config.layer_types if layer_idx < len(layer_types) and layer_types[layer_idx] == "sliding_attention": per_layer_sliding_window = getattr(self.config, "sliding_window", None) attention = Attention( num_heads=num_heads, head_size=head_size, scale=1.0 / (head_size ** 0.5), num_kv_heads=num_kv_heads, cache_config=self.cache_config, quant_config=self.quant_config, prefix=f"model.layers.{layer_idx}.self_attn", ) attention_instances[layer_idx] = attention return attention_instances def _init_parameters(self) -> None: """Initialize parameters from meta device to target device.""" device = self.device_config.device if device is None: device = torch.device("cpu") dtype = self.model_config.dtype def _init_params(module: nn.Module): if isinstance(module, PPMissingLayer): return for name, param in list(module.named_parameters(recurse=False)): if param.device == torch.device("meta"): new_param = nn.Parameter( torch.empty_like(param.data, dtype=dtype, device=device), requires_grad=False, ) setattr(module, name, new_param) for child in module.children(): _init_params(child) _init_params(self.model) logger.info("Parameters initialized on %s", device) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: """Get embeddings for input IDs.""" inputs_embeds = self.model.get_input_embeddings()(input_ids) if self.embed_scale is not None: inputs_embeds = inputs_embeds * self.embed_scale return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: "AttentionMetadata", intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """Forward pass with pipeline parallel support.""" # Handle intermediate tensors for PP if not self.pp_group.is_first_rank: assert intermediate_tensors is not None input_ids = None inputs_embeds = intermediate_tensors["hidden_states"] set_attention_context(attn_metadata, kv_caches) try: # Prepare inputs if inputs_embeds is not None: if inputs_embeds.dim() == 2: inputs_embeds = inputs_embeds.unsqueeze(0) model_inputs = {"inputs_embeds": inputs_embeds} else: if input_ids is not None and input_ids.dim() == 1: input_ids = input_ids.unsqueeze(0) model_inputs = {"input_ids": input_ids} if positions is not None: if positions.dim() == 1: positions = positions.unsqueeze(0) model_inputs["position_ids"] = positions # Apply embed_scale if needed if ( self.embed_scale is not None and input_ids is not None and inputs_embeds is None ): inputs_embeds = self.embed_input_ids(model_inputs["input_ids"]) model_inputs = {"inputs_embeds": inputs_embeds} if positions is not None: model_inputs["position_ids"] = positions # Forward through model # Note: return_dict=False returns tuple, first element is last hidden state with torch.no_grad(): outputs = self.model( **model_inputs, use_cache=False, return_dict=False, attention_instances=self.attention_instances, ) # Get hidden states from model output # For models using return_dict=False, outputs is a tuple # outputs[0] is usually the last hidden state if isinstance(outputs, tuple): hidden_states = outputs[0] else: hidden_states = outputs # Remove batch dimension if hidden_states.dim() == 3 and hidden_states.size(0) == 1: hidden_states = hidden_states.squeeze(0) # Return intermediate tensors for PP if not self.pp_group.is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) return hidden_states finally: clear_attention_context() def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], ) -> Set[str]: """Load weights using AutoWeightsLoader with name mapping.""" loader = AutoWeightsLoader( self, skip_prefixes=self.skip_prefixes, ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, ) loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) logger.info("Loaded %d weight tensors", len(loaded)) return set(loaded)