From 4d0da98b9e01d0b1ee4915d6dc92e97c91dbc833 Mon Sep 17 00:00:00 2001 From: Chranos <826995883@qq.com> Date: Thu, 5 Feb 2026 18:21:31 +0800 Subject: [PATCH] testing dynamic register --- .../models/transformers/base.py | 23 +++-- .../models/transformers/causal.py | 84 ++++++++++++------- 2 files changed, 69 insertions(+), 38 deletions(-) diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py index 5007e1f..4812a90 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py @@ -364,6 +364,7 @@ class Base(nn.Module): Replace modules with vLLM optimized versions. Uses tp_plan for tensor parallel style selection. + Note: lm_head is NOT replaced here - it's created at wrapper level by CausalMixin. """ logger.info("Replacing modules with vLLM optimized versions...") replaced_count = 0 @@ -371,6 +372,9 @@ class Base(nn.Module): # Get tensor parallel plan tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {} + # Modules to skip replacement (handled at wrapper level) + skip_modules = {"lm_head", "score", "classifier"} + def _recursive_replace(module: nn.Module, prefix: str = ""): nonlocal replaced_count @@ -378,6 +382,11 @@ class Base(nn.Module): # Skip PPMissingLayer if isinstance(child, PPMissingLayer): continue + + # Skip modules that are handled at wrapper level + if name in skip_modules: + logger.debug("Skipping %s (handled at wrapper level)", name) + continue qual_name = maybe_prefix(prefix, name) new_module = None @@ -552,20 +561,22 @@ class Base(nn.Module): model_inputs["position_ids"] = positions # Forward through model + # Note: return_dict=False returns tuple, first element is last hidden state with torch.no_grad(): outputs = self.model( **model_inputs, use_cache=False, - return_dict=True, - output_hidden_states=True, + return_dict=False, attention_instances=self.attention_instances, ) - # Get hidden states - if outputs.hidden_states is not None: - hidden_states = outputs.hidden_states[-1] + # Get hidden states from model output + # For models using return_dict=False, outputs is a tuple + # outputs[0] is usually the last hidden state + if isinstance(outputs, tuple): + hidden_states = outputs[0] else: - hidden_states = outputs.logits + hidden_states = outputs # Remove batch dimension if hidden_states.dim() == 3 and hidden_states.size(0) == 1: diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py index 9d38d3f..33508fd 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/causal.py @@ -7,6 +7,7 @@ functionality (lm_head, compute_logits, sample) to the Base class. Following latest vLLM architecture: - TransformersForCausalLM = CausalMixin + Base +- lm_head is created at the wrapper level (not inside self.model) """ from typing import TYPE_CHECKING, Optional @@ -16,6 +17,8 @@ import torch from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix from vllm.model_executor.sampling_metadata import SamplingMetadata if TYPE_CHECKING: @@ -29,11 +32,17 @@ class CausalMixin: Mixin class that adds causal language model functionality. This mixin provides: + - ParallelLMHead for language model head (created at wrapper level) - LogitsProcessor for logits computation - Sampler for token sampling - compute_logits method for VllmModelForTextGeneration protocol - sample method for VllmModelForTextGeneration protocol + Following latest vLLM architecture: + - lm_head is a direct attribute of TransformersForCausalLM (not inside self.model) + - hf_to_vllm_mapper maps "model.lm_head." -> "lm_head." to handle this + - For tied embeddings, lm_head weight loading is skipped and weights are tied + Should be used with Base class: class TransformersForCausalLM(CausalMixin, Base): ... """ @@ -46,21 +55,43 @@ class CausalMixin: tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False) if tie_word_embeddings: self.skip_prefixes.append("lm_head.") - logger.info("Model has tied word embeddings, skipping lm_head weight loading") + logger.info("Model has tied word embeddings, will tie lm_head weights") - # Setup logits processor - logit_scale = getattr(self.text_config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.vocab_size, - logits_as_input=False, - scale=logit_scale, - ) + # Create lm_head at wrapper level (following latest vLLM architecture) + # This is outside self.model, so weights map "model.lm_head." -> "lm_head." + if self.pp_group.is_last_rank: + self.lm_head = ParallelLMHead( + self.vocab_size, + self.hidden_size, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + + # Tie weights if needed + if tie_word_embeddings: + input_embeddings = self.model.get_input_embeddings() + if input_embeddings is not None: + self.lm_head = self.lm_head.tie_weights(input_embeddings) + logger.info("Tied lm_head weights with input embeddings") + + # Setup logits processor + logit_scale = getattr(self.text_config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.vocab_size, + logits_as_input=False, + scale=logit_scale, + ) + + logger.info("CausalMixin initialized (vocab_size=%d, hidden_size=%d, logit_scale=%s)", + self.vocab_size, self.hidden_size, logit_scale) + else: + # For non-last PP ranks, use PPMissingLayer + self.lm_head = PPMissingLayer() + self.logits_processor = None + logger.info("CausalMixin initialized (PP non-last rank, using PPMissingLayer)") # Setup sampler self.sampler = Sampler() - - logger.info("CausalMixin initialized (vocab_size=%d, logit_scale=%s)", - self.vocab_size, logit_scale) def compute_logits( self, @@ -79,28 +110,17 @@ class CausalMixin: Returns: Logits tensor or None """ - # Check if hidden_states are already logits (some models output logits directly) - if hidden_states.shape[-1] == self.vocab_size: - logits = hidden_states + if self.logits_processor is None: + # Non-last PP rank + return None + + # Apply lm_head (at wrapper level, not self.model.lm_head) + output = self.lm_head(hidden_states) + # Handle tuple output from vLLM Linear layers (output, bias) + if isinstance(output, tuple): + logits = output[0] else: - # Apply the LM head - lm_head = getattr(self.model, "lm_head", None) - if lm_head is None: - # Some models use different names - lm_head = getattr(self.model, "embed_out", None) - if lm_head is None: - lm_head = getattr(self.model, "output", None) - - if lm_head is not None: - output = lm_head(hidden_states) - # Handle tuple output from vLLM Linear layers (output, bias) - if isinstance(output, tuple): - logits = output[0] - else: - logits = output - else: - logger.warning("Could not find lm_head, using hidden_states as logits") - logits = hidden_states + logits = output return self.logits_processor(None, logits, sampling_metadata)