From fba02652c8b0ac3de9cfbd6a1c7e7cd280b85fb2 Mon Sep 17 00:00:00 2001 From: Chranos <826995883@qq.com> Date: Fri, 6 Feb 2026 14:04:04 +0800 Subject: [PATCH] testing dynamic register --- .../models/transformers/base.py | 74 +++++++++++++------ 1 file changed, 52 insertions(+), 22 deletions(-) diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py index 9cdeb99..a30dde2 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py @@ -475,33 +475,63 @@ class Base(nn.Module): logger.info("Replaced %d modules", replaced_count) def _add_attention_debug_hook(self) -> None: - """Add a forward pre-hook to the first attention module for debugging.""" + """Add debug hooks to capture actual tensor shapes during forward.""" + # Monkey-patch apply_rotary_pos_emb in the transformers module + try: + import transformers.models.qwen2.modeling_qwen2 as qwen2_module + original_apply_rotary = qwen2_module.apply_rotary_pos_emb + + def _debug_apply_rotary(q, k, cos, sin, unsqueeze_dim=1): + logger.info("DEBUG ROTARY: q.shape=%s, k.shape=%s, cos.shape=%s, sin.shape=%s", + q.shape, k.shape, cos.shape, sin.shape) + # After unsqueeze + cos_unsqueezed = cos.unsqueeze(unsqueeze_dim) + sin_unsqueezed = sin.unsqueeze(unsqueeze_dim) + logger.info("DEBUG ROTARY: after unsqueeze(%d): cos.shape=%s, sin.shape=%s", + unsqueeze_dim, cos_unsqueezed.shape, sin_unsqueezed.shape) + logger.info("DEBUG ROTARY: q dim 3 = %d, cos dim 3 = %d", + q.shape[3] if q.dim() >= 4 else -1, + cos_unsqueezed.shape[3] if cos_unsqueezed.dim() >= 4 else -1) + return original_apply_rotary(q, k, cos, sin, unsqueeze_dim) + + qwen2_module.apply_rotary_pos_emb = _debug_apply_rotary + logger.info("DEBUG: Patched apply_rotary_pos_emb for debugging") + except Exception as e: + logger.warning("DEBUG: Failed to patch apply_rotary_pos_emb: %s", e) + + # Also add a forward pre-hook with kwargs support for name, module in self.model.named_modules(): if "Attention" in module.__class__.__name__: - def _debug_hook(mod, args, kwargs=None): - hidden = args[0] if args else None + def _debug_hook(mod, args, kwargs): + hidden = kwargs.get('hidden_states', args[0] if args else None) if hidden is not None: logger.info("DEBUG HOOK: Attention input hidden_states.shape=%s", hidden.shape) - # Print q_proj output shape - q_proj = getattr(mod, 'q_proj', None) - if q_proj is not None and hidden is not None: - try: - q_out = q_proj(hidden) - logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape) - head_dim = getattr(mod, 'head_dim', 'NOT SET') - num_heads = getattr(mod, 'num_heads', 'NOT SET') - logger.info("DEBUG HOOK: Will try view with num_heads=%s, head_dim=%s", - num_heads, head_dim) - if isinstance(head_dim, int) and isinstance(num_heads, int): - expected = num_heads * head_dim - actual = q_out.shape[-1] - logger.info("DEBUG HOOK: q_proj output last dim=%d, expected (num_heads*head_dim)=%d, match=%s", - actual, expected, actual == expected) - except Exception as e: - logger.info("DEBUG HOOK: Error testing q_proj: %s", e) + logger.info("DEBUG HOOK: mod.head_dim=%s (at forward time)", getattr(mod, 'head_dim', 'NOT SET')) + # Check mod.config.head_dim + mod_config = getattr(mod, 'config', None) + if mod_config: + logger.info("DEBUG HOOK: mod.config.head_dim=%s", getattr(mod_config, 'head_dim', 'NOT SET')) + logger.info("DEBUG HOOK: mod.config id=%d, same as self.config=%s", + id(mod_config), id(mod_config) == id(mod_config)) + # Try q_proj + q_proj = getattr(mod, 'q_proj', None) + if q_proj is not None: + try: + q_out = q_proj(hidden) + logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape) + head_dim = getattr(mod, 'head_dim', 128) + input_shape = hidden.shape[:-1] + hidden_shape = (*input_shape, -1, head_dim) + logger.info("DEBUG HOOK: view target shape=%s", hidden_shape) + q_viewed = q_out.view(hidden_shape) + logger.info("DEBUG HOOK: q_proj viewed shape=%s", q_viewed.shape) + q_transposed = q_viewed.transpose(1, 2) + logger.info("DEBUG HOOK: q_proj transposed shape=%s", q_transposed.shape) + except Exception as e: + logger.info("DEBUG HOOK: Error: %s", e) - module.register_forward_pre_hook(_debug_hook) - logger.info("DEBUG: Added debug hook to %s", name) + module.register_forward_pre_hook(_debug_hook, with_kwargs=True) + logger.info("DEBUG: Added debug hook (with_kwargs) to %s", name) break def _fix_attention_head_dim(self) -> None: