testing dynamic register

This commit is contained in:
Chranos
2026-02-05 18:21:31 +08:00
parent 05605419e3
commit 4d0da98b9e
2 changed files with 69 additions and 38 deletions

View File

@@ -364,6 +364,7 @@ class Base(nn.Module):
Replace modules with vLLM optimized versions.
Uses tp_plan for tensor parallel style selection.
Note: lm_head is NOT replaced here - it's created at wrapper level by CausalMixin.
"""
logger.info("Replacing modules with vLLM optimized versions...")
replaced_count = 0
@@ -371,6 +372,9 @@ class Base(nn.Module):
# Get tensor parallel plan
tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {}
# Modules to skip replacement (handled at wrapper level)
skip_modules = {"lm_head", "score", "classifier"}
def _recursive_replace(module: nn.Module, prefix: str = ""):
nonlocal replaced_count
@@ -378,6 +382,11 @@ class Base(nn.Module):
# Skip PPMissingLayer
if isinstance(child, PPMissingLayer):
continue
# Skip modules that are handled at wrapper level
if name in skip_modules:
logger.debug("Skipping %s (handled at wrapper level)", name)
continue
qual_name = maybe_prefix(prefix, name)
new_module = None
@@ -552,20 +561,22 @@ class Base(nn.Module):
model_inputs["position_ids"] = positions
# Forward through model
# Note: return_dict=False returns tuple, first element is last hidden state
with torch.no_grad():
outputs = self.model(
**model_inputs,
use_cache=False,
return_dict=True,
output_hidden_states=True,
return_dict=False,
attention_instances=self.attention_instances,
)
# Get hidden states
if outputs.hidden_states is not None:
hidden_states = outputs.hidden_states[-1]
# Get hidden states from model output
# For models using return_dict=False, outputs is a tuple
# outputs[0] is usually the last hidden state
if isinstance(outputs, tuple):
hidden_states = outputs[0]
else:
hidden_states = outputs.logits
hidden_states = outputs
# Remove batch dimension
if hidden_states.dim() == 3 and hidden_states.size(0) == 1:

View File

@@ -7,6 +7,7 @@ functionality (lm_head, compute_logits, sample) to the Base class.
Following latest vLLM architecture:
- TransformersForCausalLM = CausalMixin + Base
- lm_head is created at the wrapper level (not inside self.model)
"""
from typing import TYPE_CHECKING, Optional
@@ -16,6 +17,8 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
if TYPE_CHECKING:
@@ -29,11 +32,17 @@ class CausalMixin:
Mixin class that adds causal language model functionality.
This mixin provides:
- ParallelLMHead for language model head (created at wrapper level)
- LogitsProcessor for logits computation
- Sampler for token sampling
- compute_logits method for VllmModelForTextGeneration protocol
- sample method for VllmModelForTextGeneration protocol
Following latest vLLM architecture:
- lm_head is a direct attribute of TransformersForCausalLM (not inside self.model)
- hf_to_vllm_mapper maps "model.lm_head." -> "lm_head." to handle this
- For tied embeddings, lm_head weight loading is skipped and weights are tied
Should be used with Base class:
class TransformersForCausalLM(CausalMixin, Base): ...
"""
@@ -46,21 +55,43 @@ class CausalMixin:
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
if tie_word_embeddings:
self.skip_prefixes.append("lm_head.")
logger.info("Model has tied word embeddings, skipping lm_head weight loading")
logger.info("Model has tied word embeddings, will tie lm_head weights")
# Setup logits processor
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.vocab_size,
logits_as_input=False,
scale=logit_scale,
)
# Create lm_head at wrapper level (following latest vLLM architecture)
# This is outside self.model, so weights map "model.lm_head." -> "lm_head."
if self.pp_group.is_last_rank:
self.lm_head = ParallelLMHead(
self.vocab_size,
self.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
# Tie weights if needed
if tie_word_embeddings:
input_embeddings = self.model.get_input_embeddings()
if input_embeddings is not None:
self.lm_head = self.lm_head.tie_weights(input_embeddings)
logger.info("Tied lm_head weights with input embeddings")
# Setup logits processor
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.vocab_size,
logits_as_input=False,
scale=logit_scale,
)
logger.info("CausalMixin initialized (vocab_size=%d, hidden_size=%d, logit_scale=%s)",
self.vocab_size, self.hidden_size, logit_scale)
else:
# For non-last PP ranks, use PPMissingLayer
self.lm_head = PPMissingLayer()
self.logits_processor = None
logger.info("CausalMixin initialized (PP non-last rank, using PPMissingLayer)")
# Setup sampler
self.sampler = Sampler()
logger.info("CausalMixin initialized (vocab_size=%d, logit_scale=%s)",
self.vocab_size, logit_scale)
def compute_logits(
self,
@@ -79,28 +110,17 @@ class CausalMixin:
Returns:
Logits tensor or None
"""
# Check if hidden_states are already logits (some models output logits directly)
if hidden_states.shape[-1] == self.vocab_size:
logits = hidden_states
if self.logits_processor is None:
# Non-last PP rank
return None
# Apply lm_head (at wrapper level, not self.model.lm_head)
output = self.lm_head(hidden_states)
# Handle tuple output from vLLM Linear layers (output, bias)
if isinstance(output, tuple):
logits = output[0]
else:
# Apply the LM head
lm_head = getattr(self.model, "lm_head", None)
if lm_head is None:
# Some models use different names
lm_head = getattr(self.model, "embed_out", None)
if lm_head is None:
lm_head = getattr(self.model, "output", None)
if lm_head is not None:
output = lm_head(hidden_states)
# Handle tuple output from vLLM Linear layers (output, bias)
if isinstance(output, tuple):
logits = output[0]
else:
logits = output
else:
logger.warning("Could not find lm_head, using hidden_states as logits")
logits = hidden_states
logits = output
return self.logits_processor(None, logits, sampling_metadata)