testing dynamic register
This commit is contained in:
@@ -222,6 +222,9 @@ class Base(nn.Module):
|
|||||||
# Fix attention head_dim in case config was incorrect
|
# Fix attention head_dim in case config was incorrect
|
||||||
self._fix_attention_head_dim()
|
self._fix_attention_head_dim()
|
||||||
|
|
||||||
|
# Add debug hook to first attention module to capture tensor shapes
|
||||||
|
self._add_attention_debug_hook()
|
||||||
|
|
||||||
# Replace input embeddings
|
# Replace input embeddings
|
||||||
self._replace_input_embeddings()
|
self._replace_input_embeddings()
|
||||||
|
|
||||||
@@ -471,6 +474,36 @@ class Base(nn.Module):
|
|||||||
_recursive_replace(self.model, "model")
|
_recursive_replace(self.model, "model")
|
||||||
logger.info("Replaced %d modules", replaced_count)
|
logger.info("Replaced %d modules", replaced_count)
|
||||||
|
|
||||||
|
def _add_attention_debug_hook(self) -> None:
|
||||||
|
"""Add a forward pre-hook to the first attention module for debugging."""
|
||||||
|
for name, module in self.model.named_modules():
|
||||||
|
if "Attention" in module.__class__.__name__:
|
||||||
|
def _debug_hook(mod, args, kwargs=None):
|
||||||
|
hidden = args[0] if args else None
|
||||||
|
if hidden is not None:
|
||||||
|
logger.info("DEBUG HOOK: Attention input hidden_states.shape=%s", hidden.shape)
|
||||||
|
# Print q_proj output shape
|
||||||
|
q_proj = getattr(mod, 'q_proj', None)
|
||||||
|
if q_proj is not None and hidden is not None:
|
||||||
|
try:
|
||||||
|
q_out = q_proj(hidden)
|
||||||
|
logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape)
|
||||||
|
head_dim = getattr(mod, 'head_dim', 'NOT SET')
|
||||||
|
num_heads = getattr(mod, 'num_heads', 'NOT SET')
|
||||||
|
logger.info("DEBUG HOOK: Will try view with num_heads=%s, head_dim=%s",
|
||||||
|
num_heads, head_dim)
|
||||||
|
if isinstance(head_dim, int) and isinstance(num_heads, int):
|
||||||
|
expected = num_heads * head_dim
|
||||||
|
actual = q_out.shape[-1]
|
||||||
|
logger.info("DEBUG HOOK: q_proj output last dim=%d, expected (num_heads*head_dim)=%d, match=%s",
|
||||||
|
actual, expected, actual == expected)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info("DEBUG HOOK: Error testing q_proj: %s", e)
|
||||||
|
|
||||||
|
module.register_forward_pre_hook(_debug_hook)
|
||||||
|
logger.info("DEBUG: Added debug hook to %s", name)
|
||||||
|
break
|
||||||
|
|
||||||
def _fix_attention_head_dim(self) -> None:
|
def _fix_attention_head_dim(self) -> None:
|
||||||
"""
|
"""
|
||||||
Fix head_dim in attention modules and rotary embeddings after model creation.
|
Fix head_dim in attention modules and rotary embeddings after model creation.
|
||||||
@@ -696,19 +729,78 @@ class Base(nn.Module):
|
|||||||
# Forward through model
|
# Forward through model
|
||||||
# Note: return_dict=False returns tuple, first element is last hidden state
|
# Note: return_dict=False returns tuple, first element is last hidden state
|
||||||
|
|
||||||
# DEBUG: Print attention module head_dim values just before forward
|
# DEBUG: Print detailed model structure info before forward
|
||||||
logger.info("DEBUG: Checking attention modules before forward...")
|
if not hasattr(self, '_debug_printed'):
|
||||||
for name, module in self.model.named_modules():
|
self._debug_printed = True
|
||||||
if "Attention" in module.__class__.__name__:
|
logger.info("DEBUG: === Detailed model structure debug ===")
|
||||||
head_dim = getattr(module, 'head_dim', 'NOT SET')
|
|
||||||
rotary_emb = getattr(module, 'rotary_emb', None)
|
# Print transformers version
|
||||||
if rotary_emb:
|
try:
|
||||||
emb_dim = getattr(rotary_emb, 'dim', 'N/A')
|
import transformers
|
||||||
logger.info("DEBUG: %s: head_dim=%s, rotary_emb.dim=%s",
|
logger.info("DEBUG: transformers version: %s", transformers.__version__)
|
||||||
name, head_dim, emb_dim)
|
except Exception:
|
||||||
else:
|
pass
|
||||||
logger.info("DEBUG: %s: head_dim=%s, rotary_emb=None", name, head_dim)
|
|
||||||
break # Just print first one
|
# Print TP world size
|
||||||
|
logger.info("DEBUG: TP world_size=%d", self.tp_group.world_size)
|
||||||
|
|
||||||
|
# Print first attention module details
|
||||||
|
for name, module in self.model.named_modules():
|
||||||
|
if "Attention" in module.__class__.__name__:
|
||||||
|
logger.info("DEBUG: First attention: %s (class=%s)", name, module.__class__.__name__)
|
||||||
|
# Print all attributes
|
||||||
|
for attr in ['head_dim', 'num_heads', 'num_key_value_heads',
|
||||||
|
'hidden_size', 'num_attention_heads',
|
||||||
|
'num_key_value_groups']:
|
||||||
|
val = getattr(module, attr, 'NOT SET')
|
||||||
|
logger.info("DEBUG: %s = %s", attr, val)
|
||||||
|
|
||||||
|
# Print rotary_emb
|
||||||
|
rotary = getattr(module, 'rotary_emb', None)
|
||||||
|
if rotary:
|
||||||
|
logger.info("DEBUG: rotary_emb: %s", type(rotary).__name__)
|
||||||
|
if hasattr(rotary, 'inv_freq'):
|
||||||
|
logger.info("DEBUG: rotary_emb.inv_freq.shape: %s", rotary.inv_freq.shape)
|
||||||
|
else:
|
||||||
|
logger.info("DEBUG: rotary_emb: None")
|
||||||
|
|
||||||
|
# Print projection shapes
|
||||||
|
for proj_name in ['q_proj', 'k_proj', 'v_proj', 'o_proj']:
|
||||||
|
proj = getattr(module, proj_name, None)
|
||||||
|
if proj:
|
||||||
|
if hasattr(proj, 'weight'):
|
||||||
|
logger.info("DEBUG: %s: type=%s, weight.shape=%s",
|
||||||
|
proj_name, type(proj).__name__,
|
||||||
|
proj.weight.shape if proj.weight is not None else 'None')
|
||||||
|
elif hasattr(proj, 'output_size'):
|
||||||
|
logger.info("DEBUG: %s: type=%s, in=%s, out=%s, out_per_part=%s",
|
||||||
|
proj_name, type(proj).__name__,
|
||||||
|
getattr(proj, 'input_size', 'N/A'),
|
||||||
|
getattr(proj, 'output_size', 'N/A'),
|
||||||
|
getattr(proj, 'output_size_per_partition', 'N/A'))
|
||||||
|
break
|
||||||
|
|
||||||
|
# Print model-level rotary_emb
|
||||||
|
model_rotary = getattr(self.model, 'rotary_emb', None)
|
||||||
|
if model_rotary:
|
||||||
|
logger.info("DEBUG: Model-level rotary_emb: %s", type(model_rotary).__name__)
|
||||||
|
if hasattr(model_rotary, 'inv_freq'):
|
||||||
|
logger.info("DEBUG: Model rotary_emb.inv_freq.shape: %s", model_rotary.inv_freq.shape)
|
||||||
|
else:
|
||||||
|
logger.info("DEBUG: No model-level rotary_emb")
|
||||||
|
# Check nested
|
||||||
|
for name, module in self.model.named_modules():
|
||||||
|
if "RotaryEmbedding" in module.__class__.__name__:
|
||||||
|
inv_freq_shape = module.inv_freq.shape if hasattr(module, 'inv_freq') else 'N/A'
|
||||||
|
logger.info("DEBUG: Found rotary at %s: inv_freq.shape=%s", name, inv_freq_shape)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Print config details
|
||||||
|
for attr in ['head_dim', 'hidden_size', 'num_attention_heads', 'num_key_value_heads',
|
||||||
|
'intermediate_size', 'num_hidden_layers']:
|
||||||
|
logger.info("DEBUG: config.%s = %s", attr, getattr(self.config, attr, 'NOT SET'))
|
||||||
|
|
||||||
|
logger.info("DEBUG: === End debug ===")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
|
|||||||
Reference in New Issue
Block a user