testing dynamic register

This commit is contained in:
Chranos
2026-02-06 14:17:06 +08:00
parent fba02652c8
commit b702adf015
2 changed files with 93 additions and 186 deletions

View File

@@ -22,6 +22,7 @@ from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
@@ -251,17 +252,6 @@ class Base(nn.Module):
logger.info("Creating model structure on meta device...")
# DEBUG: Print config info before any modifications
logger.info("DEBUG: Config type: %s", type(self.config).__name__)
logger.info("DEBUG: text_config type: %s", type(self.text_config).__name__)
logger.info("DEBUG: hidden_size=%s, num_attention_heads=%s",
getattr(self.text_config, 'hidden_size', 'N/A'),
getattr(self.text_config, 'num_attention_heads', 'N/A'))
logger.info("DEBUG: config.head_dim=%s (before fix)",
getattr(self.config, 'head_dim', 'NOT SET'))
logger.info("DEBUG: text_config.head_dim=%s (before fix)",
getattr(self.text_config, 'head_dim', 'NOT SET'))
# Set attention implementation to vLLM's
self.text_config._attn_implementation = "vllm"
@@ -270,7 +260,6 @@ class Base(nn.Module):
# Some models may have incorrect head_dim, so we compute and set it
if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"):
correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads
logger.info("DEBUG: Computed correct_head_dim = %d", correct_head_dim)
# Check and fix head_dim in text_config
if hasattr(self.text_config, "head_dim"):
@@ -297,11 +286,6 @@ class Base(nn.Module):
# Some models also need _attn_implementation in config
self.config._attn_implementation = "vllm"
logger.info("DEBUG: config.head_dim=%s (after fix)",
getattr(self.config, 'head_dim', 'NOT SET'))
logger.info("DEBUG: text_config.head_dim=%s (after fix)",
getattr(self.text_config, 'head_dim', 'NOT SET'))
with init_on_device_without_buffers("meta"):
self.model: "PreTrainedModel" = AutoModel.from_config(
self.config,
@@ -461,7 +445,8 @@ class Base(nn.Module):
)
replaced_count += 1
elif child.__class__.__name__.endswith("RMSNorm"):
elif child.__class__.__name__.endswith("RMSNorm") and \
not isinstance(child, RMSNorm):
new_module = replace_rms_norm_class(child, self.hidden_size)
replaced_count += 1
@@ -475,64 +460,8 @@ class Base(nn.Module):
logger.info("Replaced %d modules", replaced_count)
def _add_attention_debug_hook(self) -> None:
"""Add debug hooks to capture actual tensor shapes during forward."""
# Monkey-patch apply_rotary_pos_emb in the transformers module
try:
import transformers.models.qwen2.modeling_qwen2 as qwen2_module
original_apply_rotary = qwen2_module.apply_rotary_pos_emb
def _debug_apply_rotary(q, k, cos, sin, unsqueeze_dim=1):
logger.info("DEBUG ROTARY: q.shape=%s, k.shape=%s, cos.shape=%s, sin.shape=%s",
q.shape, k.shape, cos.shape, sin.shape)
# After unsqueeze
cos_unsqueezed = cos.unsqueeze(unsqueeze_dim)
sin_unsqueezed = sin.unsqueeze(unsqueeze_dim)
logger.info("DEBUG ROTARY: after unsqueeze(%d): cos.shape=%s, sin.shape=%s",
unsqueeze_dim, cos_unsqueezed.shape, sin_unsqueezed.shape)
logger.info("DEBUG ROTARY: q dim 3 = %d, cos dim 3 = %d",
q.shape[3] if q.dim() >= 4 else -1,
cos_unsqueezed.shape[3] if cos_unsqueezed.dim() >= 4 else -1)
return original_apply_rotary(q, k, cos, sin, unsqueeze_dim)
qwen2_module.apply_rotary_pos_emb = _debug_apply_rotary
logger.info("DEBUG: Patched apply_rotary_pos_emb for debugging")
except Exception as e:
logger.warning("DEBUG: Failed to patch apply_rotary_pos_emb: %s", e)
# Also add a forward pre-hook with kwargs support
for name, module in self.model.named_modules():
if "Attention" in module.__class__.__name__:
def _debug_hook(mod, args, kwargs):
hidden = kwargs.get('hidden_states', args[0] if args else None)
if hidden is not None:
logger.info("DEBUG HOOK: Attention input hidden_states.shape=%s", hidden.shape)
logger.info("DEBUG HOOK: mod.head_dim=%s (at forward time)", getattr(mod, 'head_dim', 'NOT SET'))
# Check mod.config.head_dim
mod_config = getattr(mod, 'config', None)
if mod_config:
logger.info("DEBUG HOOK: mod.config.head_dim=%s", getattr(mod_config, 'head_dim', 'NOT SET'))
logger.info("DEBUG HOOK: mod.config id=%d, same as self.config=%s",
id(mod_config), id(mod_config) == id(mod_config))
# Try q_proj
q_proj = getattr(mod, 'q_proj', None)
if q_proj is not None:
try:
q_out = q_proj(hidden)
logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape)
head_dim = getattr(mod, 'head_dim', 128)
input_shape = hidden.shape[:-1]
hidden_shape = (*input_shape, -1, head_dim)
logger.info("DEBUG HOOK: view target shape=%s", hidden_shape)
q_viewed = q_out.view(hidden_shape)
logger.info("DEBUG HOOK: q_proj viewed shape=%s", q_viewed.shape)
q_transposed = q_viewed.transpose(1, 2)
logger.info("DEBUG HOOK: q_proj transposed shape=%s", q_transposed.shape)
except Exception as e:
logger.info("DEBUG HOOK: Error: %s", e)
module.register_forward_pre_hook(_debug_hook, with_kwargs=True)
logger.info("DEBUG: Added debug hook (with_kwargs) to %s", name)
break
"""No-op. Debug hooks removed after root cause identified."""
pass
def _fix_attention_head_dim(self) -> None:
"""
@@ -546,50 +475,36 @@ class Base(nn.Module):
correct_head_dim = self.hidden_size // getattr(
self.text_config, "num_attention_heads", 32
)
logger.info("DEBUG: _fix_attention_head_dim called, correct_head_dim=%d", correct_head_dim)
fixed_count = 0
attention_modules_found = []
rotary_modules_fixed = []
for name, module in self.model.named_modules():
module_name = module.__class__.__name__
# Fix head_dim in Attention modules
if "Attention" in module_name:
current_head_dim = getattr(module, 'head_dim', 'NOT SET')
num_heads = getattr(module, 'num_heads', 'NOT SET')
num_kv_heads = getattr(module, 'num_key_value_heads', 'NOT SET')
attention_modules_found.append(
f"{name}: head_dim={current_head_dim}, num_heads={num_heads}, num_kv_heads={num_kv_heads}"
)
# Fix head_dim if it exists and is incorrect
if hasattr(module, "head_dim"):
if module.head_dim != correct_head_dim:
logger.warning(
"DEBUG: Fixing head_dim in %s: %d -> %d",
"Fixing head_dim in %s: %d -> %d",
name, module.head_dim, correct_head_dim
)
module.head_dim = correct_head_dim
fixed_count += 1
# Fix rotary embeddings - need to recreate inv_freq buffer
# Fix rotary embeddings - recreate inv_freq buffer if needed
if "RotaryEmbedding" in module_name:
# Check if rotary embedding has wrong dimension
if hasattr(module, "inv_freq"):
current_dim = module.inv_freq.shape[0] * 2 # inv_freq is half the dim
current_dim = module.inv_freq.shape[0] * 2
if current_dim != correct_head_dim:
logger.warning(
"DEBUG: Recreating rotary embedding %s: dim %d -> %d",
"Recreating rotary embedding %s: dim %d -> %d",
name, current_dim, correct_head_dim
)
# Recreate inv_freq with correct dimension
base = getattr(module.config, 'rope_theta', 10000.0)
if hasattr(module.config, 'rope_parameters'):
base = module.config.rope_parameters.get('rope_theta', base)
device = module.inv_freq.device
# Create new inv_freq
inv_freq = 1.0 / (
base ** (
torch.arange(0, correct_head_dim, 2, dtype=torch.int64)
@@ -599,22 +514,9 @@ class Base(nn.Module):
module.register_buffer("inv_freq", inv_freq, persistent=False)
if hasattr(module, "original_inv_freq"):
module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
rotary_modules_fixed.append(name)
# Print debug info
if attention_modules_found:
logger.info("DEBUG: Found %d attention modules", len(attention_modules_found))
for info in attention_modules_found[:3]:
logger.info("DEBUG: Attention module: %s", info)
if rotary_modules_fixed:
logger.info("DEBUG: Fixed %d rotary embedding modules: %s",
len(rotary_modules_fixed), rotary_modules_fixed)
if fixed_count > 0:
logger.info("Fixed head_dim in %d attention modules", fixed_count)
else:
logger.info("DEBUG: No attention modules needed head_dim fix")
def _replace_input_embeddings(self) -> None:
"""Replace input embeddings with VocabParallelEmbedding."""
@@ -758,80 +660,6 @@ class Base(nn.Module):
# Forward through model
# Note: return_dict=False returns tuple, first element is last hidden state
# DEBUG: Print detailed model structure info before forward
if not hasattr(self, '_debug_printed'):
self._debug_printed = True
logger.info("DEBUG: === Detailed model structure debug ===")
# Print transformers version
try:
import transformers
logger.info("DEBUG: transformers version: %s", transformers.__version__)
except Exception:
pass
# Print TP world size
logger.info("DEBUG: TP world_size=%d", self.tp_group.world_size)
# Print first attention module details
for name, module in self.model.named_modules():
if "Attention" in module.__class__.__name__:
logger.info("DEBUG: First attention: %s (class=%s)", name, module.__class__.__name__)
# Print all attributes
for attr in ['head_dim', 'num_heads', 'num_key_value_heads',
'hidden_size', 'num_attention_heads',
'num_key_value_groups']:
val = getattr(module, attr, 'NOT SET')
logger.info("DEBUG: %s = %s", attr, val)
# Print rotary_emb
rotary = getattr(module, 'rotary_emb', None)
if rotary:
logger.info("DEBUG: rotary_emb: %s", type(rotary).__name__)
if hasattr(rotary, 'inv_freq'):
logger.info("DEBUG: rotary_emb.inv_freq.shape: %s", rotary.inv_freq.shape)
else:
logger.info("DEBUG: rotary_emb: None")
# Print projection shapes
for proj_name in ['q_proj', 'k_proj', 'v_proj', 'o_proj']:
proj = getattr(module, proj_name, None)
if proj:
if hasattr(proj, 'weight'):
logger.info("DEBUG: %s: type=%s, weight.shape=%s",
proj_name, type(proj).__name__,
proj.weight.shape if proj.weight is not None else 'None')
elif hasattr(proj, 'output_size'):
logger.info("DEBUG: %s: type=%s, in=%s, out=%s, out_per_part=%s",
proj_name, type(proj).__name__,
getattr(proj, 'input_size', 'N/A'),
getattr(proj, 'output_size', 'N/A'),
getattr(proj, 'output_size_per_partition', 'N/A'))
break
# Print model-level rotary_emb
model_rotary = getattr(self.model, 'rotary_emb', None)
if model_rotary:
logger.info("DEBUG: Model-level rotary_emb: %s", type(model_rotary).__name__)
if hasattr(model_rotary, 'inv_freq'):
logger.info("DEBUG: Model rotary_emb.inv_freq.shape: %s", model_rotary.inv_freq.shape)
else:
logger.info("DEBUG: No model-level rotary_emb")
# Check nested
for name, module in self.model.named_modules():
if "RotaryEmbedding" in module.__class__.__name__:
inv_freq_shape = module.inv_freq.shape if hasattr(module, 'inv_freq') else 'N/A'
logger.info("DEBUG: Found rotary at %s: inv_freq.shape=%s", name, inv_freq_shape)
break
# Print config details
for attr in ['head_dim', 'hidden_size', 'num_attention_heads', 'num_key_value_heads',
'intermediate_size', 'num_hidden_layers']:
logger.info("DEBUG: config.%s = %s", attr, getattr(self.config, attr, 'NOT SET'))
logger.info("DEBUG: === End debug ===")
with torch.no_grad():
outputs = self.model(
**model_inputs,

View File

@@ -8,7 +8,7 @@ module replacement functions.
"""
from contextlib import contextmanager
from typing import TYPE_CHECKING, Literal, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -123,23 +123,102 @@ def replace_linear_class(
)
class TransformersRMSNorm(RMSNorm):
"""
vLLM RMSNorm subclass that preserves tensor dimensions.
vLLM's RMSNorm (especially the MLU backend) flattens input to 2D
(e.g., [batch, seq, hidden] -> [batch*seq, hidden]), but transformers
expects the batch dimension to be preserved. This subclass wraps
the parent forward methods to save and restore the original tensor shape.
Since this inherits from RMSNorm directly, weight loading via
named_parameters() works correctly (weight path stays the same).
"""
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_native(x, residual)
return self._restore_shape(result, orig_shape)
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_cuda(x, residual)
return self._restore_shape(result, orig_shape)
def forward_mlu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_mlu(x, residual)
return self._restore_shape(result, orig_shape)
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_xpu(x, residual)
return self._restore_shape(result, orig_shape)
def forward_hpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_hpu(x, residual)
return self._restore_shape(result, orig_shape)
@staticmethod
def _restore_shape(result, orig_shape: Tuple):
"""Restore original tensor shape if it was changed."""
if isinstance(result, tuple):
restored = []
for t in result:
if t is not None and t.shape != orig_shape:
t = t.view(orig_shape)
restored.append(t)
return tuple(restored)
else:
if result.shape != orig_shape:
result = result.view(orig_shape)
return result
def replace_rms_norm_class(
rms_norm: nn.Module,
hidden_size: int,
) -> RMSNorm:
) -> nn.Module:
"""
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm.
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm,
wrapped to preserve tensor dimensions.
vLLM's RMSNorm provides:
- Fused CUDA kernels for better performance
- Support for fused add + norm operations
The wrapper ensures that the original tensor shape (including batch
dimension) is preserved, which is required by transformers' model
forward methods.
Args:
rms_norm: The RMSNorm module to replace.
hidden_size: The hidden size of the model.
Returns:
The new vLLM RMSNorm layer.
The new vLLM RMSNorm layer wrapped for shape preservation.
"""
# Try to get epsilon from various attribute names
eps = getattr(rms_norm, "eps", None)
@@ -153,7 +232,7 @@ def replace_rms_norm_class(
if weight is not None:
hidden_size = weight.size(0)
return RMSNorm(hidden_size=hidden_size, eps=eps)
return TransformersRMSNorm(hidden_size=hidden_size, eps=eps)
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):