# SPDX-License-Identifier: Apache-2.0 # Copyright 2024 The vLLM team. """Transformers modeling backend for causal language models. This module provides a wrapper class that enables vLLM to use any HuggingFace causal language model, including custom models that define their implementation via `auto_map` in config.json. The key insight is that we use HuggingFace's AutoModelForCausalLM to load the actual model, then wrap it with the vLLM interface (compute_logits, sample, etc). """ from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors if TYPE_CHECKING: from transformers import PreTrainedModel from vllm.attention import AttentionMetadata logger = init_logger(__name__) class TransformersForCausalLM(nn.Module): """ A wrapper class that adapts any HuggingFace causal language model to the vLLM interface. This class provides: 1. forward() - processes input through the model 2. compute_logits() - computes output logits 3. sample() - samples tokens from logits 4. load_weights() - loads model weights The actual HuggingFace model is loaded using AutoModelForCausalLM and stored in self.model. Interface compliance: - Implements VllmModel protocol (vllm_config init, forward with required args) - Implements VllmModelForTextGeneration protocol (compute_logits, sample) """ def __init__( self, vllm_config: VllmConfig, prefix: str = "", ) -> None: super().__init__() config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config.hf_config self.model_config = config self.cache_config = cache_config self.quant_config = quant_config self.prefix = prefix logger.info("Using Transformers modeling backend for %s", config.hf_config.architectures) # Load the actual HuggingFace model self._load_hf_model() # Setup logits processor and sampler self.logits_processor = LogitsProcessor( self.config.vocab_size, logits_as_input=False, ) self.sampler = Sampler() def _load_hf_model(self) -> None: """Load the HuggingFace model using AutoModelForCausalLM.""" from transformers import AutoModelForCausalLM # We load with minimal config first - weights will be loaded separately # by vLLM's weight loader logger.info("Loading HuggingFace model from config...") self.model: "PreTrainedModel" = AutoModelForCausalLM.from_config( self.config, torch_dtype=self.model_config.dtype, trust_remote_code=self.model_config.trust_remote_code, ) # Disable gradient computation for inference self.model.eval() for param in self.model.parameters(): param.requires_grad = False def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: "AttentionMetadata", intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ Forward pass through the model. This method conforms to the VllmModel protocol by accepting: - input_ids: Token IDs - positions: Position IDs - kv_caches: KV cache tensors (not used in basic HF forward) - attn_metadata: Attention metadata (not used in basic HF forward) Note: This is a simplified implementation that does not use vLLM's optimized attention mechanisms. For production use with KV caching, a more sophisticated implementation would be needed. """ # For simplicity, we use HuggingFace's native forward # This won't have vLLM's optimizations but will work if inputs_embeds is not None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.unsqueeze(0) if input_ids.dim() == 1 else input_ids} # Position IDs if positions is not None: model_inputs["position_ids"] = positions.unsqueeze(0) if positions.dim() == 1 else positions # Run the model with torch.no_grad(): outputs = self.model( **model_inputs, use_cache=False, return_dict=True, ) # Get hidden states from the last layer # For CausalLM, we typically want the hidden states before the LM head if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None: hidden_states = outputs.hidden_states[-1] else: # Fall back to running without output_hidden_states # and getting logits directly hidden_states = outputs.logits if hidden_states.dim() == 3: hidden_states = hidden_states.squeeze(0) return hidden_states if hidden_states.dim() == 3: hidden_states = hidden_states.squeeze(0) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: """ Compute logits from hidden states. This method conforms to the VllmModelForTextGeneration protocol. """ # If hidden_states are already logits (from forward), process them if hidden_states.shape[-1] == self.config.vocab_size: logits = hidden_states else: # Apply the LM head logits = self.model.lm_head(hidden_states) return self.logits_processor(None, logits, sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: """ Sample tokens from logits. This method conforms to the VllmModelForTextGeneration protocol. """ next_tokens = self.sampler(logits, sampling_metadata) return next_tokens def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], ) -> Set[str]: """ Load weights into the model. This method loads weights from an iterable of (name, tensor) pairs into the HuggingFace model. """ loaded_params: Set[str] = set() model_params = dict(self.model.named_parameters()) for name, loaded_weight in weights: # Try to find the parameter in the model if name in model_params: param = model_params[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) else: # Try common prefixes for prefix in ["model.", ""]: full_name = f"{prefix}{name}" if prefix else name if full_name in model_params: param = model_params[full_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) break return loaded_params def is_backend_compatible() -> bool: """ Check if the current model is compatible with the Transformers backend. This is a simplified check - in practice, compatibility depends on whether the model follows standard HuggingFace conventions. """ return True