Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py
2026-02-06 15:05:48 +08:00

126 lines
4.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend mixin for causal language models.
This module provides CausalMixin that adds causal language model specific
functionality (lm_head, compute_logits, sample) to the Base class.
Following latest vLLM architecture:
- TransformersForCausalLM = CausalMixin + Base
"""
from typing import TYPE_CHECKING, Optional
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
class CausalMixin:
"""
Mixin class that adds causal language model functionality.
This mixin provides:
- LogitsProcessor for logits computation
- Sampler for token sampling
- compute_logits method for VllmModelForTextGeneration protocol
- sample method for VllmModelForTextGeneration protocol
Should be used with Base class:
class TransformersForCausalLM(CausalMixin, Base): ...
"""
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
# Call next class in MRO (should be Base)
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Handle tied word embeddings - skip loading lm_head weights
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
if tie_word_embeddings:
self.skip_prefixes.append("lm_head.")
logger.info("Model has tied word embeddings, skipping lm_head weight loading")
# Setup logits processor
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.vocab_size,
logits_as_input=False,
scale=logit_scale,
)
# Setup sampler
self.sampler = Sampler()
logger.info("CausalMixin initialized (vocab_size=%d, logit_scale=%s)",
self.vocab_size, logit_scale)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
"""
Compute logits from hidden states.
This method conforms to the VllmModelForTextGeneration protocol.
Args:
hidden_states: Hidden states from the model [seq_len, hidden_size]
sampling_metadata: Sampling metadata
Returns:
Logits tensor or None
"""
# Check if hidden_states are already logits (some models output logits directly)
if hidden_states.shape[-1] == self.vocab_size:
logits = hidden_states
else:
# Apply the LM head
lm_head = getattr(self.model, "lm_head", None)
if lm_head is None:
# Some models use different names
lm_head = getattr(self.model, "embed_out", None)
if lm_head is None:
lm_head = getattr(self.model, "output", None)
if lm_head is not None:
output = lm_head(hidden_states)
# Handle tuple output from vLLM Linear layers (output, bias)
if isinstance(output, tuple):
logits = output[0]
else:
logits = output
else:
logger.warning("Could not find lm_head, using hidden_states as logits")
logits = hidden_states
return self.logits_processor(None, logits, sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Sample tokens from logits.
This method conforms to the VllmModelForTextGeneration protocol.
Args:
logits: Logits tensor
sampling_metadata: Sampling metadata
Returns:
SamplerOutput with sampled tokens
"""
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens