Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py
2026-02-06 14:17:06 +08:00

705 lines
28 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend base class for v0.6.2.
This module provides the Base class following latest vLLM architecture:
- Meta device initialization for memory efficiency
- Pipeline parallel support (PPMissingLayer)
- Tensor parallel support (tp_plan based module replacement)
- Module replacement (Linear, RMSNorm) with vLLM optimized versions
- VocabParallelEmbedding for input embeddings
- Attention instances for KV cache allocation
- Weight loading with AutoWeightsLoader and WeightsMapper
"""
import re
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
make_empty_intermediate_tensors_factory,
)
from vllm.attention.layer import Attention
from vllm.sequence import IntermediateTensors
from .utils import (
init_on_device_without_buffers,
replace_linear_class,
replace_rms_norm_class,
log_replacement,
maybe_prefix,
)
if TYPE_CHECKING:
from transformers import PreTrainedModel
from vllm.attention import AttentionMetadata
logger = init_logger(__name__)
# ============================================================================
# Attention Context Management (for vLLM attention integration)
# ============================================================================
_current_attn_metadata = None
_current_kv_caches = None
def set_attention_context(attn_metadata, kv_caches):
"""Set the current attention context for vLLM attention functions."""
global _current_attn_metadata, _current_kv_caches
_current_attn_metadata = attn_metadata
_current_kv_caches = kv_caches
def clear_attention_context():
"""Clear the current attention context after forward pass."""
global _current_attn_metadata, _current_kv_caches
_current_attn_metadata = None
_current_kv_caches = None
def get_attention_context():
"""Get the current attention context."""
return _current_attn_metadata, _current_kv_caches
# ============================================================================
# vLLM Attention Function for Transformers Integration
# ============================================================================
def vllm_flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
scaling: float = None,
attention_instances: Dict[int, Attention] = None,
**kwargs,
):
"""
vLLM's optimized attention function for transformers integration.
In v0.6.2, Attention.forward signature is:
(query, key, value, kv_cache, attn_metadata)
"""
layer_idx = getattr(module, 'layer_idx', 0)
if attention_instances is None or layer_idx not in attention_instances:
return _standard_attention(query, key, value, attention_mask, scaling)
self_attn = attention_instances[layer_idx]
attn_metadata, kv_caches = get_attention_context()
if attn_metadata is None or kv_caches is None:
return _standard_attention(query, key, value, attention_mask, scaling)
if scaling is not None:
self_attn.impl.scale = float(scaling)
# Reshape: [batch, heads, seq, head_dim] -> [seq, heads * head_dim]
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
kv_cache = kv_caches[layer_idx] if layer_idx < len(kv_caches) else None
output = self_attn.forward(query, key, value, kv_cache, attn_metadata)
return output, None
def _standard_attention(query, key, value, attention_mask, scaling):
"""Standard scaled dot-product attention fallback."""
attn_weights = torch.matmul(query, key.transpose(-2, -1))
if scaling is not None:
attn_weights = attn_weights * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, None
# Register vLLM attention to transformers
_vllm_attention_registered = False
try:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
_vllm_attention_registered = True
logger.info("Registered vLLM attention function to transformers")
except (ImportError, AttributeError) as e:
logger.warning("Could not register vLLM attention: %s", e)
# ============================================================================
# Base Class with Pipeline Parallel and Tensor Parallel Support
# ============================================================================
class Base(nn.Module):
"""
Base class for Transformers backend models with full parallel support.
Features:
- Pipeline Parallel: PPMissingLayer for distributed layers
- Tensor Parallel: tp_plan based module replacement
- Meta device initialization
- Module replacement (Linear → vLLM Linear, RMSNorm → vLLM RMSNorm)
- VocabParallelEmbedding for input embeddings
- Attention instances for KV cache allocation
"""
# For vLLM's weight loader
embedding_modules = ["embed_tokens"]
# Weight name mapping following latest vLLM pattern
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Add `model.` prefix for base model checkpoints,
# handling the case where it is already present
"": "model.",
"model.model.": "model.",
# Heads will be adjacent to `model` (pooling included because of adapters)
"model.lm_head.": "lm_head.",
"model.score.": "classifier.",
"model.classifier.": "classifier.",
}
)
# Note: __init_subclass__ with WeightsMapper merging is not supported in v0.6.2
# because WeightsMapper doesn't implement __or__/__ior__ operators.
# Each Mixin should define its own hf_to_vllm_mapper if needed.
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
logger.info("Using Transformers modeling backend.")
# Store configuration
self.config = vllm_config.model_config.hf_config
self.text_config = getattr(self.config, "text_config", self.config)
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.device_config = vllm_config.device_config
self.parallel_config = vllm_config.parallel_config
self.quant_config = vllm_config.quant_config
self.prefix = prefix
# Parallel groups
self.pp_group = get_pp_group()
self.tp_group = get_tp_group()
# Model dimensions
self.hidden_size = getattr(self.text_config, "hidden_size", 4096)
self.vocab_size = getattr(self.text_config, "vocab_size", 32000)
# Weight loading configuration
self.skip_prefixes: List[str] = []
self.ignore_unexpected_prefixes: List[str] = []
# Configure attention backend
self._configure_attention_backend()
# Create model on meta device
self._init_model_on_meta()
# Apply pipeline parallel
self._apply_pipeline_parallel()
# Replace modules (with tensor parallel support)
self._replace_modules()
# Fix attention head_dim in case config was incorrect
self._fix_attention_head_dim()
# Add debug hook to first attention module to capture tensor shapes
self._add_attention_debug_hook()
# Replace input embeddings
self._replace_input_embeddings()
# Create attention instances
self.attention_instances = self._create_attention_instances()
# Initialize parameters on target device
self._init_parameters()
# Pipeline parallel intermediate tensors
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], self.hidden_size
)
def _configure_attention_backend(self) -> None:
"""Configure vLLM attention backend."""
# Note: attention implementation is set in _init_model_on_meta
# This method is kept for potential platform-specific configuration
pass
def _init_model_on_meta(self) -> None:
"""Create model structure on meta device."""
from transformers import AutoModel
logger.info("Creating model structure on meta device...")
# Set attention implementation to vLLM's
self.text_config._attn_implementation = "vllm"
# Ensure head_dim is correctly set in BOTH config and text_config
# Transformers models use config.head_dim to compute attention dimensions
# Some models may have incorrect head_dim, so we compute and set it
if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"):
correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads
# Check and fix head_dim in text_config
if hasattr(self.text_config, "head_dim"):
if self.text_config.head_dim != correct_head_dim:
logger.warning(
"Correcting head_dim in text_config: %d -> %d",
self.text_config.head_dim, correct_head_dim
)
self.text_config.head_dim = correct_head_dim
else:
self.text_config.head_dim = correct_head_dim
# Also set in self.config (which is passed to AutoModel.from_config)
if hasattr(self.config, "head_dim"):
if self.config.head_dim != correct_head_dim:
logger.warning(
"Correcting head_dim in config: %d -> %d",
self.config.head_dim, correct_head_dim
)
self.config.head_dim = correct_head_dim
else:
self.config.head_dim = correct_head_dim
# Some models also need _attn_implementation in config
self.config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"):
self.model: "PreTrainedModel" = AutoModel.from_config(
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
def _apply_pipeline_parallel(self) -> None:
"""
Apply pipeline parallelization plan.
For models that don't explicitly support pp_plan, we do a best-effort
approach by splitting layers based on num_hidden_layers.
"""
if self.pp_group.world_size <= 1:
return
logger.info("Applying pipeline parallel (world_size=%d, rank=%d)",
self.pp_group.world_size, self.pp_group.rank_in_group)
num_layers = getattr(self.text_config, "num_hidden_layers",
getattr(self.text_config, "num_layers", 32))
start_layer, end_layer = get_pp_indices(
num_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
)
# Find and process layer modules
layers_module = self._find_layers_module()
if layers_module is not None:
layers = list(layers_module.children())
for i, layer in enumerate(layers):
if not (start_layer <= i < end_layer):
# Replace layers not on this rank with PPMissingLayer
setattr(layers_module, str(i), PPMissingLayer())
# Handle embeddings (only on first rank)
if not self.pp_group.is_first_rank:
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is not None:
# Keep a reference but mark as missing for forward
self._has_embeddings = False
else:
self._has_embeddings = True
# Handle final norm and lm_head (only on last rank)
if not self.pp_group.is_last_rank:
# Mark lm_head as missing
if hasattr(self.model, 'lm_head'):
self.model.lm_head = PPMissingLayer()
logger.info("Pipeline parallel applied: layers %d-%d on this rank",
start_layer, end_layer)
def _find_layers_module(self) -> Optional[nn.Module]:
"""Find the ModuleList containing transformer layers."""
# Common layer container names
layer_names = ['layers', 'h', 'blocks', 'layer', 'encoder.layer', 'decoder.layers']
def _search_layers(module: nn.Module, prefix: str = "") -> Optional[nn.Module]:
for name, child in module.named_children():
if name in ['layers', 'h', 'blocks', 'layer'] and isinstance(child, nn.ModuleList):
return child
# Recursively search in model backbone
if name in ['model', 'transformer', 'encoder', 'decoder']:
result = _search_layers(child, f"{prefix}.{name}" if prefix else name)
if result is not None:
return result
return None
return _search_layers(self.model)
def _get_tp_plan(self) -> Dict[str, str]:
"""
Get tensor parallel plan for module replacement.
This maps module name patterns to parallelization styles:
- "colwise": Column parallel (split output dim)
- "rowwise": Row parallel (split input dim)
- "replicate": Replicated (no split)
Returns a dict mapping regex patterns to styles.
"""
# Check if model has explicit tp_plan
if hasattr(self.model, 'tp_plan') and self.model.tp_plan:
return {maybe_prefix("model", k): v for k, v in self.model.tp_plan.items()}
# Default tp_plan for common LLM architectures
# Based on typical transformer structure
return {
r".*\.q_proj$": "colwise",
r".*\.k_proj$": "colwise",
r".*\.v_proj$": "colwise",
r".*\.o_proj$": "rowwise",
r".*\.gate_proj$": "colwise",
r".*\.up_proj$": "colwise",
r".*\.down_proj$": "rowwise",
r".*\.query$": "colwise",
r".*\.key$": "colwise",
r".*\.value$": "colwise",
r".*\.dense$": "rowwise",
r".*\.fc1$": "colwise",
r".*\.fc2$": "rowwise",
}
def _replace_modules(self) -> None:
"""
Replace modules with vLLM optimized versions.
Uses tp_plan for tensor parallel style selection.
Note: lm_head is NOT replaced here - it's created at wrapper level by CausalMixin.
"""
logger.info("Replacing modules with vLLM optimized versions...")
replaced_count = 0
# Get tensor parallel plan
tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {}
# Modules to skip replacement (handled at wrapper level)
skip_modules = {"lm_head", "score", "classifier"}
def _recursive_replace(module: nn.Module, prefix: str = ""):
nonlocal replaced_count
for name, child in list(module.named_children()):
# Skip PPMissingLayer
if isinstance(child, PPMissingLayer):
continue
# Skip modules that are handled at wrapper level
if name in skip_modules:
logger.debug("Skipping %s (handled at wrapper level)", name)
continue
qual_name = maybe_prefix(prefix, name)
new_module = None
if isinstance(child, nn.Linear):
# Determine parallelization style from tp_plan
style = "replicate"
for pattern, plan_style in tp_plan.items():
if re.match(pattern, qual_name):
style = plan_style
break
new_module = replace_linear_class(
child,
style=style,
quant_config=self.quant_config,
prefix=qual_name,
)
replaced_count += 1
elif child.__class__.__name__.endswith("RMSNorm") and \
not isinstance(child, RMSNorm):
new_module = replace_rms_norm_class(child, self.hidden_size)
replaced_count += 1
if new_module is not None:
setattr(module, name, new_module)
log_replacement(qual_name, child, new_module)
else:
_recursive_replace(child, qual_name)
_recursive_replace(self.model, "model")
logger.info("Replaced %d modules", replaced_count)
def _add_attention_debug_hook(self) -> None:
"""No-op. Debug hooks removed after root cause identified."""
pass
def _fix_attention_head_dim(self) -> None:
"""
Fix head_dim in attention modules and rotary embeddings after model creation.
Some models may have incorrect head_dim in config, which causes
Transformers attention modules and RoPE to use wrong dimensions.
This method corrects head_dim in all attention modules and recreates
rotary embeddings if needed.
"""
correct_head_dim = self.hidden_size // getattr(
self.text_config, "num_attention_heads", 32
)
fixed_count = 0
for name, module in self.model.named_modules():
module_name = module.__class__.__name__
# Fix head_dim in Attention modules
if "Attention" in module_name:
if hasattr(module, "head_dim"):
if module.head_dim != correct_head_dim:
logger.warning(
"Fixing head_dim in %s: %d -> %d",
name, module.head_dim, correct_head_dim
)
module.head_dim = correct_head_dim
fixed_count += 1
# Fix rotary embeddings - recreate inv_freq buffer if needed
if "RotaryEmbedding" in module_name:
if hasattr(module, "inv_freq"):
current_dim = module.inv_freq.shape[0] * 2
if current_dim != correct_head_dim:
logger.warning(
"Recreating rotary embedding %s: dim %d -> %d",
name, current_dim, correct_head_dim
)
base = getattr(module.config, 'rope_theta', 10000.0)
if hasattr(module.config, 'rope_parameters'):
base = module.config.rope_parameters.get('rope_theta', base)
device = module.inv_freq.device
inv_freq = 1.0 / (
base ** (
torch.arange(0, correct_head_dim, 2, dtype=torch.int64)
.to(device=device, dtype=torch.float) / correct_head_dim
)
)
module.register_buffer("inv_freq", inv_freq, persistent=False)
if hasattr(module, "original_inv_freq"):
module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
if fixed_count > 0:
logger.info("Fixed head_dim in %d attention modules", fixed_count)
def _replace_input_embeddings(self) -> None:
"""Replace input embeddings with VocabParallelEmbedding."""
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is None or isinstance(input_embeddings, PPMissingLayer):
return
if hasattr(input_embeddings, "embedding_dim"):
embedding_dim = input_embeddings.embedding_dim
elif hasattr(input_embeddings, "weight"):
embedding_dim = input_embeddings.weight.shape[1]
else:
embedding_dim = self.hidden_size
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
logger.info("Replacing input embeddings (vocab=%d, dim=%d)",
self.vocab_size, embedding_dim)
new_embeddings = VocabParallelEmbedding(
self.vocab_size,
embedding_dim,
org_num_embeddings=self.vocab_size,
quant_config=self.quant_config,
)
self.model.set_input_embeddings(new_embeddings)
def _create_attention_instances(self) -> Dict[int, Attention]:
"""Create Attention instances for KV cache allocation."""
num_layers = getattr(self.text_config, "num_hidden_layers",
getattr(self.text_config, "num_layers", 32))
num_heads = getattr(self.text_config, "num_attention_heads", 32)
head_size = self.hidden_size // num_heads
num_kv_heads = getattr(self.text_config, "num_key_value_heads", num_heads)
# Get PP layer range
pp_rank = self.pp_group.rank_in_group
pp_size = self.pp_group.world_size
start_layer, end_layer = get_pp_indices(num_layers, pp_rank, pp_size)
logger.info("Creating attention instances for layers %d-%d "
"(heads=%d, head_size=%d, kv_heads=%d)",
start_layer, end_layer, num_heads, head_size, num_kv_heads)
attention_instances: Dict[int, Attention] = {}
for layer_idx in range(start_layer, end_layer):
per_layer_sliding_window = None
if hasattr(self.config, "layer_types"):
layer_types = self.config.layer_types
if layer_idx < len(layer_types) and layer_types[layer_idx] == "sliding_attention":
per_layer_sliding_window = getattr(self.config, "sliding_window", None)
attention = Attention(
num_heads=num_heads,
head_size=head_size,
scale=1.0 / (head_size ** 0.5),
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"model.layers.{layer_idx}.self_attn",
)
attention_instances[layer_idx] = attention
return attention_instances
def _init_parameters(self) -> None:
"""Initialize parameters from meta device to target device."""
device = self.device_config.device
if device is None:
device = torch.device("cpu")
dtype = self.model_config.dtype
def _init_params(module: nn.Module):
if isinstance(module, PPMissingLayer):
return
for name, param in list(module.named_parameters(recurse=False)):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data, dtype=dtype, device=device),
requires_grad=False,
)
setattr(module, name, new_param)
for child in module.children():
_init_params(child)
_init_params(self.model)
logger.info("Parameters initialized on %s", device)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Get embeddings for input IDs."""
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if self.embed_scale is not None:
inputs_embeds = inputs_embeds * self.embed_scale
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Forward pass with pipeline parallel support."""
# Handle intermediate tensors for PP
if not self.pp_group.is_first_rank:
assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
set_attention_context(attn_metadata, kv_caches)
try:
# Prepare inputs
if inputs_embeds is not None:
if inputs_embeds.dim() == 2:
inputs_embeds = inputs_embeds.unsqueeze(0)
model_inputs = {"inputs_embeds": inputs_embeds}
else:
if input_ids is not None and input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
model_inputs = {"input_ids": input_ids}
if positions is not None:
if positions.dim() == 1:
positions = positions.unsqueeze(0)
model_inputs["position_ids"] = positions
# Apply embed_scale if needed
if (
self.embed_scale is not None
and input_ids is not None
and inputs_embeds is None
):
inputs_embeds = self.embed_input_ids(model_inputs["input_ids"])
model_inputs = {"inputs_embeds": inputs_embeds}
if positions is not None:
model_inputs["position_ids"] = positions
# Forward through model
# Note: return_dict=False returns tuple, first element is last hidden state
with torch.no_grad():
outputs = self.model(
**model_inputs,
use_cache=False,
return_dict=False,
attention_instances=self.attention_instances,
)
# Get hidden states from model output
# For models using return_dict=False, outputs is a tuple
# outputs[0] is usually the last hidden state
if isinstance(outputs, tuple):
hidden_states = outputs[0]
else:
hidden_states = outputs
# Remove batch dimension
if hidden_states.dim() == 3 and hidden_states.size(0) == 1:
hidden_states = hidden_states.squeeze(0)
# Return intermediate tensors for PP
if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
finally:
clear_attention_context()
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Set[str]:
"""Load weights using AutoWeightsLoader with name mapping."""
loader = AutoWeightsLoader(
self,
skip_prefixes=self.skip_prefixes,
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
)
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
logger.info("Loaded %d weight tensors", len(loaded))
return set(loaded)