From 332e5f71a6ee7ef185fb7da7c04d612782447f2d Mon Sep 17 00:00:00 2001 From: Chranos <826995883@qq.com> Date: Thu, 5 Feb 2026 18:02:59 +0800 Subject: [PATCH] testing dynamic register --- .../models/transformers/__init__.py | 105 ++- .../models/transformers/base.py | 600 ++++++++++++++++++ .../models/transformers/causal.py | 577 ++--------------- .../models/transformers/legacy.py | 118 ++++ .../models/transformers/pooling.py | 170 +++++ 5 files changed, 1035 insertions(+), 535 deletions(-) create mode 100644 vllm-v0.6.2/vllm/model_executor/models/transformers/base.py create mode 100644 vllm-v0.6.2/vllm/model_executor/models/transformers/legacy.py create mode 100644 vllm-v0.6.2/vllm/model_executor/models/transformers/pooling.py diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/__init__.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/__init__.py index 7b92369..2b22ef8 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/transformers/__init__.py +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/__init__.py @@ -2,23 +2,43 @@ # Copyright 2024 The vLLM team. """Wrapper around `transformers` models for vLLM v0.6.2. -This module provides an advanced Transformers modeling backend that wraps +This module provides the Transformers modeling backend that wraps any HuggingFace model with the vLLM interface, enabling support for custom models that define their implementation via `auto_map` in config.json. -Key optimizations and features: +Architecture (following latest vLLM patterns): +- Base: Core functionality (meta init, PP/TP support, module replacement, attention, weight loading) +- CausalMixin: Causal LM specific (lm_head, compute_logits, sample) +- EmbeddingMixin: Embedding/pooling specific (pooler, pooling) +- SequenceClassificationMixin: Classification specific (classifier, pooling) + +Composed model classes: +- TransformersForCausalLM = CausalMixin + Base +- TransformersForEmbedding = EmbeddingMixin + Base +- TransformersForSequenceClassification = SequenceClassificationMixin + Base + +Key optimizations: - Meta device initialization for memory efficiency -- Module replacement (Linear, RMSNorm, Embedding) with vLLM optimized versions -- VocabParallelEmbedding for input embeddings +- Pipeline Parallel support (PPMissingLayer) +- Tensor Parallel support (tp_plan based module replacement) +- Module replacement (Linear, RMSNorm, Embedding) with vLLM optimized versions - vLLM Attention instances for proper KV cache allocation - AutoWeightsLoader for efficient weight loading with name mapping -- vLLM attention backend integration (when supported by upgraded transformers) """ -from vllm.model_executor.models.transformers.causal import ( - TransformersForCausalLM, - is_backend_compatible, +from vllm.model_executor.models.transformers.base import ( + Base, + set_attention_context, + clear_attention_context, + get_attention_context, + vllm_flash_attention_forward, ) +from vllm.model_executor.models.transformers.causal import CausalMixin +from vllm.model_executor.models.transformers.pooling import ( + EmbeddingMixin, + SequenceClassificationMixin, +) +from vllm.model_executor.models.transformers.legacy import LegacyMixin from vllm.model_executor.models.transformers.utils import ( init_on_device_without_buffers, replace_linear_class, @@ -27,10 +47,77 @@ from vllm.model_executor.models.transformers.utils import ( maybe_prefix, ) + +# ============================================================================ +# Composed Model Classes (Mixin + Base pattern) +# ============================================================================ + +class TransformersForCausalLM(CausalMixin, Base): + """ + Transformers backend wrapper for causal language models. + + Combines CausalMixin (lm_head, compute_logits, sample) with + Base (meta init, PP/TP support, module replacement, attention, weight loading). + + Supports any HuggingFace model with auto_map in config.json. + """ + pass + + +class TransformersForEmbedding(EmbeddingMixin, Base): + """ + Transformers backend wrapper for embedding/sentence similarity models. + + Combines EmbeddingMixin (pooler, pooling) with + Base (meta init, PP/TP support, module replacement, attention, weight loading). + + Supports embedding models like BERT, sentence-transformers, etc. + """ + pass + + +class TransformersForSequenceClassification(SequenceClassificationMixin, Base): + """ + Transformers backend wrapper for sequence classification models. + + Combines SequenceClassificationMixin (classifier, pooling) with + Base (meta init, PP/TP support, module replacement, attention, weight loading). + + Supports cross-encoders and classification models. + """ + pass + + +class TransformersForLegacy(LegacyMixin, EmbeddingMixin, Base): + """ + Transformers backend wrapper for legacy/encoder models. + + Combines LegacyMixin (BERT/RoBERTa weight mapping, position handling) with + EmbeddingMixin (pooler) and Base (core functionality). + + Supports BERT, RoBERTa, and similar encoder models. + """ + pass + + __all__ = [ # Main wrapper classes "TransformersForCausalLM", - "is_backend_compatible", + "TransformersForEmbedding", + "TransformersForSequenceClassification", + "TransformersForLegacy", + # Base class for extension + "Base", + # Mixin classes for custom combinations + "CausalMixin", + "EmbeddingMixin", + "SequenceClassificationMixin", + "LegacyMixin", + # Attention context management + "set_attention_context", + "clear_attention_context", + "get_attention_context", + "vllm_flash_attention_forward", # Utility functions "init_on_device_without_buffers", "replace_linear_class", diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py new file mode 100644 index 0000000..64de039 --- /dev/null +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py @@ -0,0 +1,600 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 The vLLM team. +"""Transformers modeling backend base class for v0.6.2. + +This module provides the Base class following latest vLLM architecture: +- Meta device initialization for memory efficiency +- Pipeline parallel support (PPMissingLayer) +- Tensor parallel support (tp_plan based module replacement) +- Module replacement (Linear, RMSNorm) with vLLM optimized versions +- VocabParallelEmbedding for input embeddings +- Attention instances for KV cache allocation +- Weight loading with AutoWeightsLoader and WeightsMapper +""" + +import re +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group, get_tp_group +from vllm.distributed.utils import get_pp_indices +from vllm.logger import init_logger +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + PPMissingLayer, + WeightsMapper, + make_empty_intermediate_tensors_factory, +) +from vllm.attention.layer import Attention +from vllm.sequence import IntermediateTensors + +from .utils import ( + init_on_device_without_buffers, + replace_linear_class, + replace_rms_norm_class, + log_replacement, + maybe_prefix, +) + +if TYPE_CHECKING: + from transformers import PreTrainedModel + from vllm.attention import AttentionMetadata + +logger = init_logger(__name__) + + +# ============================================================================ +# Attention Context Management (for vLLM attention integration) +# ============================================================================ + +_current_attn_metadata = None +_current_kv_caches = None + + +def set_attention_context(attn_metadata, kv_caches): + """Set the current attention context for vLLM attention functions.""" + global _current_attn_metadata, _current_kv_caches + _current_attn_metadata = attn_metadata + _current_kv_caches = kv_caches + + +def clear_attention_context(): + """Clear the current attention context after forward pass.""" + global _current_attn_metadata, _current_kv_caches + _current_attn_metadata = None + _current_kv_caches = None + + +def get_attention_context(): + """Get the current attention context.""" + return _current_attn_metadata, _current_kv_caches + + +# ============================================================================ +# vLLM Attention Function for Transformers Integration +# ============================================================================ + +def vllm_flash_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + scaling: float = None, + attention_instances: Dict[int, Attention] = None, + **kwargs, +): + """ + vLLM's optimized attention function for transformers integration. + + In v0.6.2, Attention.forward signature is: + (query, key, value, kv_cache, attn_metadata) + """ + layer_idx = getattr(module, 'layer_idx', 0) + + if attention_instances is None or layer_idx not in attention_instances: + return _standard_attention(query, key, value, attention_mask, scaling) + + self_attn = attention_instances[layer_idx] + attn_metadata, kv_caches = get_attention_context() + + if attn_metadata is None or kv_caches is None: + return _standard_attention(query, key, value, attention_mask, scaling) + + if scaling is not None: + self_attn.impl.scale = float(scaling) + + # Reshape: [batch, heads, seq, head_dim] -> [seq, heads * head_dim] + hidden = query.shape[-2] + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) + + kv_cache = kv_caches[layer_idx] if layer_idx < len(kv_caches) else None + output = self_attn.forward(query, key, value, kv_cache, attn_metadata) + + return output, None + + +def _standard_attention(query, key, value, attention_mask, scaling): + """Standard scaled dot-product attention fallback.""" + attn_weights = torch.matmul(query, key.transpose(-2, -1)) + if scaling is not None: + attn_weights = attn_weights * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value) + return attn_output, None + + +# Register vLLM attention to transformers +_vllm_attention_registered = False +try: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward + _vllm_attention_registered = True + logger.info("Registered vLLM attention function to transformers") +except (ImportError, AttributeError) as e: + logger.warning("Could not register vLLM attention: %s", e) + + +# ============================================================================ +# Base Class with Pipeline Parallel and Tensor Parallel Support +# ============================================================================ + +class Base(nn.Module): + """ + Base class for Transformers backend models with full parallel support. + + Features: + - Pipeline Parallel: PPMissingLayer for distributed layers + - Tensor Parallel: tp_plan based module replacement + - Meta device initialization + - Module replacement (Linear → vLLM Linear, RMSNorm → vLLM RMSNorm) + - VocabParallelEmbedding for input embeddings + - Attention instances for KV cache allocation + """ + + # For vLLM's weight loader + embedding_modules = ["embed_tokens"] + + # Weight name mapping following latest vLLM pattern + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Add `model.` prefix for base model checkpoints, + # handling the case where it is already present + "": "model.", + "model.model.": "model.", + # Heads will be adjacent to `model` (pooling included because of adapters) + "model.lm_head.": "lm_head.", + "model.score.": "classifier.", + "model.classifier.": "classifier.", + } + ) + + def __init_subclass__(cls, *args, **kwargs): + """Merge hf_to_vllm_mapper in MRO from most specific to least specific.""" + super().__init_subclass__(*args, **kwargs) + hf_to_vllm_mapper = WeightsMapper() + for base in cls.__mro__: + if base_hf_to_vllm_mapper := getattr(base, "hf_to_vllm_mapper", None): + hf_to_vllm_mapper |= base_hf_to_vllm_mapper + cls.hf_to_vllm_mapper = hf_to_vllm_mapper + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + logger.info("Using Transformers modeling backend.") + + # Store configuration + self.config = vllm_config.model_config.hf_config + self.text_config = getattr(self.config, "text_config", self.config) + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.device_config = vllm_config.device_config + self.parallel_config = vllm_config.parallel_config + self.quant_config = vllm_config.quant_config + self.prefix = prefix + + # Parallel groups + self.pp_group = get_pp_group() + self.tp_group = get_tp_group() + + # Model dimensions + self.hidden_size = getattr(self.text_config, "hidden_size", 4096) + self.vocab_size = getattr(self.text_config, "vocab_size", 32000) + + # Weight loading configuration + self.skip_prefixes: List[str] = [] + self.ignore_unexpected_prefixes: List[str] = [] + + # Configure attention backend + self._configure_attention_backend() + + # Create model on meta device + self._init_model_on_meta() + + # Apply pipeline parallel + self._apply_pipeline_parallel() + + # Replace modules (with tensor parallel support) + self._replace_modules() + + # Replace input embeddings + self._replace_input_embeddings() + + # Create attention instances + self.attention_instances = self._create_attention_instances() + + # Initialize parameters on target device + self._init_parameters() + + # Pipeline parallel intermediate tensors + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], self.hidden_size + ) + + def _configure_attention_backend(self) -> None: + """Configure vLLM attention backend.""" + # Note: attention implementation is set in _init_model_on_meta + # This method is kept for potential platform-specific configuration + pass + + def _init_model_on_meta(self) -> None: + """Create model structure on meta device.""" + from transformers import AutoModel + + logger.info("Creating model structure on meta device...") + + # Set attention implementation to vLLM's + self.text_config._attn_implementation = "vllm" + + with init_on_device_without_buffers("meta"): + self.model: "PreTrainedModel" = AutoModel.from_config( + self.config, + torch_dtype=self.model_config.dtype, + trust_remote_code=self.model_config.trust_remote_code, + ) + + self.model.eval() + for param in self.model.parameters(): + param.requires_grad = False + + def _apply_pipeline_parallel(self) -> None: + """ + Apply pipeline parallelization plan. + + For models that don't explicitly support pp_plan, we do a best-effort + approach by splitting layers based on num_hidden_layers. + """ + if self.pp_group.world_size <= 1: + return + + logger.info("Applying pipeline parallel (world_size=%d, rank=%d)", + self.pp_group.world_size, self.pp_group.rank_in_group) + + num_layers = getattr(self.text_config, "num_hidden_layers", + getattr(self.text_config, "num_layers", 32)) + + start_layer, end_layer = get_pp_indices( + num_layers, + self.pp_group.rank_in_group, + self.pp_group.world_size, + ) + + # Find and process layer modules + layers_module = self._find_layers_module() + if layers_module is not None: + layers = list(layers_module.children()) + for i, layer in enumerate(layers): + if not (start_layer <= i < end_layer): + # Replace layers not on this rank with PPMissingLayer + setattr(layers_module, str(i), PPMissingLayer()) + + # Handle embeddings (only on first rank) + if not self.pp_group.is_first_rank: + input_embeddings = self.model.get_input_embeddings() + if input_embeddings is not None: + # Keep a reference but mark as missing for forward + self._has_embeddings = False + else: + self._has_embeddings = True + + # Handle final norm and lm_head (only on last rank) + if not self.pp_group.is_last_rank: + # Mark lm_head as missing + if hasattr(self.model, 'lm_head'): + self.model.lm_head = PPMissingLayer() + + logger.info("Pipeline parallel applied: layers %d-%d on this rank", + start_layer, end_layer) + + def _find_layers_module(self) -> Optional[nn.Module]: + """Find the ModuleList containing transformer layers.""" + # Common layer container names + layer_names = ['layers', 'h', 'blocks', 'layer', 'encoder.layer', 'decoder.layers'] + + def _search_layers(module: nn.Module, prefix: str = "") -> Optional[nn.Module]: + for name, child in module.named_children(): + if name in ['layers', 'h', 'blocks', 'layer'] and isinstance(child, nn.ModuleList): + return child + # Recursively search in model backbone + if name in ['model', 'transformer', 'encoder', 'decoder']: + result = _search_layers(child, f"{prefix}.{name}" if prefix else name) + if result is not None: + return result + return None + + return _search_layers(self.model) + + def _get_tp_plan(self) -> Dict[str, str]: + """ + Get tensor parallel plan for module replacement. + + This maps module name patterns to parallelization styles: + - "colwise": Column parallel (split output dim) + - "rowwise": Row parallel (split input dim) + - "replicate": Replicated (no split) + + Returns a dict mapping regex patterns to styles. + """ + # Check if model has explicit tp_plan + if hasattr(self.model, 'tp_plan') and self.model.tp_plan: + return {maybe_prefix("model", k): v for k, v in self.model.tp_plan.items()} + + # Default tp_plan for common LLM architectures + # Based on typical transformer structure + return { + r".*\.q_proj$": "colwise", + r".*\.k_proj$": "colwise", + r".*\.v_proj$": "colwise", + r".*\.o_proj$": "rowwise", + r".*\.gate_proj$": "colwise", + r".*\.up_proj$": "colwise", + r".*\.down_proj$": "rowwise", + r".*\.query$": "colwise", + r".*\.key$": "colwise", + r".*\.value$": "colwise", + r".*\.dense$": "rowwise", + r".*\.fc1$": "colwise", + r".*\.fc2$": "rowwise", + } + + def _replace_modules(self) -> None: + """ + Replace modules with vLLM optimized versions. + + Uses tp_plan for tensor parallel style selection. + """ + logger.info("Replacing modules with vLLM optimized versions...") + replaced_count = 0 + + # Get tensor parallel plan + tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {} + + def _recursive_replace(module: nn.Module, prefix: str = ""): + nonlocal replaced_count + + for name, child in list(module.named_children()): + # Skip PPMissingLayer + if isinstance(child, PPMissingLayer): + continue + + qual_name = maybe_prefix(prefix, name) + new_module = None + + if isinstance(child, nn.Linear): + # Determine parallelization style from tp_plan + style = "replicate" + for pattern, plan_style in tp_plan.items(): + if re.match(pattern, qual_name): + style = plan_style + break + + new_module = replace_linear_class( + child, + style=style, + quant_config=self.quant_config, + prefix=qual_name, + ) + replaced_count += 1 + + elif child.__class__.__name__.endswith("RMSNorm"): + new_module = replace_rms_norm_class(child, self.hidden_size) + replaced_count += 1 + + if new_module is not None: + setattr(module, name, new_module) + log_replacement(qual_name, child, new_module) + else: + _recursive_replace(child, qual_name) + + _recursive_replace(self.model, "model") + logger.info("Replaced %d modules", replaced_count) + + def _replace_input_embeddings(self) -> None: + """Replace input embeddings with VocabParallelEmbedding.""" + input_embeddings = self.model.get_input_embeddings() + if input_embeddings is None or isinstance(input_embeddings, PPMissingLayer): + return + + if hasattr(input_embeddings, "embedding_dim"): + embedding_dim = input_embeddings.embedding_dim + elif hasattr(input_embeddings, "weight"): + embedding_dim = input_embeddings.weight.shape[1] + else: + embedding_dim = self.hidden_size + + self.embed_scale = getattr(input_embeddings, "embed_scale", None) + + logger.info("Replacing input embeddings (vocab=%d, dim=%d)", + self.vocab_size, embedding_dim) + + new_embeddings = VocabParallelEmbedding( + self.vocab_size, + embedding_dim, + org_num_embeddings=self.vocab_size, + quant_config=self.quant_config, + ) + self.model.set_input_embeddings(new_embeddings) + + def _create_attention_instances(self) -> Dict[int, Attention]: + """Create Attention instances for KV cache allocation.""" + num_layers = getattr(self.text_config, "num_hidden_layers", + getattr(self.text_config, "num_layers", 32)) + num_heads = getattr(self.text_config, "num_attention_heads", 32) + head_size = self.hidden_size // num_heads + num_kv_heads = getattr(self.text_config, "num_key_value_heads", num_heads) + + # Get PP layer range + pp_rank = self.pp_group.rank_in_group + pp_size = self.pp_group.world_size + start_layer, end_layer = get_pp_indices(num_layers, pp_rank, pp_size) + + logger.info("Creating attention instances for layers %d-%d " + "(heads=%d, head_size=%d, kv_heads=%d)", + start_layer, end_layer, num_heads, head_size, num_kv_heads) + + attention_instances: Dict[int, Attention] = {} + for layer_idx in range(start_layer, end_layer): + per_layer_sliding_window = None + if hasattr(self.config, "layer_types"): + layer_types = self.config.layer_types + if layer_idx < len(layer_types) and layer_types[layer_idx] == "sliding_attention": + per_layer_sliding_window = getattr(self.config, "sliding_window", None) + + attention = Attention( + num_heads=num_heads, + head_size=head_size, + scale=1.0 / (head_size ** 0.5), + num_kv_heads=num_kv_heads, + cache_config=self.cache_config, + quant_config=self.quant_config, + prefix=f"model.layers.{layer_idx}.self_attn", + ) + attention_instances[layer_idx] = attention + + return attention_instances + + def _init_parameters(self) -> None: + """Initialize parameters from meta device to target device.""" + device = self.device_config.device + if device is None: + device = torch.device("cpu") + dtype = self.model_config.dtype + + def _init_params(module: nn.Module): + if isinstance(module, PPMissingLayer): + return + for name, param in list(module.named_parameters(recurse=False)): + if param.device == torch.device("meta"): + new_param = nn.Parameter( + torch.empty_like(param.data, dtype=dtype, device=device), + requires_grad=False, + ) + setattr(module, name, new_param) + for child in module.children(): + _init_params(child) + + _init_params(self.model) + logger.info("Parameters initialized on %s", device) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """Get embeddings for input IDs.""" + inputs_embeds = self.model.get_input_embeddings()(input_ids) + if self.embed_scale is not None: + inputs_embeds = inputs_embeds * self.embed_scale + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """Forward pass with pipeline parallel support.""" + # Handle intermediate tensors for PP + if not self.pp_group.is_first_rank: + assert intermediate_tensors is not None + input_ids = None + inputs_embeds = intermediate_tensors["hidden_states"] + + set_attention_context(attn_metadata, kv_caches) + + try: + # Prepare inputs + if inputs_embeds is not None: + if inputs_embeds.dim() == 2: + inputs_embeds = inputs_embeds.unsqueeze(0) + model_inputs = {"inputs_embeds": inputs_embeds} + else: + if input_ids is not None and input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + model_inputs = {"input_ids": input_ids} + + if positions is not None: + if positions.dim() == 1: + positions = positions.unsqueeze(0) + model_inputs["position_ids"] = positions + + # Apply embed_scale if needed + if ( + self.embed_scale is not None + and input_ids is not None + and inputs_embeds is None + ): + inputs_embeds = self.embed_input_ids(model_inputs["input_ids"]) + model_inputs = {"inputs_embeds": inputs_embeds} + if positions is not None: + model_inputs["position_ids"] = positions + + # Forward through model + with torch.no_grad(): + outputs = self.model( + **model_inputs, + use_cache=False, + return_dict=True, + output_hidden_states=True, + attention_instances=self.attention_instances, + ) + + # Get hidden states + if outputs.hidden_states is not None: + hidden_states = outputs.hidden_states[-1] + else: + hidden_states = outputs.logits + + # Remove batch dimension + if hidden_states.dim() == 3 and hidden_states.size(0) == 1: + hidden_states = hidden_states.squeeze(0) + + # Return intermediate tensors for PP + if not self.pp_group.is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + return hidden_states + + finally: + clear_attention_context() + + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> Set[str]: + """Load weights using AutoWeightsLoader with name mapping.""" + loader = AutoWeightsLoader( + self, + skip_prefixes=self.skip_prefixes, + ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, + ) + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + logger.info("Loaded %d weight tensors", len(loaded)) + return set(loaded) diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py index 4f2b299..9d38d3f 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py @@ -1,533 +1,66 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright 2024 The vLLM team. -"""Transformers modeling backend for causal language models. +"""Transformers modeling backend mixin for causal language models. -This module provides a wrapper class that enables vLLM to use any HuggingFace -causal language model, including custom models that define their implementation -via `auto_map` in config.json. +This module provides CausalMixin that adds causal language model specific +functionality (lm_head, compute_logits, sample) to the Base class. -Key optimizations: -1. Use meta device for delayed memory allocation -2. Replace nn.Linear with vLLM's optimized Linear classes -3. Replace RMSNorm with vLLM's fused RMSNorm -4. Replace input embeddings with VocabParallelEmbedding -5. Use vLLM's weight loading infrastructure (AutoWeightsLoader) +Following latest vLLM architecture: +- TransformersForCausalLM = CausalMixin + Base """ -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Optional import torch -import torch.nn as nn -from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, -) -from vllm.attention.layer import Attention from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.models.utils import ( - AutoWeightsLoader, - WeightsMapper, -) -from vllm.model_executor.models.transformers.utils import ( - init_on_device_without_buffers, - replace_linear_class, - replace_rms_norm_class, - log_replacement, - maybe_prefix, -) -from vllm.sequence import IntermediateTensors if TYPE_CHECKING: - from transformers import PreTrainedModel - from vllm.attention import AttentionMetadata + from vllm.config import VllmConfig logger = init_logger(__name__) -# Note: In v0.6.2, the vLLM Attention.forward requires (query, key, value, kv_cache, attn_metadata). -# The transformers backend integration works differently than in latest vLLM. -# We keep the vllm_flash_attention_forward for reference, but it may not be compatible -# with all transformers versions or MLU backends. - -# Global variable to store current attention metadata (set during forward pass) -_current_attn_metadata = None -_current_kv_caches = None - - -def set_attention_context(attn_metadata, kv_caches): - """Set the current attention context for vLLM attention functions.""" - global _current_attn_metadata, _current_kv_caches - _current_attn_metadata = attn_metadata - _current_kv_caches = kv_caches - - -def clear_attention_context(): - """Clear the current attention context.""" - global _current_attn_metadata, _current_kv_caches - _current_attn_metadata = None - _current_kv_caches = None - - -def vllm_flash_attention_forward( - # Transformers args - module: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor, - # Transformers kwargs - scaling: float = None, - # vLLM kwargs - attention_instances: Dict[int, Attention] = None, - **kwargs, -): +class CausalMixin: """ - vLLM's optimized attention function that replaces HuggingFace's attention. - This function is registered to transformers' ALL_ATTENTION_FUNCTIONS. + Mixin class that adds causal language model functionality. - Note: In v0.6.2, this function may have limited functionality due to - API differences in the Attention layer. For full functionality, the model - should fall back to HuggingFace's native attention when vLLM attention - is not properly configured. - """ - # Get the attention instance for this layer - layer_idx = getattr(module, 'layer_idx', 0) + This mixin provides: + - LogitsProcessor for logits computation + - Sampler for token sampling + - compute_logits method for VllmModelForTextGeneration protocol + - sample method for VllmModelForTextGeneration protocol - if attention_instances is None or layer_idx not in attention_instances: - # Fall back to standard attention computation - logger.debug("No attention instance for layer %d, using standard attention", layer_idx) - # Standard scaled dot-product attention - attn_weights = torch.matmul(query, key.transpose(-2, -1)) - if scaling is not None: - attn_weights = attn_weights * scaling - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, value) - return attn_output, None - - self_attn = attention_instances[layer_idx] - - # v0.6.2 Attention.forward requires: (query, key, value, kv_cache, attn_metadata) - # We need to get these from the global context - global _current_attn_metadata, _current_kv_caches - - if _current_attn_metadata is None or _current_kv_caches is None: - # No context set, fall back to standard attention - logger.debug("No attention context, using standard attention for layer %d", layer_idx) - attn_weights = torch.matmul(query, key.transpose(-2, -1)) - if scaling is not None: - attn_weights = attn_weights * scaling - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) - attn_output = torch.matmul(attn_weights, value) - return attn_output, None - - # Update scale if provided - if scaling is not None: - self_attn.impl.scale = float(scaling) - - # Reshape tensors for vLLM: [batch, heads, seq, head_dim] -> [seq, heads * head_dim] - hidden = query.shape[-2] - query, key, value = (x.transpose(1, 2) for x in (query, key, value)) - query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) - - # Get KV cache for this layer - kv_cache = _current_kv_caches[layer_idx] if layer_idx < len(_current_kv_caches) else None - - # Call vLLM attention - output = self_attn.forward(query, key, value, kv_cache, _current_attn_metadata) - - return output, None - - -# Try to register vLLM attention to transformers -_vllm_attention_registered = False -try: - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward - _vllm_attention_registered = True - logger.info("Registered vLLM attention function to transformers") -except (ImportError, AttributeError) as e: - logger.warning("Could not register vLLM attention function - " - "transformers version may not support custom attention: %s", e) - - -class TransformersForCausalLM(nn.Module): - """ - A wrapper class that adapts any HuggingFace causal language model - to the vLLM interface with memory optimizations. - - Key optimizations (following latest vLLM): - 1. Meta device initialization - no GPU memory until weights are loaded - 2. Module replacement - Linear/RMSNorm replaced with vLLM optimized versions - 3. VocabParallelEmbedding for input embeddings - 4. AutoWeightsLoader for efficient weight loading with name mapping - - Interface compliance: - - Implements VllmModel protocol (vllm_config init, forward with required args) - - Implements VllmModelForTextGeneration protocol (compute_logits, sample) + Should be used with Base class: + class TransformersForCausalLM(CausalMixin, Base): ... """ - # Weight name mapping from HuggingFace to vLLM format - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - # Add `model.` prefix for base model checkpoints - "": "model.", - "model.model.": "model.", - # Heads will be adjacent to `model` - "model.lm_head.": "lm_head.", - "lm_head.": "lm_head.", - # Handle different model architectures - "transformer.": "model.", - "model.transformer.": "model.", - # Handle embeddings - "embed_tokens.": "model.embed_tokens.", - "model.embed_tokens.": "model.embed_tokens.", - # Handle attention weights - "self_attn.": "self_attn.", - "attention.": "self_attn.", - } - ) - - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: - super().__init__() + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None: + # Call next class in MRO (should be Base) + super().__init__(vllm_config=vllm_config, prefix=prefix) - config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config + # Handle tied word embeddings - skip loading lm_head weights + tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) + if tie_word_embeddings: + self.skip_prefixes.append("lm_head.") + logger.info("Model has tied word embeddings, skipping lm_head weight loading") - self.config = config.hf_config - self.text_config = getattr(self.config, "text_config", self.config) - self.model_config = config - self.cache_config = cache_config - self.device_config = vllm_config.device_config - self.quant_config = quant_config - self.prefix = prefix - - # Get model dimensions from config - self.hidden_size = getattr(self.text_config, "hidden_size", 4096) - self.vocab_size = getattr(self.text_config, "vocab_size", 32000) - - # Weight loading configuration - self.skip_prefixes: List[str] = [] - self.ignore_unexpected_prefixes: List[str] = [ - "model.layers.*.self_attn.rotary_emb.inv_freq", # Skip RoPE weights - "model.norm.bias", # Some models don't have bias in final norm - ] - - logger.info("Using Transformers modeling backend for %s", - config.hf_config.architectures) - - # Configure vLLM attention backend - self._configure_attention_backend() - - # Load the HuggingFace model structure on meta device (no memory allocation) - self._load_hf_model_on_meta() - - # Replace modules with vLLM optimized versions - self._replace_modules() - - # Replace input embeddings with VocabParallelEmbedding - self._replace_input_embeddings() - - # Create attention instances for KV cache - self.attention_instances = self._create_attention_instances() - - # Initialize parameters (allocate memory on target device) - self._init_parameters() - - # Setup logits processor and sampler + # Setup logits processor + logit_scale = getattr(self.text_config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( self.vocab_size, logits_as_input=False, + scale=logit_scale, ) + + # Setup sampler self.sampler = Sampler() - - def _load_hf_model_on_meta(self) -> None: - """ - Load the HuggingFace model structure on meta device. - This creates the model structure without allocating GPU memory. - Memory will be allocated later during weight loading. - """ - from transformers import AutoModelForCausalLM - - logger.info("Creating model structure on meta device...") - - # Create model on meta device - no GPU memory allocated - with init_on_device_without_buffers("meta"): - self.model: "PreTrainedModel" = AutoModelForCausalLM.from_config( - self.config, - torch_dtype=self.model_config.dtype, - trust_remote_code=self.model_config.trust_remote_code, - ) - - # Disable gradient computation for inference - self.model.eval() - for param in self.model.parameters(): - param.requires_grad = False - - logger.info("Model structure created on meta device") - - def _replace_modules(self) -> None: - """ - Replace HuggingFace modules with vLLM optimized versions. - - This replaces: - - nn.Linear with ReplicatedLinear (memory efficient, supports quantization) - - RMSNorm variants with vLLM's fused RMSNorm - """ - logger.info("Replacing modules with vLLM optimized versions...") - replaced_count = 0 - - def _recursive_replace(module: nn.Module, prefix: str = ""): - nonlocal replaced_count - for name, child in list(module.named_children()): - qual_name = maybe_prefix(prefix, name) - - if isinstance(child, nn.Linear): - # Replace Linear with vLLM's ReplicatedLinear - new_module = replace_linear_class( - child, - style="replicate", - quant_config=self.quant_config, - prefix=qual_name, - ) - setattr(module, name, new_module) - log_replacement(qual_name, child, new_module) - replaced_count += 1 - - elif child.__class__.__name__.endswith("RMSNorm"): - # Replace RMSNorm with vLLM's optimized version - new_module = replace_rms_norm_class(child, self.hidden_size) - setattr(module, name, new_module) - log_replacement(qual_name, child, new_module) - replaced_count += 1 - - elif child.__class__.__name__.endswith(("LayerNorm", "GroupNorm")): - # Also handle other normalization layers - logger.debug("Found normalization layer %s: %s", qual_name, type(child).__name__) - # Could add specialized replacement here if needed - _recursive_replace(child, qual_name) - - elif "Attention" in child.__class__.__name__: - # Mark attention layers for potential replacement - logger.debug("Found attention layer %s: %s", qual_name, type(child).__name__) - # Note: We don't replace the attention module itself, - # but we create separate vLLM Attention instances - _recursive_replace(child, qual_name) - - else: - # Recursively process children - _recursive_replace(child, qual_name) - - _recursive_replace(self.model, "model") - logger.info("Replaced %d modules with vLLM optimized versions", replaced_count) - - def _replace_input_embeddings(self) -> None: - """ - Replace the input embeddings with VocabParallelEmbedding. - - This provides memory efficiency for large vocabularies. - """ - input_embeddings = self.model.get_input_embeddings() - if input_embeddings is None: - logger.warning("Could not find input embeddings to replace") - return - - # Get embedding dimension - if hasattr(input_embeddings, "embedding_dim"): - embedding_dim = input_embeddings.embedding_dim - elif hasattr(input_embeddings, "weight"): - embedding_dim = input_embeddings.weight.shape[1] - else: - embedding_dim = self.hidden_size - - # Store embed_scale if present (some models scale embeddings) - self.embed_scale = getattr(input_embeddings, "embed_scale", None) - - logger.info("Replacing input embeddings with VocabParallelEmbedding " - "(vocab_size=%d, embedding_dim=%d)", self.vocab_size, embedding_dim) - - new_embeddings = VocabParallelEmbedding( - self.vocab_size, - embedding_dim, - org_num_embeddings=self.vocab_size, - quant_config=self.quant_config, - ) - self.model.set_input_embeddings(new_embeddings) - - def _init_parameters(self) -> None: - """ - Initialize parameters from meta device to target device. - - This allocates the actual GPU memory for all parameters. - """ - logger.info("Initializing parameters on target device...") - - # Use device_config to get the correct device (supports MLU, CUDA, etc.) - device = self.device_config.device - if device is None: - # Fallback for TPU or other special cases - device = torch.device("cpu") - dtype = self.model_config.dtype - - def _init_params(module: nn.Module): - for name, param in list(module.named_parameters(recurse=False)): - if param.device == torch.device("meta"): - new_param = nn.Parameter( - torch.empty_like( - param.data, - dtype=dtype, - device=device, - ), - requires_grad=False, - ) - setattr(module, name, new_param) - - for child in module.children(): - _init_params(child) - - _init_params(self.model) - logger.info("Parameters initialized on %s", device) - - def _create_attention_instances(self) -> Dict[int, Attention]: - """ - Create vLLM Attention instances for each layer. - - This enables proper KV cache allocation and vLLM's optimized attention. - Returns a dict mapping layer_idx to Attention instance. - """ - attention_instances: Dict[int, Attention] = {} - num_layers = getattr(self.text_config, "num_hidden_layers", - getattr(self.text_config, "num_layers", 32)) - - num_heads = getattr(self.text_config, "num_attention_heads", 32) - head_size = self.hidden_size // num_heads - scale = 1.0 / (head_size ** 0.5) - num_kv_heads = getattr(self.text_config, "num_key_value_heads", num_heads) - - logger.info("Creating %d attention instances for KV cache " - "(num_heads=%d, head_size=%d, num_kv_heads=%d)", - num_layers, num_heads, head_size, num_kv_heads) - - for layer_idx in range(num_layers): - attention = Attention( - num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=num_kv_heads, - cache_config=self.cache_config, - quant_config=self.quant_config, - prefix=f"model.layers.{layer_idx}.self_attn", - ) - attention_instances[layer_idx] = attention - - return attention_instances - - def _configure_attention_backend(self) -> None: - """ - Configure vLLM attention backend for the model. - - This sets up the attention implementation BEFORE model creation. - Only sets 'vllm' implementation if the attention function was registered. - """ - global _vllm_attention_registered - - if _vllm_attention_registered: - # Set vLLM attention implementation in config (must be before from_config) - self.text_config._attn_implementation = "vllm" - logger.info("Set attention implementation to 'vllm' in text_config") - else: - # Use default eager attention if vLLM attention is not available - logger.info("Using default HuggingFace attention (vLLM attention not registered)") - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: "AttentionMetadata", - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - Forward pass through the model. - - This method conforms to the VllmModel protocol by accepting: - - input_ids: Token IDs - - positions: Position IDs - - kv_caches: KV cache tensors - - attn_metadata: Attention metadata - """ - # Set attention context for vLLM attention function - set_attention_context(attn_metadata, kv_caches) - - try: - # Prepare inputs - add batch dimension if needed - if inputs_embeds is not None: - if inputs_embeds.dim() == 2: - inputs_embeds = inputs_embeds.unsqueeze(0) - model_inputs = {"inputs_embeds": inputs_embeds} - else: - if input_ids.dim() == 1: - input_ids = input_ids.unsqueeze(0) - model_inputs = {"input_ids": input_ids} - - # Position IDs - add batch dimension - if positions is not None: - if positions.dim() == 1: - positions = positions.unsqueeze(0) - model_inputs["position_ids"] = positions - - # Apply embed_scale if needed - if ( - self.embed_scale is not None - and "input_ids" in model_inputs - and "inputs_embeds" not in model_inputs - ): - inputs_embeds = self.model.get_input_embeddings()(model_inputs["input_ids"]) - inputs_embeds = inputs_embeds * self.embed_scale - model_inputs = {"inputs_embeds": inputs_embeds} - if positions is not None: - model_inputs["position_ids"] = positions - - # Run the model with vLLM attention instances - with torch.no_grad(): - outputs = self.model( - **model_inputs, - use_cache=False, - return_dict=True, - output_hidden_states=True, - attention_instances=self.attention_instances, - ) - - # Get hidden states from the last layer - if outputs.hidden_states is not None: - hidden_states = outputs.hidden_states[-1] - else: - # Fallback: use logits directly - hidden_states = outputs.logits - - # Remove batch dimension - if hidden_states.dim() == 3 and hidden_states.size(0) == 1: - hidden_states = hidden_states.squeeze(0) - - return hidden_states - finally: - # Clear attention context - clear_attention_context() + logger.info("CausalMixin initialized (vocab_size=%d, logit_scale=%s)", + self.vocab_size, logit_scale) def compute_logits( self, @@ -538,24 +71,35 @@ class TransformersForCausalLM(nn.Module): Compute logits from hidden states. This method conforms to the VllmModelForTextGeneration protocol. + + Args: + hidden_states: Hidden states from the model [seq_len, hidden_size] + sampling_metadata: Sampling metadata + + Returns: + Logits tensor or None """ - # Check if hidden_states are already logits + # Check if hidden_states are already logits (some models output logits directly) if hidden_states.shape[-1] == self.vocab_size: logits = hidden_states else: # Apply the LM head lm_head = getattr(self.model, "lm_head", None) if lm_head is None: + # Some models use different names lm_head = getattr(self.model, "embed_out", None) + if lm_head is None: + lm_head = getattr(self.model, "output", None) if lm_head is not None: output = lm_head(hidden_states) - # Handle tuple output from vLLM Linear layers + # Handle tuple output from vLLM Linear layers (output, bias) if isinstance(output, tuple): logits = output[0] else: logits = output else: + logger.warning("Could not find lm_head, using hidden_states as logits") logits = hidden_states return self.logits_processor(None, logits, sampling_metadata) @@ -569,32 +113,13 @@ class TransformersForCausalLM(nn.Module): Sample tokens from logits. This method conforms to the VllmModelForTextGeneration protocol. + + Args: + logits: Logits tensor + sampling_metadata: Sampling metadata + + Returns: + SamplerOutput with sampled tokens """ next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - - def load_weights( - self, - weights: Iterable[Tuple[str, torch.Tensor]], - ) -> Set[str]: - """ - Load weights into the model using AutoWeightsLoader. - - This uses vLLM's efficient weight loading infrastructure with - automatic name mapping. - """ - loader = AutoWeightsLoader( - self, - skip_prefixes=self.skip_prefixes, - ignore_unexpected_prefixes=self.ignore_unexpected_prefixes, - ) - loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - logger.info("Loaded %d weight tensors", len(loaded)) - return set(loaded) - - -def is_backend_compatible() -> bool: - """ - Check if the current model is compatible with the Transformers backend. - """ - return True diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/legacy.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/legacy.py new file mode 100644 index 0000000..34d6dbb --- /dev/null +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/legacy.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 The vLLM team. +"""Transformers modeling backend mixin for legacy models. + +This module provides LegacyMixin for BERT-like encoder models that have +different weight naming conventions and special position handling. + +Following latest vLLM architecture patterns adapted for v0.6.2. +""" + +from typing import TYPE_CHECKING, List, Optional + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.models.utils import WeightsMapper +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class LegacyMixin: + """ + Mixin class for legacy/encoder models like BERT, RoBERTa. + + This mixin provides: + - Weight name mapping for legacy suffix conventions (.gamma/.beta) + - Prefix mapping for BERT-like model structures + - RoBERTa-specific position handling + - Skip prefixes for unsupported output layers + + Should be used with Base class: + class TransformersForLegacy(LegacyMixin, Base): ... + """ + + # Weight name mapping for legacy models + hf_to_vllm_mapper = WeightsMapper( + # These are applied in order, so the order matters! + orig_to_new_prefix={ + # Handle BERT-like models + "roberta": "model", + "bert": "model", + }, + orig_to_new_suffix={ + # Replace legacy suffixes used for norms + ".gamma": ".weight", + ".beta": ".bias", + }, + ) + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None: + # Call next class in MRO (should be Base) + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # Skip unsupported/unwanted output embeddings layers + self.skip_prefixes.extend([ + "model.lm_head.", + "model.predictions.", + "model.qa_outputs.", + "model.embeddings_project.", + "model.discriminator_predictions.", + ]) + + # v0.6.2 doesn't have skip_substrs, so we handle it differently + # Store patterns to skip during weight loading + self._legacy_skip_patterns: List[str] = [ + "position_ids", # Some encoder models have position_ids buffer + "score.bias", # Final classifier bias not used by vLLM + ] + + # RoBERTa-like models have extra padding in positions + model_type = getattr(self.text_config, "model_type", "").lower() + self.is_roberta = "roberta" in model_type + self.padding_idx = getattr(self.text_config, "pad_token_id", 1) + + if self.is_roberta: + logger.info("LegacyMixin detected RoBERTa model, enabling position padding") + + logger.info("LegacyMixin initialized for legacy/encoder model") + + def _should_skip_weight(self, name: str) -> bool: + """Check if a weight should be skipped during loading.""" + for pattern in self._legacy_skip_patterns: + if pattern in name: + return True + return False + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass with RoBERTa position handling. + + RoBERTa models require positions to be offset by padding_idx + 1. + """ + if self.is_roberta and positions is not None: + # RoBERTa-specific positions padding + positions = positions + self.padding_idx + 1 + + return super().forward( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/pooling.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/pooling.py new file mode 100644 index 0000000..6b34a03 --- /dev/null +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/pooling.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 The vLLM team. +"""Transformers modeling backend mixins for pooling/embedding models. + +This module provides mixins for embedding and sequence classification models: +- EmbeddingMixin: For embedding/sentence similarity models +- SequenceClassificationMixin: For sequence classification/cross-encoding + +Following latest vLLM architecture patterns adapted for v0.6.2. +""" + +from typing import TYPE_CHECKING, List, Optional + +import torch +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import PoolerOutput + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class EmbeddingMixin: + """ + Mixin class that adds embedding/pooling functionality. + + This mixin provides: + - Pooler layer for extracting embeddings + - pooling method for VllmModelForPooling protocol + + Should be used with Base class: + class TransformersForEmbedding(EmbeddingMixin, Base): ... + """ + + # Default pooling configuration + default_pooling_type: PoolingType = PoolingType.CLS + default_normalize: bool = True + default_softmax: bool = False + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None: + # Call next class in MRO (should be Base) + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # Get pooler config from model config + pooler_config = vllm_config.model_config.pooler_config + + # Setup pooler + self.pooler = Pooler.from_config_with_defaults( + pooler_config=pooler_config, + pooling_type=self.default_pooling_type, + normalize=self.default_normalize, + softmax=self.default_softmax, + ) + + if self.pooler is None: + # Create default pooler if config doesn't specify + self.pooler = Pooler( + pooling_type=self.default_pooling_type, + normalize=self.default_normalize, + softmax=self.default_softmax, + ) + + logger.info("EmbeddingMixin initialized (pooling_type=%s, normalize=%s)", + self.pooler.pooling_type.name, self.pooler.normalize) + + def pooling( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + """ + Apply pooling to hidden states. + + Args: + hidden_states: Hidden states from the model [seq_len, hidden_size] + pooling_metadata: Pooling metadata + + Returns: + PoolerOutput with pooled embeddings + """ + return self.pooler(hidden_states, pooling_metadata) + + +class SequenceClassificationMixin(EmbeddingMixin): + """ + Mixin class that adds sequence classification functionality. + + This mixin provides: + - Classifier layer for sequence classification + - pooling method with classification logits + + Should be used with Base class: + class TransformersForSequenceClassification(SequenceClassificationMixin, Base): ... + """ + + default_pooling_type: PoolingType = PoolingType.CLS + default_normalize: bool = False + default_softmax: bool = True + + def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None: + # Call EmbeddingMixin.__init__ -> Base.__init__ + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # Find and setup classifier layer + self.classifier = self._find_classifier() + + if self.classifier is not None: + # Initialize classifier parameters on device + self._init_classifier_params() + logger.info("SequenceClassificationMixin initialized with classifier") + else: + logger.warning("Could not find classifier layer") + + def _find_classifier(self) -> Optional[nn.Module]: + """Find the classifier layer in the model.""" + # Common classifier layer names + classifier_names = ['classifier', 'score', 'fc', 'head'] + + for name in classifier_names: + if hasattr(self.model, name): + return getattr(self.model, name) + + return None + + def _init_classifier_params(self) -> None: + """Initialize classifier parameters on target device.""" + device = self.device_config.device + if device is None: + device = torch.device("cpu") + dtype = self.model_config.dtype + + for name, param in list(self.classifier.named_parameters()): + if param.device == torch.device("meta"): + new_param = nn.Parameter( + torch.empty_like(param.data, dtype=dtype, device=device), + requires_grad=False, + ) + setattr(self.classifier, name.split('.')[-1], new_param) + + def pooling( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + """ + Apply pooling and classification to hidden states. + + Args: + hidden_states: Hidden states from the model [seq_len, hidden_size] + pooling_metadata: Pooling metadata + + Returns: + PoolerOutput with classification logits + """ + # First apply base pooling + pooled = self.pooler(hidden_states, pooling_metadata) + + # Apply classifier if available + if self.classifier is not None and pooled is not None: + # Apply classifier to each pooled output + for i, output in enumerate(pooled.outputs): + if hasattr(output, 'data'): + output.data = self.classifier(output.data) + + return pooled