testing dynamic register
This commit is contained in:
@@ -364,6 +364,7 @@ class Base(nn.Module):
|
|||||||
Replace modules with vLLM optimized versions.
|
Replace modules with vLLM optimized versions.
|
||||||
|
|
||||||
Uses tp_plan for tensor parallel style selection.
|
Uses tp_plan for tensor parallel style selection.
|
||||||
|
Note: lm_head is NOT replaced here - it's created at wrapper level by CausalMixin.
|
||||||
"""
|
"""
|
||||||
logger.info("Replacing modules with vLLM optimized versions...")
|
logger.info("Replacing modules with vLLM optimized versions...")
|
||||||
replaced_count = 0
|
replaced_count = 0
|
||||||
@@ -371,6 +372,9 @@ class Base(nn.Module):
|
|||||||
# Get tensor parallel plan
|
# Get tensor parallel plan
|
||||||
tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {}
|
tp_plan = self._get_tp_plan() if self.tp_group.world_size > 1 else {}
|
||||||
|
|
||||||
|
# Modules to skip replacement (handled at wrapper level)
|
||||||
|
skip_modules = {"lm_head", "score", "classifier"}
|
||||||
|
|
||||||
def _recursive_replace(module: nn.Module, prefix: str = ""):
|
def _recursive_replace(module: nn.Module, prefix: str = ""):
|
||||||
nonlocal replaced_count
|
nonlocal replaced_count
|
||||||
|
|
||||||
@@ -378,6 +382,11 @@ class Base(nn.Module):
|
|||||||
# Skip PPMissingLayer
|
# Skip PPMissingLayer
|
||||||
if isinstance(child, PPMissingLayer):
|
if isinstance(child, PPMissingLayer):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Skip modules that are handled at wrapper level
|
||||||
|
if name in skip_modules:
|
||||||
|
logger.debug("Skipping %s (handled at wrapper level)", name)
|
||||||
|
continue
|
||||||
|
|
||||||
qual_name = maybe_prefix(prefix, name)
|
qual_name = maybe_prefix(prefix, name)
|
||||||
new_module = None
|
new_module = None
|
||||||
@@ -552,20 +561,22 @@ class Base(nn.Module):
|
|||||||
model_inputs["position_ids"] = positions
|
model_inputs["position_ids"] = positions
|
||||||
|
|
||||||
# Forward through model
|
# Forward through model
|
||||||
|
# Note: return_dict=False returns tuple, first element is last hidden state
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
**model_inputs,
|
**model_inputs,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
return_dict=True,
|
return_dict=False,
|
||||||
output_hidden_states=True,
|
|
||||||
attention_instances=self.attention_instances,
|
attention_instances=self.attention_instances,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get hidden states
|
# Get hidden states from model output
|
||||||
if outputs.hidden_states is not None:
|
# For models using return_dict=False, outputs is a tuple
|
||||||
hidden_states = outputs.hidden_states[-1]
|
# outputs[0] is usually the last hidden state
|
||||||
|
if isinstance(outputs, tuple):
|
||||||
|
hidden_states = outputs[0]
|
||||||
else:
|
else:
|
||||||
hidden_states = outputs.logits
|
hidden_states = outputs
|
||||||
|
|
||||||
# Remove batch dimension
|
# Remove batch dimension
|
||||||
if hidden_states.dim() == 3 and hidden_states.size(0) == 1:
|
if hidden_states.dim() == 3 and hidden_states.size(0) == 1:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ functionality (lm_head, compute_logits, sample) to the Base class.
|
|||||||
|
|
||||||
Following latest vLLM architecture:
|
Following latest vLLM architecture:
|
||||||
- TransformersForCausalLM = CausalMixin + Base
|
- TransformersForCausalLM = CausalMixin + Base
|
||||||
|
- lm_head is created at the wrapper level (not inside self.model)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
@@ -16,6 +17,8 @@ import torch
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.models.utils import PPMissingLayer, maybe_prefix
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -29,11 +32,17 @@ class CausalMixin:
|
|||||||
Mixin class that adds causal language model functionality.
|
Mixin class that adds causal language model functionality.
|
||||||
|
|
||||||
This mixin provides:
|
This mixin provides:
|
||||||
|
- ParallelLMHead for language model head (created at wrapper level)
|
||||||
- LogitsProcessor for logits computation
|
- LogitsProcessor for logits computation
|
||||||
- Sampler for token sampling
|
- Sampler for token sampling
|
||||||
- compute_logits method for VllmModelForTextGeneration protocol
|
- compute_logits method for VllmModelForTextGeneration protocol
|
||||||
- sample method for VllmModelForTextGeneration protocol
|
- sample method for VllmModelForTextGeneration protocol
|
||||||
|
|
||||||
|
Following latest vLLM architecture:
|
||||||
|
- lm_head is a direct attribute of TransformersForCausalLM (not inside self.model)
|
||||||
|
- hf_to_vllm_mapper maps "model.lm_head." -> "lm_head." to handle this
|
||||||
|
- For tied embeddings, lm_head weight loading is skipped and weights are tied
|
||||||
|
|
||||||
Should be used with Base class:
|
Should be used with Base class:
|
||||||
class TransformersForCausalLM(CausalMixin, Base): ...
|
class TransformersForCausalLM(CausalMixin, Base): ...
|
||||||
"""
|
"""
|
||||||
@@ -46,21 +55,43 @@ class CausalMixin:
|
|||||||
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
|
tie_word_embeddings = getattr(self.text_config, "tie_word_embeddings", False)
|
||||||
if tie_word_embeddings:
|
if tie_word_embeddings:
|
||||||
self.skip_prefixes.append("lm_head.")
|
self.skip_prefixes.append("lm_head.")
|
||||||
logger.info("Model has tied word embeddings, skipping lm_head weight loading")
|
logger.info("Model has tied word embeddings, will tie lm_head weights")
|
||||||
|
|
||||||
# Setup logits processor
|
# Create lm_head at wrapper level (following latest vLLM architecture)
|
||||||
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
|
# This is outside self.model, so weights map "model.lm_head." -> "lm_head."
|
||||||
self.logits_processor = LogitsProcessor(
|
if self.pp_group.is_last_rank:
|
||||||
self.vocab_size,
|
self.lm_head = ParallelLMHead(
|
||||||
logits_as_input=False,
|
self.vocab_size,
|
||||||
scale=logit_scale,
|
self.hidden_size,
|
||||||
)
|
quant_config=self.quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tie weights if needed
|
||||||
|
if tie_word_embeddings:
|
||||||
|
input_embeddings = self.model.get_input_embeddings()
|
||||||
|
if input_embeddings is not None:
|
||||||
|
self.lm_head = self.lm_head.tie_weights(input_embeddings)
|
||||||
|
logger.info("Tied lm_head weights with input embeddings")
|
||||||
|
|
||||||
|
# Setup logits processor
|
||||||
|
logit_scale = getattr(self.text_config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(
|
||||||
|
self.vocab_size,
|
||||||
|
logits_as_input=False,
|
||||||
|
scale=logit_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("CausalMixin initialized (vocab_size=%d, hidden_size=%d, logit_scale=%s)",
|
||||||
|
self.vocab_size, self.hidden_size, logit_scale)
|
||||||
|
else:
|
||||||
|
# For non-last PP ranks, use PPMissingLayer
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
self.logits_processor = None
|
||||||
|
logger.info("CausalMixin initialized (PP non-last rank, using PPMissingLayer)")
|
||||||
|
|
||||||
# Setup sampler
|
# Setup sampler
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
logger.info("CausalMixin initialized (vocab_size=%d, logit_scale=%s)",
|
|
||||||
self.vocab_size, logit_scale)
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
@@ -79,28 +110,17 @@ class CausalMixin:
|
|||||||
Returns:
|
Returns:
|
||||||
Logits tensor or None
|
Logits tensor or None
|
||||||
"""
|
"""
|
||||||
# Check if hidden_states are already logits (some models output logits directly)
|
if self.logits_processor is None:
|
||||||
if hidden_states.shape[-1] == self.vocab_size:
|
# Non-last PP rank
|
||||||
logits = hidden_states
|
return None
|
||||||
|
|
||||||
|
# Apply lm_head (at wrapper level, not self.model.lm_head)
|
||||||
|
output = self.lm_head(hidden_states)
|
||||||
|
# Handle tuple output from vLLM Linear layers (output, bias)
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
logits = output[0]
|
||||||
else:
|
else:
|
||||||
# Apply the LM head
|
logits = output
|
||||||
lm_head = getattr(self.model, "lm_head", None)
|
|
||||||
if lm_head is None:
|
|
||||||
# Some models use different names
|
|
||||||
lm_head = getattr(self.model, "embed_out", None)
|
|
||||||
if lm_head is None:
|
|
||||||
lm_head = getattr(self.model, "output", None)
|
|
||||||
|
|
||||||
if lm_head is not None:
|
|
||||||
output = lm_head(hidden_states)
|
|
||||||
# Handle tuple output from vLLM Linear layers (output, bias)
|
|
||||||
if isinstance(output, tuple):
|
|
||||||
logits = output[0]
|
|
||||||
else:
|
|
||||||
logits = output
|
|
||||||
else:
|
|
||||||
logger.warning("Could not find lm_head, using hidden_states as logits")
|
|
||||||
logits = hidden_states
|
|
||||||
|
|
||||||
return self.logits_processor(None, logits, sampling_metadata)
|
return self.logits_processor(None, logits, sampling_metadata)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user