In vLLM v0.6.2, ParallelLMHead.forward() raises RuntimeError since its weights should be used through LogitsProcessor.linear_method.apply(). Pass lm_head as first arg to LogitsProcessor which handles the hidden_states -> logits projection internally.
143 lines
5.3 KiB
Python
143 lines
5.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# Copyright 2024 The vLLM team.
|
|
"""Transformers modeling backend mixin for causal language models.
|
|
|
|
This module provides CausalMixin that adds causal language model specific
|
|
functionality (lm_head, compute_logits, sample) to the Base class.
|
|
|
|
Following latest vLLM architecture:
|
|
- TransformersForCausalLM = CausalMixin + Base
|
|
- lm_head is created at the wrapper level (not inside self.model)
|
|
"""
|
|
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
import torch
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class CausalMixin:
|
|
"""
|
|
Mixin class that adds causal language model functionality.
|
|
|
|
This mixin provides:
|
|
- ParallelLMHead for language model head (created at wrapper level)
|
|
- LogitsProcessor for logits computation
|
|
- Sampler for token sampling
|
|
- compute_logits method for VllmModelForTextGeneration protocol
|
|
- sample method for VllmModelForTextGeneration protocol
|
|
|
|
Following latest vLLM architecture:
|
|
- lm_head is a direct attribute of TransformersForCausalLM (not inside self.model)
|
|
- hf_to_vllm_mapper maps "model.lm_head." -> "lm_head." to handle this
|
|
- For tied embeddings, lm_head weight loading is skipped and weights are tied
|
|
|
|
Should be used with Base class:
|
|
class TransformersForCausalLM(CausalMixin, Base): ...
|
|
"""
|
|
|
|
def __init__(self, *, vllm_config: "VllmConfig", prefix: str = "") -> None:
|
|
# Call next class in MRO (should be Base)
|
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
|
|
# Handle tied word embeddings - skip loading lm_head weights
|
|
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
|
|
if tie_word_embeddings:
|
|
self.skip_prefixes.append("lm_head.")
|
|
logger.info("Model has tied word embeddings, will tie lm_head weights")
|
|
|
|
# Create lm_head at wrapper level (following latest vLLM architecture)
|
|
# This is outside self.model, so weights map "model.lm_head." -> "lm_head."
|
|
if self.pp_group.is_last_rank:
|
|
self.lm_head = ParallelLMHead(
|
|
self.vocab_size,
|
|
self.hidden_size,
|
|
quant_config=self.quant_config,
|
|
prefix=maybe_prefix(prefix, "lm_head"),
|
|
)
|
|
|
|
# Tie weights if needed
|
|
if tie_word_embeddings:
|
|
input_embeddings = self.model.get_input_embeddings()
|
|
if input_embeddings is not None:
|
|
self.lm_head = self.lm_head.tie_weights(input_embeddings)
|
|
logger.info("Tied lm_head weights with input embeddings")
|
|
|
|
# Setup logits processor
|
|
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
|
|
self.logits_processor = LogitsProcessor(
|
|
self.vocab_size,
|
|
logits_as_input=False,
|
|
scale=logit_scale,
|
|
)
|
|
|
|
logger.info("CausalMixin initialized (vocab_size=%d, hidden_size=%d, logit_scale=%s)",
|
|
self.vocab_size, self.hidden_size, logit_scale)
|
|
else:
|
|
# For non-last PP ranks, use PPMissingLayer
|
|
self.lm_head = PPMissingLayer()
|
|
self.logits_processor = None
|
|
logger.info("CausalMixin initialized (PP non-last rank, using PPMissingLayer)")
|
|
|
|
# Setup sampler
|
|
self.sampler = Sampler()
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
"""
|
|
Compute logits from hidden states.
|
|
|
|
This method conforms to the VllmModelForTextGeneration protocol.
|
|
|
|
Args:
|
|
hidden_states: Hidden states from the model [seq_len, hidden_size]
|
|
sampling_metadata: Sampling metadata
|
|
|
|
Returns:
|
|
Logits tensor or None
|
|
"""
|
|
if self.logits_processor is None:
|
|
# Non-last PP rank
|
|
return None
|
|
|
|
# In v0.6.2, LogitsProcessor handles the lm_head projection internally
|
|
# via lm_head.linear_method.apply(). Pass lm_head as the first arg.
|
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
sampling_metadata)
|
|
|
|
return logits
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
"""
|
|
Sample tokens from logits.
|
|
|
|
This method conforms to the VllmModelForTextGeneration protocol.
|
|
|
|
Args:
|
|
logits: Logits tensor
|
|
sampling_metadata: Sampling metadata
|
|
|
|
Returns:
|
|
SamplerOutput with sampled tokens
|
|
"""
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|