testing dynamic register

This commit is contained in:
Chranos
2026-02-06 14:17:06 +08:00
parent fba02652c8
commit b702adf015
2 changed files with 93 additions and 186 deletions

View File

@@ -22,6 +22,7 @@ from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
@@ -251,17 +252,6 @@ class Base(nn.Module):
logger.info("Creating model structure on meta device...")
# DEBUG: Print config info before any modifications
logger.info("DEBUG: Config type: %s", type(self.config).__name__)
logger.info("DEBUG: text_config type: %s", type(self.text_config).__name__)
logger.info("DEBUG: hidden_size=%s, num_attention_heads=%s",
getattr(self.text_config, 'hidden_size', 'N/A'),
getattr(self.text_config, 'num_attention_heads', 'N/A'))
logger.info("DEBUG: config.head_dim=%s (before fix)",
getattr(self.config, 'head_dim', 'NOT SET'))
logger.info("DEBUG: text_config.head_dim=%s (before fix)",
getattr(self.text_config, 'head_dim', 'NOT SET'))
# Set attention implementation to vLLM's
self.text_config._attn_implementation = "vllm"
@@ -270,7 +260,6 @@ class Base(nn.Module):
# Some models may have incorrect head_dim, so we compute and set it
if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"):
correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads
logger.info("DEBUG: Computed correct_head_dim = %d", correct_head_dim)
# Check and fix head_dim in text_config
if hasattr(self.text_config, "head_dim"):
@@ -297,11 +286,6 @@ class Base(nn.Module):
# Some models also need _attn_implementation in config
self.config._attn_implementation = "vllm"
logger.info("DEBUG: config.head_dim=%s (after fix)",
getattr(self.config, 'head_dim', 'NOT SET'))
logger.info("DEBUG: text_config.head_dim=%s (after fix)",
getattr(self.text_config, 'head_dim', 'NOT SET'))
with init_on_device_without_buffers("meta"):
self.model: "PreTrainedModel" = AutoModel.from_config(
self.config,
@@ -461,7 +445,8 @@ class Base(nn.Module):
)
replaced_count += 1
elif child.__class__.__name__.endswith("RMSNorm"):
elif child.__class__.__name__.endswith("RMSNorm") and \
not isinstance(child, RMSNorm):
new_module = replace_rms_norm_class(child, self.hidden_size)
replaced_count += 1
@@ -475,64 +460,8 @@ class Base(nn.Module):
logger.info("Replaced %d modules", replaced_count)
def _add_attention_debug_hook(self) -> None:
"""Add debug hooks to capture actual tensor shapes during forward."""
# Monkey-patch apply_rotary_pos_emb in the transformers module
try:
import transformers.models.qwen2.modeling_qwen2 as qwen2_module
original_apply_rotary = qwen2_module.apply_rotary_pos_emb
def _debug_apply_rotary(q, k, cos, sin, unsqueeze_dim=1):
logger.info("DEBUG ROTARY: q.shape=%s, k.shape=%s, cos.shape=%s, sin.shape=%s",
q.shape, k.shape, cos.shape, sin.shape)
# After unsqueeze
cos_unsqueezed = cos.unsqueeze(unsqueeze_dim)
sin_unsqueezed = sin.unsqueeze(unsqueeze_dim)
logger.info("DEBUG ROTARY: after unsqueeze(%d): cos.shape=%s, sin.shape=%s",
unsqueeze_dim, cos_unsqueezed.shape, sin_unsqueezed.shape)
logger.info("DEBUG ROTARY: q dim 3 = %d, cos dim 3 = %d",
q.shape[3] if q.dim() >= 4 else -1,
cos_unsqueezed.shape[3] if cos_unsqueezed.dim() >= 4 else -1)
return original_apply_rotary(q, k, cos, sin, unsqueeze_dim)
qwen2_module.apply_rotary_pos_emb = _debug_apply_rotary
logger.info("DEBUG: Patched apply_rotary_pos_emb for debugging")
except Exception as e:
logger.warning("DEBUG: Failed to patch apply_rotary_pos_emb: %s", e)
# Also add a forward pre-hook with kwargs support
for name, module in self.model.named_modules():
if "Attention" in module.__class__.__name__:
def _debug_hook(mod, args, kwargs):
hidden = kwargs.get('hidden_states', args[0] if args else None)
if hidden is not None:
logger.info("DEBUG HOOK: Attention input hidden_states.shape=%s", hidden.shape)
logger.info("DEBUG HOOK: mod.head_dim=%s (at forward time)", getattr(mod, 'head_dim', 'NOT SET'))
# Check mod.config.head_dim
mod_config = getattr(mod, 'config', None)
if mod_config:
logger.info("DEBUG HOOK: mod.config.head_dim=%s", getattr(mod_config, 'head_dim', 'NOT SET'))
logger.info("DEBUG HOOK: mod.config id=%d, same as self.config=%s",
id(mod_config), id(mod_config) == id(mod_config))
# Try q_proj
q_proj = getattr(mod, 'q_proj', None)
if q_proj is not None:
try:
q_out = q_proj(hidden)
logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape)
head_dim = getattr(mod, 'head_dim', 128)
input_shape = hidden.shape[:-1]
hidden_shape = (*input_shape, -1, head_dim)
logger.info("DEBUG HOOK: view target shape=%s", hidden_shape)
q_viewed = q_out.view(hidden_shape)
logger.info("DEBUG HOOK: q_proj viewed shape=%s", q_viewed.shape)
q_transposed = q_viewed.transpose(1, 2)
logger.info("DEBUG HOOK: q_proj transposed shape=%s", q_transposed.shape)
except Exception as e:
logger.info("DEBUG HOOK: Error: %s", e)
module.register_forward_pre_hook(_debug_hook, with_kwargs=True)
logger.info("DEBUG: Added debug hook (with_kwargs) to %s", name)
break
"""No-op. Debug hooks removed after root cause identified."""
pass
def _fix_attention_head_dim(self) -> None:
"""
@@ -546,50 +475,36 @@ class Base(nn.Module):
correct_head_dim = self.hidden_size // getattr(
self.text_config, "num_attention_heads", 32
)
logger.info("DEBUG: _fix_attention_head_dim called, correct_head_dim=%d", correct_head_dim)
fixed_count = 0
attention_modules_found = []
rotary_modules_fixed = []
for name, module in self.model.named_modules():
module_name = module.__class__.__name__
# Fix head_dim in Attention modules
if "Attention" in module_name:
current_head_dim = getattr(module, 'head_dim', 'NOT SET')
num_heads = getattr(module, 'num_heads', 'NOT SET')
num_kv_heads = getattr(module, 'num_key_value_heads', 'NOT SET')
attention_modules_found.append(
f"{name}: head_dim={current_head_dim}, num_heads={num_heads}, num_kv_heads={num_kv_heads}"
)
# Fix head_dim if it exists and is incorrect
if hasattr(module, "head_dim"):
if module.head_dim != correct_head_dim:
logger.warning(
"DEBUG: Fixing head_dim in %s: %d -> %d",
"Fixing head_dim in %s: %d -> %d",
name, module.head_dim, correct_head_dim
)
module.head_dim = correct_head_dim
fixed_count += 1
# Fix rotary embeddings - need to recreate inv_freq buffer
# Fix rotary embeddings - recreate inv_freq buffer if needed
if "RotaryEmbedding" in module_name:
# Check if rotary embedding has wrong dimension
if hasattr(module, "inv_freq"):
current_dim = module.inv_freq.shape[0] * 2 # inv_freq is half the dim
current_dim = module.inv_freq.shape[0] * 2
if current_dim != correct_head_dim:
logger.warning(
"DEBUG: Recreating rotary embedding %s: dim %d -> %d",
"Recreating rotary embedding %s: dim %d -> %d",
name, current_dim, correct_head_dim
)
# Recreate inv_freq with correct dimension
base = getattr(module.config, 'rope_theta', 10000.0)
if hasattr(module.config, 'rope_parameters'):
base = module.config.rope_parameters.get('rope_theta', base)
device = module.inv_freq.device
# Create new inv_freq
inv_freq = 1.0 / (
base ** (
torch.arange(0, correct_head_dim, 2, dtype=torch.int64)
@@ -599,22 +514,9 @@ class Base(nn.Module):
module.register_buffer("inv_freq", inv_freq, persistent=False)
if hasattr(module, "original_inv_freq"):
module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
rotary_modules_fixed.append(name)
# Print debug info
if attention_modules_found:
logger.info("DEBUG: Found %d attention modules", len(attention_modules_found))
for info in attention_modules_found[:3]:
logger.info("DEBUG: Attention module: %s", info)
if rotary_modules_fixed:
logger.info("DEBUG: Fixed %d rotary embedding modules: %s",
len(rotary_modules_fixed), rotary_modules_fixed)
if fixed_count > 0:
logger.info("Fixed head_dim in %d attention modules", fixed_count)
else:
logger.info("DEBUG: No attention modules needed head_dim fix")
def _replace_input_embeddings(self) -> None:
"""Replace input embeddings with VocabParallelEmbedding."""
@@ -758,80 +660,6 @@ class Base(nn.Module):
# Forward through model
# Note: return_dict=False returns tuple, first element is last hidden state
# DEBUG: Print detailed model structure info before forward
if not hasattr(self, '_debug_printed'):
self._debug_printed = True
logger.info("DEBUG: === Detailed model structure debug ===")
# Print transformers version
try:
import transformers
logger.info("DEBUG: transformers version: %s", transformers.__version__)
except Exception:
pass
# Print TP world size
logger.info("DEBUG: TP world_size=%d", self.tp_group.world_size)
# Print first attention module details
for name, module in self.model.named_modules():
if "Attention" in module.__class__.__name__:
logger.info("DEBUG: First attention: %s (class=%s)", name, module.__class__.__name__)
# Print all attributes
for attr in ['head_dim', 'num_heads', 'num_key_value_heads',
'hidden_size', 'num_attention_heads',
'num_key_value_groups']:
val = getattr(module, attr, 'NOT SET')
logger.info("DEBUG: %s = %s", attr, val)
# Print rotary_emb
rotary = getattr(module, 'rotary_emb', None)
if rotary:
logger.info("DEBUG: rotary_emb: %s", type(rotary).__name__)
if hasattr(rotary, 'inv_freq'):
logger.info("DEBUG: rotary_emb.inv_freq.shape: %s", rotary.inv_freq.shape)
else:
logger.info("DEBUG: rotary_emb: None")
# Print projection shapes
for proj_name in ['q_proj', 'k_proj', 'v_proj', 'o_proj']:
proj = getattr(module, proj_name, None)
if proj:
if hasattr(proj, 'weight'):
logger.info("DEBUG: %s: type=%s, weight.shape=%s",
proj_name, type(proj).__name__,
proj.weight.shape if proj.weight is not None else 'None')
elif hasattr(proj, 'output_size'):
logger.info("DEBUG: %s: type=%s, in=%s, out=%s, out_per_part=%s",
proj_name, type(proj).__name__,
getattr(proj, 'input_size', 'N/A'),
getattr(proj, 'output_size', 'N/A'),
getattr(proj, 'output_size_per_partition', 'N/A'))
break
# Print model-level rotary_emb
model_rotary = getattr(self.model, 'rotary_emb', None)
if model_rotary:
logger.info("DEBUG: Model-level rotary_emb: %s", type(model_rotary).__name__)
if hasattr(model_rotary, 'inv_freq'):
logger.info("DEBUG: Model rotary_emb.inv_freq.shape: %s", model_rotary.inv_freq.shape)
else:
logger.info("DEBUG: No model-level rotary_emb")
# Check nested
for name, module in self.model.named_modules():
if "RotaryEmbedding" in module.__class__.__name__:
inv_freq_shape = module.inv_freq.shape if hasattr(module, 'inv_freq') else 'N/A'
logger.info("DEBUG: Found rotary at %s: inv_freq.shape=%s", name, inv_freq_shape)
break
# Print config details
for attr in ['head_dim', 'hidden_size', 'num_attention_heads', 'num_key_value_heads',
'intermediate_size', 'num_hidden_layers']:
logger.info("DEBUG: config.%s = %s", attr, getattr(self.config, attr, 'NOT SET'))
logger.info("DEBUG: === End debug ===")
with torch.no_grad():
outputs = self.model(
**model_inputs,