testing dynamic register
This commit is contained in:
@@ -364,6 +364,7 @@ class Base(nn.Module):
|
||||
Replace modules with vLLM optimized versions.
|
||||
|
||||
Uses tp_plan for tensor parallel style selection.
|
||||
Note: lm_head is NOT replaced here - it's created at wrapper level by CausalMixin.
|
||||
"""
|
||||
logger.info("Replacing modules with vLLM optimized versions...")
|
||||
replaced_count = 0
|
||||
@@ -371,6 +372,9 @@ class Base(nn.Module):
|
||||
# Get tensor parallel plan
|
||||
tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {}
|
||||
|
||||
# Modules to skip replacement (handled at wrapper level)
|
||||
skip_modules = {"lm_head", "score", "classifier"}
|
||||
|
||||
def _recursive_replace(module: nn.Module, prefix: str = ""):
|
||||
nonlocal replaced_count
|
||||
|
||||
@@ -378,6 +382,11 @@ class Base(nn.Module):
|
||||
# Skip PPMissingLayer
|
||||
if isinstance(child, PPMissingLayer):
|
||||
continue
|
||||
|
||||
# Skip modules that are handled at wrapper level
|
||||
if name in skip_modules:
|
||||
logger.debug("Skipping %s (handled at wrapper level)", name)
|
||||
continue
|
||||
|
||||
qual_name = maybe_prefix(prefix, name)
|
||||
new_module = None
|
||||
@@ -552,20 +561,22 @@ class Base(nn.Module):
|
||||
model_inputs["position_ids"] = positions
|
||||
|
||||
# Forward through model
|
||||
# Note: return_dict=False returns tuple, first element is last hidden state
|
||||
with torch.no_grad():
|
||||
outputs = self.model(
|
||||
**model_inputs,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
return_dict=False,
|
||||
attention_instances=self.attention_instances,
|
||||
)
|
||||
|
||||
# Get hidden states
|
||||
if outputs.hidden_states is not None:
|
||||
hidden_states = outputs.hidden_states[-1]
|
||||
# Get hidden states from model output
|
||||
# For models using return_dict=False, outputs is a tuple
|
||||
# outputs[0] is usually the last hidden state
|
||||
if isinstance(outputs, tuple):
|
||||
hidden_states = outputs[0]
|
||||
else:
|
||||
hidden_states = outputs.logits
|
||||
hidden_states = outputs
|
||||
|
||||
# Remove batch dimension
|
||||
if hidden_states.dim() == 3 and hidden_states.size(0) == 1:
|
||||
|
||||
@@ -7,6 +7,7 @@ functionality (lm_head, compute_logits, sample) to the Base class.
|
||||
|
||||
Following latest vLLM architecture:
|
||||
- TransformersForCausalLM = CausalMixin + Base
|
||||
- lm_head is created at the wrapper level (not inside self.model)
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
@@ -16,6 +17,8 @@ import torch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -29,11 +32,17 @@ class CausalMixin:
|
||||
Mixin class that adds causal language model functionality.
|
||||
|
||||
This mixin provides:
|
||||
- ParallelLMHead for language model head (created at wrapper level)
|
||||
- LogitsProcessor for logits computation
|
||||
- Sampler for token sampling
|
||||
- compute_logits method for VllmModelForTextGeneration protocol
|
||||
- sample method for VllmModelForTextGeneration protocol
|
||||
|
||||
Following latest vLLM architecture:
|
||||
- lm_head is a direct attribute of TransformersForCausalLM (not inside self.model)
|
||||
- hf_to_vllm_mapper maps "model.lm_head." -> "lm_head." to handle this
|
||||
- For tied embeddings, lm_head weight loading is skipped and weights are tied
|
||||
|
||||
Should be used with Base class:
|
||||
class TransformersForCausalLM(CausalMixin, Base): ...
|
||||
"""
|
||||
@@ -46,21 +55,43 @@ class CausalMixin:
|
||||
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
|
||||
if tie_word_embeddings:
|
||||
self.skip_prefixes.append("lm_head.")
|
||||
logger.info("Model has tied word embeddings, skipping lm_head weight loading")
|
||||
logger.info("Model has tied word embeddings, will tie lm_head weights")
|
||||
|
||||
# Setup logits processor
|
||||
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.vocab_size,
|
||||
logits_as_input=False,
|
||||
scale=logit_scale,
|
||||
)
|
||||
# Create lm_head at wrapper level (following latest vLLM architecture)
|
||||
# This is outside self.model, so weights map "model.lm_head." -> "lm_head."
|
||||
if self.pp_group.is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.vocab_size,
|
||||
self.hidden_size,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
# Tie weights if needed
|
||||
if tie_word_embeddings:
|
||||
input_embeddings = self.model.get_input_embeddings()
|
||||
if input_embeddings is not None:
|
||||
self.lm_head = self.lm_head.tie_weights(input_embeddings)
|
||||
logger.info("Tied lm_head weights with input embeddings")
|
||||
|
||||
# Setup logits processor
|
||||
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.vocab_size,
|
||||
logits_as_input=False,
|
||||
scale=logit_scale,
|
||||
)
|
||||
|
||||
logger.info("CausalMixin initialized (vocab_size=%d, hidden_size=%d, logit_scale=%s)",
|
||||
self.vocab_size, self.hidden_size, logit_scale)
|
||||
else:
|
||||
# For non-last PP ranks, use PPMissingLayer
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = None
|
||||
logger.info("CausalMixin initialized (PP non-last rank, using PPMissingLayer)")
|
||||
|
||||
# Setup sampler
|
||||
self.sampler = Sampler()
|
||||
|
||||
logger.info("CausalMixin initialized (vocab_size=%d, logit_scale=%s)",
|
||||
self.vocab_size, logit_scale)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
@@ -79,28 +110,17 @@ class CausalMixin:
|
||||
Returns:
|
||||
Logits tensor or None
|
||||
"""
|
||||
# Check if hidden_states are already logits (some models output logits directly)
|
||||
if hidden_states.shape[-1] == self.vocab_size:
|
||||
logits = hidden_states
|
||||
if self.logits_processor is None:
|
||||
# Non-last PP rank
|
||||
return None
|
||||
|
||||
# Apply lm_head (at wrapper level, not self.model.lm_head)
|
||||
output = self.lm_head(hidden_states)
|
||||
# Handle tuple output from vLLM Linear layers (output, bias)
|
||||
if isinstance(output, tuple):
|
||||
logits = output[0]
|
||||
else:
|
||||
# Apply the LM head
|
||||
lm_head = getattr(self.model, "lm_head", None)
|
||||
if lm_head is None:
|
||||
# Some models use different names
|
||||
lm_head = getattr(self.model, "embed_out", None)
|
||||
if lm_head is None:
|
||||
lm_head = getattr(self.model, "output", None)
|
||||
|
||||
if lm_head is not None:
|
||||
output = lm_head(hidden_states)
|
||||
# Handle tuple output from vLLM Linear layers (output, bias)
|
||||
if isinstance(output, tuple):
|
||||
logits = output[0]
|
||||
else:
|
||||
logits = output
|
||||
else:
|
||||
logger.warning("Could not find lm_head, using hidden_states as logits")
|
||||
logits = hidden_states
|
||||
logits = output
|
||||
|
||||
return self.logits_processor(None, logits, sampling_metadata)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user