testing dynamic register

This commit is contained in:
Chranos
2026-02-05 18:02:59 +08:00
parent 6e38461af6
commit 332e5f71a6
5 changed files with 1035 additions and 535 deletions

View File

@@ -2,23 +2,43 @@
# Copyright 2024 The vLLM team.
"""Wrapper around `transformers` models for vLLM v0.6.2.
This module provides an advanced Transformers modeling backend that wraps
This module provides the Transformers modeling backend that wraps
any HuggingFace model with the vLLM interface, enabling support for custom
models that define their implementation via `auto_map` in config.json.
Key optimizations and features:
Architecture (following latest vLLM patterns):
- Base: Core functionality (meta init, PP/TP support, module replacement, attention, weight loading)
- CausalMixin: Causal LM specific (lm_head, compute_logits, sample)
- EmbeddingMixin: Embedding/pooling specific (pooler, pooling)
- SequenceClassificationMixin: Classification specific (classifier, pooling)
Composed model classes:
- TransformersForCausalLM = CausalMixin + Base
- TransformersForEmbedding = EmbeddingMixin + Base
- TransformersForSequenceClassification = SequenceClassificationMixin + Base
Key optimizations:
- Meta device initialization for memory efficiency
- Module replacement (Linear, RMSNorm, Embedding) with vLLM optimized versions
- VocabParallelEmbedding for input embeddings
- Pipeline Parallel support (PPMissingLayer)
- Tensor Parallel support (tp_plan based module replacement)
- Module replacement (Linear, RMSNorm, Embedding) with vLLM optimized versions
- vLLM Attention instances for proper KV cache allocation
- AutoWeightsLoader for efficient weight loading with name mapping
- vLLM attention backend integration (when supported by upgraded transformers)
"""
from vllm.model_executor.models.transformers.causal import (
TransformersForCausalLM,
is_backend_compatible,
from vllm.model_executor.models.transformers.base import (
Base,
set_attention_context,
clear_attention_context,
get_attention_context,
vllm_flash_attention_forward,
)
from vllm.model_executor.models.transformers.causal import CausalMixin
from vllm.model_executor.models.transformers.pooling import (
EmbeddingMixin,
SequenceClassificationMixin,
)
from vllm.model_executor.models.transformers.legacy import LegacyMixin
from vllm.model_executor.models.transformers.utils import (
init_on_device_without_buffers,
replace_linear_class,
@@ -27,10 +47,77 @@ from vllm.model_executor.models.transformers.utils import (
maybe_prefix,
)
# ============================================================================
# Composed Model Classes (Mixin + Base pattern)
# ============================================================================
class TransformersForCausalLM(CausalMixin, Base):
"""
Transformers backend wrapper for causal language models.
Combines CausalMixin (lm_head, compute_logits, sample) with
Base (meta init, PP/TP support, module replacement, attention, weight loading).
Supports any HuggingFace model with auto_map in config.json.
"""
pass
class TransformersForEmbedding(EmbeddingMixin, Base):
"""
Transformers backend wrapper for embedding/sentence similarity models.
Combines EmbeddingMixin (pooler, pooling) with
Base (meta init, PP/TP support, module replacement, attention, weight loading).
Supports embedding models like BERT, sentence-transformers, etc.
"""
pass
class TransformersForSequenceClassification(SequenceClassificationMixin, Base):
"""
Transformers backend wrapper for sequence classification models.
Combines SequenceClassificationMixin (classifier, pooling) with
Base (meta init, PP/TP support, module replacement, attention, weight loading).
Supports cross-encoders and classification models.
"""
pass
class TransformersForLegacy(LegacyMixin, EmbeddingMixin, Base):
"""
Transformers backend wrapper for legacy/encoder models.
Combines LegacyMixin (BERT/RoBERTa weight mapping, position handling) with
EmbeddingMixin (pooler) and Base (core functionality).
Supports BERT, RoBERTa, and similar encoder models.
"""
pass
__all__ = [
# Main wrapper classes
"TransformersForCausalLM",
"is_backend_compatible",
"TransformersForEmbedding",
"TransformersForSequenceClassification",
"TransformersForLegacy",
# Base class for extension
"Base",
# Mixin classes for custom combinations
"CausalMixin",
"EmbeddingMixin",
"SequenceClassificationMixin",
"LegacyMixin",
# Attention context management
"set_attention_context",
"clear_attention_context",
"get_attention_context",
"vllm_flash_attention_forward",
# Utility functions
"init_on_device_without_buffers",
"replace_linear_class",

View File

@@ -0,0 +1,600 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend base class for v0.6.2.
This module provides the Base class following latest vLLM architecture:
- Meta device initialization for memory efficiency
- Pipeline parallel support (PPMissingLayer)
- Tensor parallel support (tp_plan based module replacement)
- Module replacement (Linear, RMSNorm) with vLLM optimized versions
- VocabParallelEmbedding for input embeddings
- Attention instances for KV cache allocation
- Weight loading with AutoWeightsLoader and WeightsMapper
"""
import re
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
make_empty_intermediate_tensors_factory,
)
from vllm.attention.layer import Attention
from vllm.sequence import IntermediateTensors
from .utils import (
init_on_device_without_buffers,
replace_linear_class,
replace_rms_norm_class,
log_replacement,
maybe_prefix,
)
if TYPE_CHECKING:
from transformers import PreTrainedModel
from vllm.attention import AttentionMetadata
logger = init_logger(__name__)
# ============================================================================
# Attention Context Management (for vLLM attention integration)
# ============================================================================
_current_attn_metadata = None
_current_kv_caches = None
def set_attention_context(attn_metadata, kv_caches):
"""Set the current attention context for vLLM attention functions."""
global _current_attn_metadata, _current_kv_caches
_current_attn_metadata = attn_metadata
_current_kv_caches = kv_caches
def clear_attention_context():
"""Clear the current attention context after forward pass."""
global _current_attn_metadata, _current_kv_caches
_current_attn_metadata = None
_current_kv_caches = None
def get_attention_context():
"""Get the current attention context."""
return _current_attn_metadata, _current_kv_caches
# ============================================================================
# vLLM Attention Function for Transformers Integration
# ============================================================================
def vllm_flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
scaling: float = None,
attention_instances: Dict[int, Attention] = None,
**kwargs,
):
"""
vLLM's optimized attention function for transformers integration.
In v0.6.2, Attention.forward signature is:
(query, key, value, kv_cache, attn_metadata)
"""
layer_idx = getattr(module, 'layer_idx', 0)
if attention_instances is None or layer_idx not in attention_instances:
return _standard_attention(query, key, value, attention_mask, scaling)
self_attn = attention_instances[layer_idx]
attn_metadata, kv_caches = get_attention_context()
if attn_metadata is None or kv_caches is None:
return _standard_attention(query, key, value, attention_mask, scaling)
if scaling is not None:
self_attn.impl.scale = float(scaling)
# Reshape: [batch, heads, seq, head_dim] -> [seq, heads * head_dim]
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
kv_cache = kv_caches[layer_idx] if layer_idx < len(kv_caches) else None
output = self_attn.forward(query, key, value, kv_cache, attn_metadata)
return output, None
def _standard_attention(query, key, value, attention_mask, scaling):
"""Standard scaled dot-product attention fallback."""
attn_weights = torch.matmul(query, key.transpose(-2, -1))
if scaling is not None:
attn_weights = attn_weights * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, None
# Register vLLM attention to transformers
_vllm_attention_registered = False
try:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
_vllm_attention_registered = True
logger.info("Registered vLLM attention function to transformers")
except (ImportError, AttributeError) as e:
logger.warning("Could not register vLLM attention: %s", e)
# ============================================================================
# Base Class with Pipeline Parallel and Tensor Parallel Support
# ============================================================================
class Base(nn.Module):
"""
Base class for Transformers backend models with full parallel support.
Features:
- Pipeline Parallel: PPMissingLayer for distributed layers
- Tensor Parallel: tp_plan based module replacement
- Meta device initialization
- Module replacement (Linear → vLLM Linear, RMSNorm → vLLM RMSNorm)
- VocabParallelEmbedding for input embeddings
- Attention instances for KV cache allocation
"""
# For vLLM's weight loader
embedding_modules = ["embed_tokens"]
# Weight name mapping following latest vLLM pattern
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Add `model.` prefix for base model checkpoints,
# handling the case where it is already present
"": "model.",
"model.model.": "model.",
# Heads will be adjacent to `model` (pooling included because of adapters)
"model.lm_head.": "lm_head.",
"model.score.": "classifier.",
"model.classifier.": "classifier.",
}
)
def __init_subclass__(cls, *args, **kwargs):
"""Merge hf_to_vllm_mapper in MRO from most specific to least specific."""
super().__init_subclass__(*args, **kwargs)
hf_to_vllm_mapper = WeightsMapper()
for base in cls.__mro__:
if base_hf_to_vllm_mapper := getattr(base, "hf_to_vllm_mapper", None):
hf_to_vllm_mapper |= base_hf_to_vllm_mapper
cls.hf_to_vllm_mapper = hf_to_vllm_mapper
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
logger.info("Using Transformers modeling backend.")
# Store configuration
self.config = vllm_config.model_config.hf_config
self.text_config = getattr(self.config, "text_config", self.config)
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.device_config = vllm_config.device_config
self.parallel_config = vllm_config.parallel_config
self.quant_config = vllm_config.quant_config
self.prefix = prefix
# Parallel groups
self.pp_group = get_pp_group()
self.tp_group = get_tp_group()
# Model dimensions
self.hidden_size = getattr(self.text_config, "hidden_size", 4096)
self.vocab_size = getattr(self.text_config, "vocab_size", 32000)
# Weight loading configuration
self.skip_prefixes: List[str] = []
self.ignore_unexpected_prefixes: List[str] = []
# Configure attention backend
self._configure_attention_backend()
# Create model on meta device
self._init_model_on_meta()
# Apply pipeline parallel
self._apply_pipeline_parallel()
# Replace modules (with tensor parallel support)
self._replace_modules()
# Replace input embeddings
self._replace_input_embeddings()
# Create attention instances
self.attention_instances = self._create_attention_instances()
# Initialize parameters on target device
self._init_parameters()
# Pipeline parallel intermediate tensors
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], self.hidden_size
)
def _configure_attention_backend(self) -> None:
"""Configure vLLM attention backend."""
# Note: attention implementation is set in _init_model_on_meta
# This method is kept for potential platform-specific configuration
pass
def _init_model_on_meta(self) -> None:
"""Create model structure on meta device."""
from transformers import AutoModel
logger.info("Creating model structure on meta device...")
# Set attention implementation to vLLM's
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"):
self.model: "PreTrainedModel" = AutoModel.from_config(
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
def _apply_pipeline_parallel(self) -> None:
"""
Apply pipeline parallelization plan.
For models that don't explicitly support pp_plan, we do a best-effort
approach by splitting layers based on num_hidden_layers.
"""
if self.pp_group.world_size <= 1:
return
logger.info("Applying pipeline parallel (world_size=%d, rank=%d)",
self.pp_group.world_size, self.pp_group.rank_in_group)
num_layers = getattr(self.text_config, "num_hidden_layers",
getattr(self.text_config, "num_layers", 32))
start_layer, end_layer = get_pp_indices(
num_layers,
self.pp_group.rank_in_group,
self.pp_group.world_size,
)
# Find and process layer modules
layers_module = self._find_layers_module()
if layers_module is not None:
layers = list(layers_module.children())
for i, layer in enumerate(layers):
if not (start_layer <= i < end_layer):
# Replace layers not on this rank with PPMissingLayer
setattr(layers_module, str(i), PPMissingLayer())
# Handle embeddings (only on first rank)
if not self.pp_group.is_first_rank:
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is not None:
# Keep a reference but mark as missing for forward
self._has_embeddings = False
else:
self._has_embeddings = True
# Handle final norm and lm_head (only on last rank)
if not self.pp_group.is_last_rank:
# Mark lm_head as missing
if hasattr(self.model, 'lm_head'):
self.model.lm_head = PPMissingLayer()
logger.info("Pipeline parallel applied: layers %d-%d on this rank",
start_layer, end_layer)
def _find_layers_module(self) -> Optional[nn.Module]:
"""Find the ModuleList containing transformer layers."""
# Common layer container names
layer_names = ['layers', 'h', 'blocks', 'layer', 'encoder.layer', 'decoder.layers']
def _search_layers(module: nn.Module, prefix: str = "") -> Optional[nn.Module]:
for name, child in module.named_children():
if name in ['layers', 'h', 'blocks', 'layer'] and isinstance(child, nn.ModuleList):
return child
# Recursively search in model backbone
if name in ['model', 'transformer', 'encoder', 'decoder']:
result = _search_layers(child, f"{prefix}.{name}" if prefix else name)
if result is not None:
return result
return None
return _search_layers(self.model)
def _get_tp_plan(self) -> Dict[str, str]:
"""
Get tensor parallel plan for module replacement.
This maps module name patterns to parallelization styles:
- "colwise": Column parallel (split output dim)
- "rowwise": Row parallel (split input dim)
- "replicate": Replicated (no split)
Returns a dict mapping regex patterns to styles.
"""
# Check if model has explicit tp_plan
if hasattr(self.model, 'tp_plan') and self.model.tp_plan:
return {maybe_prefix("model", k): v for k, v in self.model.tp_plan.items()}
# Default tp_plan for common LLM architectures
# Based on typical transformer structure
return {
r".*\.q_proj$": "colwise",
r".*\.k_proj$": "colwise",
r".*\.v_proj$": "colwise",
r".*\.o_proj$": "rowwise",
r".*\.gate_proj$": "colwise",
r".*\.up_proj$": "colwise",
r".*\.down_proj$": "rowwise",
r".*\.query$": "colwise",
r".*\.key$": "colwise",
r".*\.value$": "colwise",
r".*\.dense$": "rowwise",
r".*\.fc1$": "colwise",
r".*\.fc2$": "rowwise",
}
def _replace_modules(self) -> None:
"""
Replace modules with vLLM optimized versions.
Uses tp_plan for tensor parallel style selection.
"""
logger.info("Replacing modules with vLLM optimized versions...")
replaced_count = 0
# Get tensor parallel plan
tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {}
def _recursive_replace(module: nn.Module, prefix: str = ""):
nonlocal replaced_count
for name, child in list(module.named_children()):
# Skip PPMissingLayer
if isinstance(child, PPMissingLayer):
continue
qual_name = maybe_prefix(prefix, name)
new_module = None
if isinstance(child, nn.Linear):
# Determine parallelization style from tp_plan
style = "replicate"
for pattern, plan_style in tp_plan.items():
if re.match(pattern, qual_name):
style = plan_style
break
new_module = replace_linear_class(
child,
style=style,
quant_config=self.quant_config,
prefix=qual_name,
)
replaced_count += 1
elif child.__class__.__name__.endswith("RMSNorm"):
new_module = replace_rms_norm_class(child, self.hidden_size)
replaced_count += 1
if new_module is not None:
setattr(module, name, new_module)
log_replacement(qual_name, child, new_module)
else:
_recursive_replace(child, qual_name)
_recursive_replace(self.model, "model")
logger.info("Replaced %d modules", replaced_count)
def _replace_input_embeddings(self) -> None:
"""Replace input embeddings with VocabParallelEmbedding."""
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is None or isinstance(input_embeddings, PPMissingLayer):
return
if hasattr(input_embeddings, "embedding_dim"):
embedding_dim = input_embeddings.embedding_dim
elif hasattr(input_embeddings, "weight"):
embedding_dim = input_embeddings.weight.shape[1]
else:
embedding_dim = self.hidden_size
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
logger.info("Replacing input embeddings (vocab=%d, dim=%d)",
self.vocab_size, embedding_dim)
new_embeddings = VocabParallelEmbedding(
self.vocab_size,
embedding_dim,
org_num_embeddings=self.vocab_size,
quant_config=self.quant_config,
)
self.model.set_input_embeddings(new_embeddings)
def _create_attention_instances(self) -> Dict[int, Attention]:
"""Create Attention instances for KV cache allocation."""
num_layers = getattr(self.text_config, "num_hidden_layers",
getattr(self.text_config, "num_layers", 32))
num_heads = getattr(self.text_config, "num_attention_heads", 32)
head_size = self.hidden_size // num_heads
num_kv_heads = getattr(self.text_config, "num_key_value_heads", num_heads)
# Get PP layer range
pp_rank = self.pp_group.rank_in_group
pp_size = self.pp_group.world_size
start_layer, end_layer = get_pp_indices(num_layers, pp_rank, pp_size)
logger.info("Creating attention instances for layers %d-%d "
"(heads=%d, head_size=%d, kv_heads=%d)",
start_layer, end_layer, num_heads, head_size, num_kv_heads)
attention_instances: Dict[int, Attention] = {}
for layer_idx in range(start_layer, end_layer):
per_layer_sliding_window = None
if hasattr(self.config, "layer_types"):
layer_types = self.config.layer_types
if layer_idx < len(layer_types) and layer_types[layer_idx] == "sliding_attention":
per_layer_sliding_window = getattr(self.config, "sliding_window", None)
attention = Attention(
num_heads=num_heads,
head_size=head_size,
scale=1.0 / (head_size ** 0.5),
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"model.layers.{layer_idx}.self_attn",
)
attention_instances[layer_idx] = attention
return attention_instances
def _init_parameters(self) -> None:
"""Initialize parameters from meta device to target device."""
device = self.device_config.device
if device is None:
device = torch.device("cpu")
dtype = self.model_config.dtype
def _init_params(module: nn.Module):
if isinstance(module, PPMissingLayer):
return
for name, param in list(module.named_parameters(recurse=False)):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data, dtype=dtype, device=device),
requires_grad=False,
)
setattr(module, name, new_param)
for child in module.children():
_init_params(child)
_init_params(self.model)
logger.info("Parameters initialized on %s", device)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Get embeddings for input IDs."""
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if self.embed_scale is not None:
inputs_embeds = inputs_embeds * self.embed_scale
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Forward pass with pipeline parallel support."""
# Handle intermediate tensors for PP
if not self.pp_group.is_first_rank:
assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
set_attention_context(attn_metadata, kv_caches)
try:
# Prepare inputs
if inputs_embeds is not None:
if inputs_embeds.dim() == 2:
inputs_embeds = inputs_embeds.unsqueeze(0)
model_inputs = {"inputs_embeds": inputs_embeds}
else:
if input_ids is not None and input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
model_inputs = {"input_ids": input_ids}
if positions is not None:
if positions.dim() == 1:
positions = positions.unsqueeze(0)
model_inputs["position_ids"] = positions
# Apply embed_scale if needed
if (
self.embed_scale is not None
and input_ids is not None
and inputs_embeds is None
):
inputs_embeds = self.embed_input_ids(model_inputs["input_ids"])
model_inputs = {"inputs_embeds": inputs_embeds}
if positions is not None:
model_inputs["position_ids"] = positions
# Forward through model
with torch.no_grad():
outputs = self.model(
**model_inputs,
use_cache=False,
return_dict=True,
output_hidden_states=True,
attention_instances=self.attention_instances,
)
# Get hidden states
if outputs.hidden_states is not None:
hidden_states = outputs.hidden_states[-1]
else:
hidden_states = outputs.logits
# Remove batch dimension
if hidden_states.dim() == 3 and hidden_states.size(0) == 1:
hidden_states = hidden_states.squeeze(0)
# Return intermediate tensors for PP
if not self.pp_group.is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
finally:
clear_attention_context()
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Set[str]:
"""Load weights using AutoWeightsLoader with name mapping."""
loader = AutoWeightsLoader(
self,
skip_prefixes=self.skip_prefixes,
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
)
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
logger.info("Loaded %d weight tensors", len(loaded))
return set(loaded)

View File

@@ -1,533 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend for causal language models.
"""Transformers modeling backend mixin for causal language models.
This module provides a wrapper class that enables vLLM to use any HuggingFace
causal language model, including custom models that define their implementation
via `auto_map` in config.json.
This module provides CausalMixin that adds causal language model specific
functionality (lm_head, compute_logits, sample) to the Base class.
Key optimizations:
1. Use meta device for delayed memory allocation
2. Replace nn.Linear with vLLM's optimized Linear classes
3. Replace RMSNorm with vLLM's fused RMSNorm
4. Replace input embeddings with VocabParallelEmbedding
5. Use vLLM's weight loading infrastructure (AutoWeightsLoader)
Following latest vLLM architecture:
- TransformersForCausalLM = CausalMixin + Base
"""
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Optional
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.attention.layer import Attention
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
WeightsMapper,
)
from vllm.model_executor.models.transformers.utils import (
init_on_device_without_buffers,
replace_linear_class,
replace_rms_norm_class,
log_replacement,
maybe_prefix,
)
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from transformers import PreTrainedModel
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
logger = init_logger(__name__)
# Note: In v0.6.2, the vLLM Attention.forward requires (query, key, value, kv_cache, attn_metadata).
# The transformers backend integration works differently than in latest vLLM.
# We keep the vllm_flash_attention_forward for reference, but it may not be compatible
# with all transformers versions or MLU backends.
# Global variable to store current attention metadata (set during forward pass)
_current_attn_metadata = None
_current_kv_caches = None
def set_attention_context(attn_metadata, kv_caches):
"""Set the current attention context for vLLM attention functions."""
global _current_attn_metadata, _current_kv_caches
_current_attn_metadata = attn_metadata
_current_kv_caches = kv_caches
def clear_attention_context():
"""Clear the current attention context."""
global _current_attn_metadata, _current_kv_caches
_current_attn_metadata = None
_current_kv_caches = None
def vllm_flash_attention_forward(
# Transformers args
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
# Transformers kwargs
scaling: float = None,
# vLLM kwargs
attention_instances: Dict[int, Attention] = None,
**kwargs,
):
class CausalMixin:
"""
vLLM's optimized attention function that replaces HuggingFace's attention.
This function is registered to transformers' ALL_ATTENTION_FUNCTIONS.
Mixin class that adds causal language model functionality.
Note: In v0.6.2, this function may have limited functionality due to
API differences in the Attention layer. For full functionality, the model
should fall back to HuggingFace's native attention when vLLM attention
is not properly configured.
"""
# Get the attention instance for this layer
layer_idx = getattr(module, 'layer_idx', 0)
This mixin provides:
- LogitsProcessor for logits computation
- Sampler for token sampling
- compute_logits method for VllmModelForTextGeneration protocol
- sample method for VllmModelForTextGeneration protocol
if attention_instances is None or layer_idx not in attention_instances:
# Fall back to standard attention computation
logger.debug("No attention instance for layer %d, using standard attention", layer_idx)
# Standard scaled dot-product attention
attn_weights = torch.matmul(query, key.transpose(-2, -1))
if scaling is not None:
attn_weights = attn_weights * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, None
self_attn = attention_instances[layer_idx]
# v0.6.2 Attention.forward requires: (query, key, value, kv_cache, attn_metadata)
# We need to get these from the global context
global _current_attn_metadata, _current_kv_caches
if _current_attn_metadata is None or _current_kv_caches is None:
# No context set, fall back to standard attention
logger.debug("No attention context, using standard attention for layer %d", layer_idx)
attn_weights = torch.matmul(query, key.transpose(-2, -1))
if scaling is not None:
attn_weights = attn_weights * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, None
# Update scale if provided
if scaling is not None:
self_attn.impl.scale = float(scaling)
# Reshape tensors for vLLM: [batch, heads, seq, head_dim] -> [seq, heads * head_dim]
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
# Get KV cache for this layer
kv_cache = _current_kv_caches[layer_idx] if layer_idx < len(_current_kv_caches) else None
# Call vLLM attention
output = self_attn.forward(query, key, value, kv_cache, _current_attn_metadata)
return output, None
# Try to register vLLM attention to transformers
_vllm_attention_registered = False
try:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
_vllm_attention_registered = True
logger.info("Registered vLLM attention function to transformers")
except (ImportError, AttributeError) as e:
logger.warning("Could not register vLLM attention function - "
"transformers version may not support custom attention: %s", e)
class TransformersForCausalLM(nn.Module):
"""
A wrapper class that adapts any HuggingFace causal language model
to the vLLM interface with memory optimizations.
Key optimizations (following latest vLLM):
1. Meta device initialization - no GPU memory until weights are loaded
2. Module replacement - Linear/RMSNorm replaced with vLLM optimized versions
3. VocabParallelEmbedding for input embeddings
4. AutoWeightsLoader for efficient weight loading with name mapping
Interface compliance:
- Implements VllmModel protocol (vllm_config init, forward with required args)
- Implements VllmModelForTextGeneration protocol (compute_logits, sample)
Should be used with Base class:
class TransformersForCausalLM(CausalMixin, Base): ...
"""
# Weight name mapping from HuggingFace to vLLM format
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Add `model.` prefix for base model checkpoints
"": "model.",
"model.model.": "model.",
# Heads will be adjacent to `model`
"model.lm_head.": "lm_head.",
"lm_head.": "lm_head.",
# Handle different model architectures
"transformer.": "model.",
"model.transformer.": "model.",
# Handle embeddings
"embed_tokens.": "model.embed_tokens.",
"model.embed_tokens.": "model.embed_tokens.",
# Handle attention weights
"self_attn.": "self_attn.",
"attention.": "self_attn.",
}
)
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call next class in MRO (should be Base)
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# Handle tied word embeddings - skip loading lm_head weights
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
if tie_word_embeddings:
self.skip_prefixes.append("lm_head.")
logger.info("Model has tied word embeddings, skipping lm_head weight loading")
self.config = config.hf_config
self.text_config = getattr(self.config, "text_config", self.config)
self.model_config = config
self.cache_config = cache_config
self.device_config = vllm_config.device_config
self.quant_config = quant_config
self.prefix = prefix
# Get model dimensions from config
self.hidden_size = getattr(self.text_config, "hidden_size", 4096)
self.vocab_size = getattr(self.text_config, "vocab_size", 32000)
# Weight loading configuration
self.skip_prefixes: List[str] = []
self.ignore_unexpected_prefixes: List[str] = [
"model.layers.*.self_attn.rotary_emb.inv_freq", # Skip RoPE weights
"model.norm.bias", # Some models don't have bias in final norm
]
logger.info("Using Transformers modeling backend for %s",
config.hf_config.architectures)
# Configure vLLM attention backend
self._configure_attention_backend()
# Load the HuggingFace model structure on meta device (no memory allocation)
self._load_hf_model_on_meta()
# Replace modules with vLLM optimized versions
self._replace_modules()
# Replace input embeddings with VocabParallelEmbedding
self._replace_input_embeddings()
# Create attention instances for KV cache
self.attention_instances = self._create_attention_instances()
# Initialize parameters (allocate memory on target device)
self._init_parameters()
# Setup logits processor and sampler
# Setup logits processor
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.vocab_size,
logits_as_input=False,
scale=logit_scale,
)
# Setup sampler
self.sampler = Sampler()
def _load_hf_model_on_meta(self) -> None:
"""
Load the HuggingFace model structure on meta device.
This creates the model structure without allocating GPU memory.
Memory will be allocated later during weight loading.
"""
from transformers import AutoModelForCausalLM
logger.info("Creating model structure on meta device...")
# Create model on meta device - no GPU memory allocated
with init_on_device_without_buffers("meta"):
self.model: "PreTrainedModel" = AutoModelForCausalLM.from_config(
self.config,
torch_dtype=self.model_config.dtype,
trust_remote_code=self.model_config.trust_remote_code,
)
# Disable gradient computation for inference
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
logger.info("Model structure created on meta device")
def _replace_modules(self) -> None:
"""
Replace HuggingFace modules with vLLM optimized versions.
This replaces:
- nn.Linear with ReplicatedLinear (memory efficient, supports quantization)
- RMSNorm variants with vLLM's fused RMSNorm
"""
logger.info("Replacing modules with vLLM optimized versions...")
replaced_count = 0
def _recursive_replace(module: nn.Module, prefix: str = ""):
nonlocal replaced_count
for name, child in list(module.named_children()):
qual_name = maybe_prefix(prefix, name)
if isinstance(child, nn.Linear):
# Replace Linear with vLLM's ReplicatedLinear
new_module = replace_linear_class(
child,
style="replicate",
quant_config=self.quant_config,
prefix=qual_name,
)
setattr(module, name, new_module)
log_replacement(qual_name, child, new_module)
replaced_count += 1
elif child.__class__.__name__.endswith("RMSNorm"):
# Replace RMSNorm with vLLM's optimized version
new_module = replace_rms_norm_class(child, self.hidden_size)
setattr(module, name, new_module)
log_replacement(qual_name, child, new_module)
replaced_count += 1
elif child.__class__.__name__.endswith(("LayerNorm", "GroupNorm")):
# Also handle other normalization layers
logger.debug("Found normalization layer %s: %s", qual_name, type(child).__name__)
# Could add specialized replacement here if needed
_recursive_replace(child, qual_name)
elif "Attention" in child.__class__.__name__:
# Mark attention layers for potential replacement
logger.debug("Found attention layer %s: %s", qual_name, type(child).__name__)
# Note: We don't replace the attention module itself,
# but we create separate vLLM Attention instances
_recursive_replace(child, qual_name)
else:
# Recursively process children
_recursive_replace(child, qual_name)
_recursive_replace(self.model, "model")
logger.info("Replaced %d modules with vLLM optimized versions", replaced_count)
def _replace_input_embeddings(self) -> None:
"""
Replace the input embeddings with VocabParallelEmbedding.
This provides memory efficiency for large vocabularies.
"""
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is None:
logger.warning("Could not find input embeddings to replace")
return
# Get embedding dimension
if hasattr(input_embeddings, "embedding_dim"):
embedding_dim = input_embeddings.embedding_dim
elif hasattr(input_embeddings, "weight"):
embedding_dim = input_embeddings.weight.shape[1]
else:
embedding_dim = self.hidden_size
# Store embed_scale if present (some models scale embeddings)
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
logger.info("Replacing input embeddings with VocabParallelEmbedding "
"(vocab_size=%d, embedding_dim=%d)", self.vocab_size, embedding_dim)
new_embeddings = VocabParallelEmbedding(
self.vocab_size,
embedding_dim,
org_num_embeddings=self.vocab_size,
quant_config=self.quant_config,
)
self.model.set_input_embeddings(new_embeddings)
def _init_parameters(self) -> None:
"""
Initialize parameters from meta device to target device.
This allocates the actual GPU memory for all parameters.
"""
logger.info("Initializing parameters on target device...")
# Use device_config to get the correct device (supports MLU, CUDA, etc.)
device = self.device_config.device
if device is None:
# Fallback for TPU or other special cases
device = torch.device("cpu")
dtype = self.model_config.dtype
def _init_params(module: nn.Module):
for name, param in list(module.named_parameters(recurse=False)):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(
param.data,
dtype=dtype,
device=device,
),
requires_grad=False,
)
setattr(module, name, new_param)
for child in module.children():
_init_params(child)
_init_params(self.model)
logger.info("Parameters initialized on %s", device)
def _create_attention_instances(self) -> Dict[int, Attention]:
"""
Create vLLM Attention instances for each layer.
This enables proper KV cache allocation and vLLM's optimized attention.
Returns a dict mapping layer_idx to Attention instance.
"""
attention_instances: Dict[int, Attention] = {}
num_layers = getattr(self.text_config, "num_hidden_layers",
getattr(self.text_config, "num_layers", 32))
num_heads = getattr(self.text_config, "num_attention_heads", 32)
head_size = self.hidden_size // num_heads
scale = 1.0 / (head_size ** 0.5)
num_kv_heads = getattr(self.text_config, "num_key_value_heads", num_heads)
logger.info("Creating %d attention instances for KV cache "
"(num_heads=%d, head_size=%d, num_kv_heads=%d)",
num_layers, num_heads, head_size, num_kv_heads)
for layer_idx in range(num_layers):
attention = Attention(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"model.layers.{layer_idx}.self_attn",
)
attention_instances[layer_idx] = attention
return attention_instances
def _configure_attention_backend(self) -> None:
"""
Configure vLLM attention backend for the model.
This sets up the attention implementation BEFORE model creation.
Only sets 'vllm' implementation if the attention function was registered.
"""
global _vllm_attention_registered
if _vllm_attention_registered:
# Set vLLM attention implementation in config (must be before from_config)
self.text_config._attn_implementation = "vllm"
logger.info("Set attention implementation to 'vllm' in text_config")
else:
# Use default eager attention if vLLM attention is not available
logger.info("Using default HuggingFace attention (vLLM attention not registered)")
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Forward pass through the model.
This method conforms to the VllmModel protocol by accepting:
- input_ids: Token IDs
- positions: Position IDs
- kv_caches: KV cache tensors
- attn_metadata: Attention metadata
"""
# Set attention context for vLLM attention function
set_attention_context(attn_metadata, kv_caches)
try:
# Prepare inputs - add batch dimension if needed
if inputs_embeds is not None:
if inputs_embeds.dim() == 2:
inputs_embeds = inputs_embeds.unsqueeze(0)
model_inputs = {"inputs_embeds": inputs_embeds}
else:
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
model_inputs = {"input_ids": input_ids}
# Position IDs - add batch dimension
if positions is not None:
if positions.dim() == 1:
positions = positions.unsqueeze(0)
model_inputs["position_ids"] = positions
# Apply embed_scale if needed
if (
self.embed_scale is not None
and "input_ids" in model_inputs
and "inputs_embeds" not in model_inputs
):
inputs_embeds = self.model.get_input_embeddings()(model_inputs["input_ids"])
inputs_embeds = inputs_embeds * self.embed_scale
model_inputs = {"inputs_embeds": inputs_embeds}
if positions is not None:
model_inputs["position_ids"] = positions
# Run the model with vLLM attention instances
with torch.no_grad():
outputs = self.model(
**model_inputs,
use_cache=False,
return_dict=True,
output_hidden_states=True,
attention_instances=self.attention_instances,
)
# Get hidden states from the last layer
if outputs.hidden_states is not None:
hidden_states = outputs.hidden_states[-1]
else:
# Fallback: use logits directly
hidden_states = outputs.logits
# Remove batch dimension
if hidden_states.dim() == 3 and hidden_states.size(0) == 1:
hidden_states = hidden_states.squeeze(0)
return hidden_states
finally:
# Clear attention context
clear_attention_context()
logger.info("CausalMixin initialized (vocab_size=%d, logit_scale=%s)",
self.vocab_size, logit_scale)
def compute_logits(
self,
@@ -538,24 +71,35 @@ class TransformersForCausalLM(nn.Module):
Compute logits from hidden states.
This method conforms to the VllmModelForTextGeneration protocol.
Args:
hidden_states: Hidden states from the model [seq_len, hidden_size]
sampling_metadata: Sampling metadata
Returns:
Logits tensor or None
"""
# Check if hidden_states are already logits
# Check if hidden_states are already logits (some models output logits directly)
if hidden_states.shape[-1] == self.vocab_size:
logits = hidden_states
else:
# Apply the LM head
lm_head = getattr(self.model, "lm_head", None)
if lm_head is None:
# Some models use different names
lm_head = getattr(self.model, "embed_out", None)
if lm_head is None:
lm_head = getattr(self.model, "output", None)
if lm_head is not None:
output = lm_head(hidden_states)
# Handle tuple output from vLLM Linear layers
# Handle tuple output from vLLM Linear layers (output, bias)
if isinstance(output, tuple):
logits = output[0]
else:
logits = output
else:
logger.warning("Could not find lm_head, using hidden_states as logits")
logits = hidden_states
return self.logits_processor(None, logits, sampling_metadata)
@@ -569,32 +113,13 @@ class TransformersForCausalLM(nn.Module):
Sample tokens from logits.
This method conforms to the VllmModelForTextGeneration protocol.
Args:
logits: Logits tensor
sampling_metadata: Sampling metadata
Returns:
SamplerOutput with sampled tokens
"""
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Set[str]:
"""
Load weights into the model using AutoWeightsLoader.
This uses vLLM's efficient weight loading infrastructure with
automatic name mapping.
"""
loader = AutoWeightsLoader(
self,
skip_prefixes=self.skip_prefixes,
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
)
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
logger.info("Loaded %d weight tensors", len(loaded))
return set(loaded)
def is_backend_compatible() -> bool:
"""
Check if the current model is compatible with the Transformers backend.
"""
return True

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend mixin for legacy models.
This module provides LegacyMixin for BERT-like encoder models that have
different weight naming conventions and special position handling.
Following latest vLLM architecture patterns adapted for v0.6.2.
"""
from typing import TYPE_CHECKING, List, Optional
import torch
from vllm.logger import init_logger
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class LegacyMixin:
"""
Mixin class for legacy/encoder models like BERT, RoBERTa.
This mixin provides:
- Weight name mapping for legacy suffix conventions (.gamma/.beta)
- Prefix mapping for BERT-like model structures
- RoBERTa-specific position handling
- Skip prefixes for unsupported output layers
Should be used with Base class:
class TransformersForLegacy(LegacyMixin, Base): ...
"""
# Weight name mapping for legacy models
hf_to_vllm_mapper = WeightsMapper(
# These are applied in order, so the order matters!
orig_to_new_prefix={
# Handle BERT-like models
"roberta": "model",
"bert": "model",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
},
)
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call next class in MRO (should be Base)
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Skip unsupported/unwanted output embeddings layers
self.skip_prefixes.extend([
"model.lm_head.",
"model.predictions.",
"model.qa_outputs.",
"model.embeddings_project.",
"model.discriminator_predictions.",
])
# v0.6.2 doesn't have skip_substrs, so we handle it differently
# Store patterns to skip during weight loading
self._legacy_skip_patterns: List[str] = [
"position_ids", # Some encoder models have position_ids buffer
"score.bias", # Final classifier bias not used by vLLM
]
# RoBERTa-like models have extra padding in positions
model_type = getattr(self.text_config, "model_type", "").lower()
self.is_roberta = "roberta" in model_type
self.padding_idx = getattr(self.text_config, "pad_token_id", 1)
if self.is_roberta:
logger.info("LegacyMixin detected RoBERTa model, enabling position padding")
logger.info("LegacyMixin initialized for legacy/encoder model")
def _should_skip_weight(self, name: str) -> bool:
"""Check if a weight should be skipped during loading."""
for pattern in self._legacy_skip_patterns:
if pattern in name:
return True
return False
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Forward pass with RoBERTa position handling.
RoBERTa models require positions to be offset by padding_idx + 1.
"""
if self.is_roberta and positions is not None:
# RoBERTa-specific positions padding
positions = positions + self.padding_idx + 1
return super().forward(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)

View File

@@ -0,0 +1,170 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend mixins for pooling/embedding models.
This module provides mixins for embedding and sequence classification models:
- EmbeddingMixin: For embedding/sentence similarity models
- SequenceClassificationMixin: For sequence classification/cross-encoding
Following latest vLLM architecture patterns adapted for v0.6.2.
"""
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import PoolerOutput
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class EmbeddingMixin:
"""
Mixin class that adds embedding/pooling functionality.
This mixin provides:
- Pooler layer for extracting embeddings
- pooling method for VllmModelForPooling protocol
Should be used with Base class:
class TransformersForEmbedding(EmbeddingMixin, Base): ...
"""
# Default pooling configuration
default_pooling_type: PoolingType = PoolingType.CLS
default_normalize: bool = True
default_softmax: bool = False
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call next class in MRO (should be Base)
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Get pooler config from model config
pooler_config = vllm_config.model_config.pooler_config
# Setup pooler
self.pooler = Pooler.from_config_with_defaults(
pooler_config=pooler_config,
pooling_type=self.default_pooling_type,
normalize=self.default_normalize,
softmax=self.default_softmax,
)
if self.pooler is None:
# Create default pooler if config doesn't specify
self.pooler = Pooler(
pooling_type=self.default_pooling_type,
normalize=self.default_normalize,
softmax=self.default_softmax,
)
logger.info("EmbeddingMixin initialized (pooling_type=%s, normalize=%s)",
self.pooler.pooling_type.name, self.pooler.normalize)
def pooling(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
"""
Apply pooling to hidden states.
Args:
hidden_states: Hidden states from the model [seq_len, hidden_size]
pooling_metadata: Pooling metadata
Returns:
PoolerOutput with pooled embeddings
"""
return self.pooler(hidden_states, pooling_metadata)
class SequenceClassificationMixin(EmbeddingMixin):
"""
Mixin class that adds sequence classification functionality.
This mixin provides:
- Classifier layer for sequence classification
- pooling method with classification logits
Should be used with Base class:
class TransformersForSequenceClassification(SequenceClassificationMixin, Base): ...
"""
default_pooling_type: PoolingType = PoolingType.CLS
default_normalize: bool = False
default_softmax: bool = True
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call EmbeddingMixin.__init__ -> Base.__init__
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Find and setup classifier layer
self.classifier = self._find_classifier()
if self.classifier is not None:
# Initialize classifier parameters on device
self._init_classifier_params()
logger.info("SequenceClassificationMixin initialized with classifier")
else:
logger.warning("Could not find classifier layer")
def _find_classifier(self) -> Optional[nn.Module]:
"""Find the classifier layer in the model."""
# Common classifier layer names
classifier_names = ['classifier', 'score', 'fc', 'head']
for name in classifier_names:
if hasattr(self.model, name):
return getattr(self.model, name)
return None
def _init_classifier_params(self) -> None:
"""Initialize classifier parameters on target device."""
device = self.device_config.device
if device is None:
device = torch.device("cpu")
dtype = self.model_config.dtype
for name, param in list(self.classifier.named_parameters()):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data, dtype=dtype, device=device),
requires_grad=False,
)
setattr(self.classifier, name.split('.')[-1], new_param)
def pooling(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
"""
Apply pooling and classification to hidden states.
Args:
hidden_states: Hidden states from the model [seq_len, hidden_size]
pooling_metadata: Pooling metadata
Returns:
PoolerOutput with classification logits
"""
# First apply base pooling
pooled = self.pooler(hidden_states, pooling_metadata)
# Apply classifier if available
if self.classifier is not None and pooled is not None:
# Apply classifier to each pooled output
for i, output in enumerate(pooled.outputs):
if hasattr(output, 'data'):
output.data = self.classifier(output.data)
return pooled