# SPDX-License-Identifier: Apache-2.0 # Copyright 2024 The vLLM team. """Transformers modeling backend mixin for causal language models. This module provides CausalMixin that adds causal language model specific functionality (lm_head, compute_logits, sample) to the Base class. Following latest vLLM architecture: - TransformersForCausalLM = CausalMixin + Base """ from typing import TYPE_CHECKING, Optional import torch from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata if TYPE_CHECKING: from vllm.config import VllmConfig logger = init_logger(__name__) class CausalMixin: """ Mixin class that adds causal language model functionality. This mixin provides: - LogitsProcessor for logits computation - Sampler for token sampling - compute_logits method for VllmModelForTextGeneration protocol - sample method for VllmModelForTextGeneration protocol Should be used with Base class: class TransformersForCausalLM(CausalMixin, Base): ... """ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None: # Call next class in MRO (should be Base) super().__init__(vllm_config=vllm_config, prefix=prefix) # Handle tied word embeddings - skip loading lm_head weights tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) if tie_word_embeddings: self.skip_prefixes.append("lm_head.") logger.info("Model has tied word embeddings, skipping lm_head weight loading") # Setup logits processor logit_scale = getattr(self.text_config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( self.vocab_size, logits_as_input=False, scale=logit_scale, ) # Setup sampler self.sampler = Sampler() logger.info("CausalMixin initialized (vocab_size=%d, logit_scale=%s)", self.vocab_size, logit_scale) def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: """ Compute logits from hidden states. This method conforms to the VllmModelForTextGeneration protocol. Args: hidden_states: Hidden states from the model [seq_len, hidden_size] sampling_metadata: Sampling metadata Returns: Logits tensor or None """ # Check if hidden_states are already logits (some models output logits directly) if hidden_states.shape[-1] == self.vocab_size: logits = hidden_states else: # Apply the LM head lm_head = getattr(self.model, "lm_head", None) if lm_head is None: # Some models use different names lm_head = getattr(self.model, "embed_out", None) if lm_head is None: lm_head = getattr(self.model, "output", None) if lm_head is not None: output = lm_head(hidden_states) # Handle tuple output from vLLM Linear layers (output, bias) if isinstance(output, tuple): logits = output[0] else: logits = output else: logger.warning("Could not find lm_head, using hidden_states as logits") logits = hidden_states return self.logits_processor(None, logits, sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: """ Sample tokens from logits. This method conforms to the VllmModelForTextGeneration protocol. Args: logits: Logits tensor sampling_metadata: Sampling metadata Returns: SamplerOutput with sampled tokens """ next_tokens = self.sampler(logits, sampling_metadata) return next_tokens