testing dynamic register
This commit is contained in:
@@ -222,6 +222,9 @@ class Base(nn.Module):
|
||||
# Fix attention head_dim in case config was incorrect
|
||||
self._fix_attention_head_dim()
|
||||
|
||||
# Add debug hook to first attention module to capture tensor shapes
|
||||
self._add_attention_debug_hook()
|
||||
|
||||
# Replace input embeddings
|
||||
self._replace_input_embeddings()
|
||||
|
||||
@@ -471,6 +474,36 @@ class Base(nn.Module):
|
||||
_recursive_replace(self.model, "model")
|
||||
logger.info("Replaced %d modules", replaced_count)
|
||||
|
||||
def _add_attention_debug_hook(self) -> None:
|
||||
"""Add a forward pre-hook to the first attention module for debugging."""
|
||||
for name, module in self.model.named_modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
def _debug_hook(mod, args, kwargs=None):
|
||||
hidden = args[0] if args else None
|
||||
if hidden is not None:
|
||||
logger.info("DEBUG HOOK: Attention input hidden_states.shape=%s", hidden.shape)
|
||||
# Print q_proj output shape
|
||||
q_proj = getattr(mod, 'q_proj', None)
|
||||
if q_proj is not None and hidden is not None:
|
||||
try:
|
||||
q_out = q_proj(hidden)
|
||||
logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape)
|
||||
head_dim = getattr(mod, 'head_dim', 'NOT SET')
|
||||
num_heads = getattr(mod, 'num_heads', 'NOT SET')
|
||||
logger.info("DEBUG HOOK: Will try view with num_heads=%s, head_dim=%s",
|
||||
num_heads, head_dim)
|
||||
if isinstance(head_dim, int) and isinstance(num_heads, int):
|
||||
expected = num_heads * head_dim
|
||||
actual = q_out.shape[-1]
|
||||
logger.info("DEBUG HOOK: q_proj output last dim=%d, expected (num_heads*head_dim)=%d, match=%s",
|
||||
actual, expected, actual == expected)
|
||||
except Exception as e:
|
||||
logger.info("DEBUG HOOK: Error testing q_proj: %s", e)
|
||||
|
||||
module.register_forward_pre_hook(_debug_hook)
|
||||
logger.info("DEBUG: Added debug hook to %s", name)
|
||||
break
|
||||
|
||||
def _fix_attention_head_dim(self) -> None:
|
||||
"""
|
||||
Fix head_dim in attention modules and rotary embeddings after model creation.
|
||||
@@ -696,19 +729,78 @@ class Base(nn.Module):
|
||||
# Forward through model
|
||||
# Note: return_dict=False returns tuple, first element is last hidden state
|
||||
|
||||
# DEBUG: Print attention module head_dim values just before forward
|
||||
logger.info("DEBUG: Checking attention modules before forward...")
|
||||
for name, module in self.model.named_modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
head_dim = getattr(module, 'head_dim', 'NOT SET')
|
||||
rotary_emb = getattr(module, 'rotary_emb', None)
|
||||
if rotary_emb:
|
||||
emb_dim = getattr(rotary_emb, 'dim', 'N/A')
|
||||
logger.info("DEBUG: %s: head_dim=%s, rotary_emb.dim=%s",
|
||||
name, head_dim, emb_dim)
|
||||
else:
|
||||
logger.info("DEBUG: %s: head_dim=%s, rotary_emb=None", name, head_dim)
|
||||
break # Just print first one
|
||||
# DEBUG: Print detailed model structure info before forward
|
||||
if not hasattr(self, '_debug_printed'):
|
||||
self._debug_printed = True
|
||||
logger.info("DEBUG: === Detailed model structure debug ===")
|
||||
|
||||
# Print transformers version
|
||||
try:
|
||||
import transformers
|
||||
logger.info("DEBUG: transformers version: %s", transformers.__version__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Print TP world size
|
||||
logger.info("DEBUG: TP world_size=%d", self.tp_group.world_size)
|
||||
|
||||
# Print first attention module details
|
||||
for name, module in self.model.named_modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
logger.info("DEBUG: First attention: %s (class=%s)", name, module.__class__.__name__)
|
||||
# Print all attributes
|
||||
for attr in ['head_dim', 'num_heads', 'num_key_value_heads',
|
||||
'hidden_size', 'num_attention_heads',
|
||||
'num_key_value_groups']:
|
||||
val = getattr(module, attr, 'NOT SET')
|
||||
logger.info("DEBUG: %s = %s", attr, val)
|
||||
|
||||
# Print rotary_emb
|
||||
rotary = getattr(module, 'rotary_emb', None)
|
||||
if rotary:
|
||||
logger.info("DEBUG: rotary_emb: %s", type(rotary).__name__)
|
||||
if hasattr(rotary, 'inv_freq'):
|
||||
logger.info("DEBUG: rotary_emb.inv_freq.shape: %s", rotary.inv_freq.shape)
|
||||
else:
|
||||
logger.info("DEBUG: rotary_emb: None")
|
||||
|
||||
# Print projection shapes
|
||||
for proj_name in ['q_proj', 'k_proj', 'v_proj', 'o_proj']:
|
||||
proj = getattr(module, proj_name, None)
|
||||
if proj:
|
||||
if hasattr(proj, 'weight'):
|
||||
logger.info("DEBUG: %s: type=%s, weight.shape=%s",
|
||||
proj_name, type(proj).__name__,
|
||||
proj.weight.shape if proj.weight is not None else 'None')
|
||||
elif hasattr(proj, 'output_size'):
|
||||
logger.info("DEBUG: %s: type=%s, in=%s, out=%s, out_per_part=%s",
|
||||
proj_name, type(proj).__name__,
|
||||
getattr(proj, 'input_size', 'N/A'),
|
||||
getattr(proj, 'output_size', 'N/A'),
|
||||
getattr(proj, 'output_size_per_partition', 'N/A'))
|
||||
break
|
||||
|
||||
# Print model-level rotary_emb
|
||||
model_rotary = getattr(self.model, 'rotary_emb', None)
|
||||
if model_rotary:
|
||||
logger.info("DEBUG: Model-level rotary_emb: %s", type(model_rotary).__name__)
|
||||
if hasattr(model_rotary, 'inv_freq'):
|
||||
logger.info("DEBUG: Model rotary_emb.inv_freq.shape: %s", model_rotary.inv_freq.shape)
|
||||
else:
|
||||
logger.info("DEBUG: No model-level rotary_emb")
|
||||
# Check nested
|
||||
for name, module in self.model.named_modules():
|
||||
if "RotaryEmbedding" in module.__class__.__name__:
|
||||
inv_freq_shape = module.inv_freq.shape if hasattr(module, 'inv_freq') else 'N/A'
|
||||
logger.info("DEBUG: Found rotary at %s: inv_freq.shape=%s", name, inv_freq_shape)
|
||||
break
|
||||
|
||||
# Print config details
|
||||
for attr in ['head_dim', 'hidden_size', 'num_attention_heads', 'num_key_value_heads',
|
||||
'intermediate_size', 'num_hidden_layers']:
|
||||
logger.info("DEBUG: config.%s = %s", attr, getattr(self.config, attr, 'NOT SET'))
|
||||
|
||||
logger.info("DEBUG: === End debug ===")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(
|
||||
|
||||
Reference in New Issue
Block a user