testing dynamic register

This commit is contained in:
Chranos
2026-02-06 13:51:02 +08:00
parent 63a1a05999
commit e1a2afd244

View File

@@ -222,6 +222,9 @@ class Base(nn.Module):
# Fix attention head_dim in case config was incorrect
self._fix_attention_head_dim()
# Add debug hook to first attention module to capture tensor shapes
self._add_attention_debug_hook()
# Replace input embeddings
self._replace_input_embeddings()
@@ -471,6 +474,36 @@ class Base(nn.Module):
_recursive_replace(self.model, "model")
logger.info("Replaced %d modules", replaced_count)
def _add_attention_debug_hook(self) -> None:
"""Add a forward pre-hook to the first attention module for debugging."""
for name, module in self.model.named_modules():
if "Attention" in module.__class__.__name__:
def _debug_hook(mod, args, kwargs=None):
hidden = args[0] if args else None
if hidden is not None:
logger.info("DEBUG HOOK: Attention input hidden_states.shape=%s", hidden.shape)
# Print q_proj output shape
q_proj = getattr(mod, 'q_proj', None)
if q_proj is not None and hidden is not None:
try:
q_out = q_proj(hidden)
logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape)
head_dim = getattr(mod, 'head_dim', 'NOT SET')
num_heads = getattr(mod, 'num_heads', 'NOT SET')
logger.info("DEBUG HOOK: Will try view with num_heads=%s, head_dim=%s",
num_heads, head_dim)
if isinstance(head_dim, int) and isinstance(num_heads, int):
expected = num_heads * head_dim
actual = q_out.shape[-1]
logger.info("DEBUG HOOK: q_proj output last dim=%d, expected (num_heads*head_dim)=%d, match=%s",
actual, expected, actual == expected)
except Exception as e:
logger.info("DEBUG HOOK: Error testing q_proj: %s", e)
module.register_forward_pre_hook(_debug_hook)
logger.info("DEBUG: Added debug hook to %s", name)
break
def _fix_attention_head_dim(self) -> None:
"""
Fix head_dim in attention modules and rotary embeddings after model creation.
@@ -696,19 +729,78 @@ class Base(nn.Module):
# Forward through model
# Note: return_dict=False returns tuple, first element is last hidden state
# DEBUG: Print attention module head_dim values just before forward
logger.info("DEBUG: Checking attention modules before forward...")
for name, module in self.model.named_modules():
if "Attention" in module.__class__.__name__:
head_dim = getattr(module, 'head_dim', 'NOT SET')
rotary_emb = getattr(module, 'rotary_emb', None)
if rotary_emb:
emb_dim = getattr(rotary_emb, 'dim', 'N/A')
logger.info("DEBUG: %s: head_dim=%s, rotary_emb.dim=%s",
name, head_dim, emb_dim)
else:
logger.info("DEBUG: %s: head_dim=%s, rotary_emb=None", name, head_dim)
break # Just print first one
# DEBUG: Print detailed model structure info before forward
if not hasattr(self, '_debug_printed'):
self._debug_printed = True
logger.info("DEBUG: === Detailed model structure debug ===")
# Print transformers version
try:
import transformers
logger.info("DEBUG: transformers version: %s", transformers.__version__)
except Exception:
pass
# Print TP world size
logger.info("DEBUG: TP world_size=%d", self.tp_group.world_size)
# Print first attention module details
for name, module in self.model.named_modules():
if "Attention" in module.__class__.__name__:
logger.info("DEBUG: First attention: %s (class=%s)", name, module.__class__.__name__)
# Print all attributes
for attr in ['head_dim', 'num_heads', 'num_key_value_heads',
'hidden_size', 'num_attention_heads',
'num_key_value_groups']:
val = getattr(module, attr, 'NOT SET')
logger.info("DEBUG: %s = %s", attr, val)
# Print rotary_emb
rotary = getattr(module, 'rotary_emb', None)
if rotary:
logger.info("DEBUG: rotary_emb: %s", type(rotary).__name__)
if hasattr(rotary, 'inv_freq'):
logger.info("DEBUG: rotary_emb.inv_freq.shape: %s", rotary.inv_freq.shape)
else:
logger.info("DEBUG: rotary_emb: None")
# Print projection shapes
for proj_name in ['q_proj', 'k_proj', 'v_proj', 'o_proj']:
proj = getattr(module, proj_name, None)
if proj:
if hasattr(proj, 'weight'):
logger.info("DEBUG: %s: type=%s, weight.shape=%s",
proj_name, type(proj).__name__,
proj.weight.shape if proj.weight is not None else 'None')
elif hasattr(proj, 'output_size'):
logger.info("DEBUG: %s: type=%s, in=%s, out=%s, out_per_part=%s",
proj_name, type(proj).__name__,
getattr(proj, 'input_size', 'N/A'),
getattr(proj, 'output_size', 'N/A'),
getattr(proj, 'output_size_per_partition', 'N/A'))
break
# Print model-level rotary_emb
model_rotary = getattr(self.model, 'rotary_emb', None)
if model_rotary:
logger.info("DEBUG: Model-level rotary_emb: %s", type(model_rotary).__name__)
if hasattr(model_rotary, 'inv_freq'):
logger.info("DEBUG: Model rotary_emb.inv_freq.shape: %s", model_rotary.inv_freq.shape)
else:
logger.info("DEBUG: No model-level rotary_emb")
# Check nested
for name, module in self.model.named_modules():
if "RotaryEmbedding" in module.__class__.__name__:
inv_freq_shape = module.inv_freq.shape if hasattr(module, 'inv_freq') else 'N/A'
logger.info("DEBUG: Found rotary at %s: inv_freq.shape=%s", name, inv_freq_shape)
break
# Print config details
for attr in ['head_dim', 'hidden_size', 'num_attention_heads', 'num_key_value_heads',
'intermediate_size', 'num_hidden_layers']:
logger.info("DEBUG: config.%s = %s", attr, getattr(self.config, attr, 'NOT SET'))
logger.info("DEBUG: === End debug ===")
with torch.no_grad():
outputs = self.model(