# SPDX-License-Identifier: Apache-2.0 # Copyright 2024 The vLLM team. """Transformers modeling backend mixin for causal language models. This module provides CausalMixin that adds causal language model specific functionality (lm_head, compute_logits, sample) to the Base class. Following latest vLLM architecture: - TransformersForCausalLM = CausalMixin + Base - lm_head is created at the wrapper level (not inside self.model) """ from typing import TYPE_CHECKING, Optional import torch from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata if TYPE_CHECKING: from vllm.config import VllmConfig logger = init_logger(__name__) class CausalMixin: """ Mixin class that adds causal language model functionality. This mixin provides: - ParallelLMHead for language model head (created at wrapper level) - LogitsProcessor for logits computation - Sampler for token sampling - compute_logits method for VllmModelForTextGeneration protocol - sample method for VllmModelForTextGeneration protocol Following latest vLLM architecture: - lm_head is a direct attribute of TransformersForCausalLM (not inside self.model) - hf_to_vllm_mapper maps "model.lm_head." -> "lm_head." to handle this - For tied embeddings, lm_head weight loading is skipped and weights are tied Should be used with Base class: class TransformersForCausalLM(CausalMixin, Base): ... """ def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None: # Call next class in MRO (should be Base) super().__init__(vllm_config=vllm_config, prefix=prefix) # Handle tied word embeddings - skip loading lm_head weights tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) if tie_word_embeddings: self.skip_prefixes.append("lm_head.") logger.info("Model has tied word embeddings, will tie lm_head weights") # Create lm_head at wrapper level (following latest vLLM architecture) # This is outside self.model, so weights map "model.lm_head." -> "lm_head." if self.pp_group.is_last_rank: self.lm_head = ParallelLMHead( self.vocab_size, self.hidden_size, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) # Tie weights if needed if tie_word_embeddings: input_embeddings = self.model.get_input_embeddings() if input_embeddings is not None: self.lm_head = self.lm_head.tie_weights(input_embeddings) logger.info("Tied lm_head weights with input embeddings") # Setup logits processor logit_scale = getattr(self.text_config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( self.vocab_size, logits_as_input=False, scale=logit_scale, ) logger.info("CausalMixin initialized (vocab_size=%d, hidden_size=%d, logit_scale=%s)", self.vocab_size, self.hidden_size, logit_scale) else: # For non-last PP ranks, use PPMissingLayer self.lm_head = PPMissingLayer() self.logits_processor = None logger.info("CausalMixin initialized (PP non-last rank, using PPMissingLayer)") # Setup sampler self.sampler = Sampler() def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: """ Compute logits from hidden states. This method conforms to the VllmModelForTextGeneration protocol. Args: hidden_states: Hidden states from the model [seq_len, hidden_size] sampling_metadata: Sampling metadata Returns: Logits tensor or None """ if self.logits_processor is None: # Non-last PP rank return None # In v0.6.2, LogitsProcessor handles the lm_head projection internally # via lm_head.linear_method.apply(). Pass lm_head as the first arg. logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: """ Sample tokens from logits. This method conforms to the VllmModelForTextGeneration protocol. Args: logits: Logits tensor sampling_metadata: Sampling metadata Returns: SamplerOutput with sampled tokens """ next_tokens = self.sampler(logits, sampling_metadata) return next_tokens