Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py
Chranos ebdc6fed03 fix: pass lm_head to LogitsProcessor instead of calling forward()
In vLLM v0.6.2, ParallelLMHead.forward() raises RuntimeError since
its weights should be used through LogitsProcessor.linear_method.apply().
Pass lm_head as first arg to LogitsProcessor which handles the
hidden_states -> logits projection internally.
2026-02-06 14:21:14 +08:00

143 lines
5.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend mixin for causal language models.
This module provides CausalMixin that adds causal language model specific
functionality (lm_head, compute_logits, sample) to the Base class.
Following latest vLLM architecture:
- TransformersForCausalLM = CausalMixin + Base
- lm_head is created at the wrapper level (not inside self.model)
"""
from typing import TYPE_CHECKING, Optional
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class CausalMixin:
"""
Mixin class that adds causal language model functionality.
This mixin provides:
- ParallelLMHead for language model head (created at wrapper level)
- LogitsProcessor for logits computation
- Sampler for token sampling
- compute_logits method for VllmModelForTextGeneration protocol
- sample method for VllmModelForTextGeneration protocol
Following latest vLLM architecture:
- lm_head is a direct attribute of TransformersForCausalLM (not inside self.model)
- hf_to_vllm_mapper maps "model.lm_head." -> "lm_head." to handle this
- For tied embeddings, lm_head weight loading is skipped and weights are tied
Should be used with Base class:
class TransformersForCausalLM(CausalMixin, Base): ...
"""
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call next class in MRO (should be Base)
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Handle tied word embeddings - skip loading lm_head weights
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
if tie_word_embeddings:
self.skip_prefixes.append("lm_head.")
logger.info("Model has tied word embeddings, will tie lm_head weights")
# Create lm_head at wrapper level (following latest vLLM architecture)
# This is outside self.model, so weights map "model.lm_head." -> "lm_head."
if self.pp_group.is_last_rank:
self.lm_head = ParallelLMHead(
self.vocab_size,
self.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
# Tie weights if needed
if tie_word_embeddings:
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is not None:
self.lm_head = self.lm_head.tie_weights(input_embeddings)
logger.info("Tied lm_head weights with input embeddings")
# Setup logits processor
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.vocab_size,
logits_as_input=False,
scale=logit_scale,
)
logger.info("CausalMixin initialized (vocab_size=%d, hidden_size=%d, logit_scale=%s)",
self.vocab_size, self.hidden_size, logit_scale)
else:
# For non-last PP ranks, use PPMissingLayer
self.lm_head = PPMissingLayer()
self.logits_processor = None
logger.info("CausalMixin initialized (PP non-last rank, using PPMissingLayer)")
# Setup sampler
self.sampler = Sampler()
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
"""
Compute logits from hidden states.
This method conforms to the VllmModelForTextGeneration protocol.
Args:
hidden_states: Hidden states from the model [seq_len, hidden_size]
sampling_metadata: Sampling metadata
Returns:
Logits tensor or None
"""
if self.logits_processor is None:
# Non-last PP rank
return None
# In v0.6.2, LogitsProcessor handles the lm_head projection internally
# via lm_head.linear_method.apply(). Pass lm_head as the first arg.
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Sample tokens from logits.
This method conforms to the VllmModelForTextGeneration protocol.
Args:
logits: Logits tensor
sampling_metadata: Sampling metadata
Returns:
SamplerOutput with sampled tokens
"""
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens