testing dynamic register

This commit is contained in:
Chranos
2026-02-06 14:04:04 +08:00
parent 5d2f4000cc
commit fba02652c8

View File

@@ -475,33 +475,63 @@ class Base(nn.Module):
logger.info("Replaced %d modules", replaced_count)
def _add_attention_debug_hook(self) -> None:
"""Add a forward pre-hook to the first attention module for debugging."""
"""Add debug hooks to capture actual tensor shapes during forward."""
# Monkey-patch apply_rotary_pos_emb in the transformers module
try:
import transformers.models.qwen2.modeling_qwen2 as qwen2_module
original_apply_rotary = qwen2_module.apply_rotary_pos_emb
def _debug_apply_rotary(q, k, cos, sin, unsqueeze_dim=1):
logger.info("DEBUG ROTARY: q.shape=%s, k.shape=%s, cos.shape=%s, sin.shape=%s",
q.shape, k.shape, cos.shape, sin.shape)
# After unsqueeze
cos_unsqueezed = cos.unsqueeze(unsqueeze_dim)
sin_unsqueezed = sin.unsqueeze(unsqueeze_dim)
logger.info("DEBUG ROTARY: after unsqueeze(%d): cos.shape=%s, sin.shape=%s",
unsqueeze_dim, cos_unsqueezed.shape, sin_unsqueezed.shape)
logger.info("DEBUG ROTARY: q dim 3 = %d, cos dim 3 = %d",
q.shape[3] if q.dim() >= 4 else -1,
cos_unsqueezed.shape[3] if cos_unsqueezed.dim() >= 4 else -1)
return original_apply_rotary(q, k, cos, sin, unsqueeze_dim)
qwen2_module.apply_rotary_pos_emb = _debug_apply_rotary
logger.info("DEBUG: Patched apply_rotary_pos_emb for debugging")
except Exception as e:
logger.warning("DEBUG: Failed to patch apply_rotary_pos_emb: %s", e)
# Also add a forward pre-hook with kwargs support
for name, module in self.model.named_modules():
if "Attention" in module.__class__.__name__:
def _debug_hook(mod, args, kwargs=None):
hidden = args[0] if args else None
def _debug_hook(mod, args, kwargs):
hidden = kwargs.get('hidden_states', args[0] if args else None)
if hidden is not None:
logger.info("DEBUG HOOK: Attention input hidden_states.shape=%s", hidden.shape)
# Print q_proj output shape
q_proj = getattr(mod, 'q_proj', None)
if q_proj is not None and hidden is not None:
try:
q_out = q_proj(hidden)
logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape)
head_dim = getattr(mod, 'head_dim', 'NOT SET')
num_heads = getattr(mod, 'num_heads', 'NOT SET')
logger.info("DEBUG HOOK: Will try view with num_heads=%s, head_dim=%s",
num_heads, head_dim)
if isinstance(head_dim, int) and isinstance(num_heads, int):
expected = num_heads * head_dim
actual = q_out.shape[-1]
logger.info("DEBUG HOOK: q_proj output last dim=%d, expected (num_heads*head_dim)=%d, match=%s",
actual, expected, actual == expected)
except Exception as e:
logger.info("DEBUG HOOK: Error testing q_proj: %s", e)
logger.info("DEBUG HOOK: mod.head_dim=%s (at forward time)", getattr(mod, 'head_dim', 'NOT SET'))
# Check mod.config.head_dim
mod_config = getattr(mod, 'config', None)
if mod_config:
logger.info("DEBUG HOOK: mod.config.head_dim=%s", getattr(mod_config, 'head_dim', 'NOT SET'))
logger.info("DEBUG HOOK: mod.config id=%d, same as self.config=%s",
id(mod_config), id(mod_config) == id(mod_config))
# Try q_proj
q_proj = getattr(mod, 'q_proj', None)
if q_proj is not None:
try:
q_out = q_proj(hidden)
logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape)
head_dim = getattr(mod, 'head_dim', 128)
input_shape = hidden.shape[:-1]
hidden_shape = (*input_shape, -1, head_dim)
logger.info("DEBUG HOOK: view target shape=%s", hidden_shape)
q_viewed = q_out.view(hidden_shape)
logger.info("DEBUG HOOK: q_proj viewed shape=%s", q_viewed.shape)
q_transposed = q_viewed.transpose(1, 2)
logger.info("DEBUG HOOK: q_proj transposed shape=%s", q_transposed.shape)
except Exception as e:
logger.info("DEBUG HOOK: Error: %s", e)
module.register_forward_pre_hook(_debug_hook)
logger.info("DEBUG: Added debug hook to %s", name)
module.register_forward_pre_hook(_debug_hook, with_kwargs=True)
logger.info("DEBUG: Added debug hook (with_kwargs) to %s", name)
break
def _fix_attention_head_dim(self) -> None: