From f088a6b45d4c752ffba171068f6e465c1d928fa3 Mon Sep 17 00:00:00 2001 From: Chranos <826995883@qq.com> Date: Fri, 6 Feb 2026 13:39:13 +0800 Subject: [PATCH] testing dynamic register --- .../models/transformers/base.py | 96 +++++++++++++++++-- 1 file changed, 90 insertions(+), 6 deletions(-) diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py index 24b73eb..f61bd8f 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py @@ -248,6 +248,17 @@ class Base(nn.Module): logger.info("Creating model structure on meta device...") + # DEBUG: Print config info before any modifications + logger.info("DEBUG: Config type: %s", type(self.config).__name__) + logger.info("DEBUG: text_config type: %s", type(self.text_config).__name__) + logger.info("DEBUG: hidden_size=%s, num_attention_heads=%s", + getattr(self.text_config, 'hidden_size', 'N/A'), + getattr(self.text_config, 'num_attention_heads', 'N/A')) + logger.info("DEBUG: config.head_dim=%s (before fix)", + getattr(self.config, 'head_dim', 'NOT SET')) + logger.info("DEBUG: text_config.head_dim=%s (before fix)", + getattr(self.text_config, 'head_dim', 'NOT SET')) + # Set attention implementation to vLLM's self.text_config._attn_implementation = "vllm" @@ -256,6 +267,7 @@ class Base(nn.Module): # Some models may have incorrect head_dim, so we compute and set it if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"): correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads + logger.info("DEBUG: Computed correct_head_dim = %d", correct_head_dim) # Check and fix head_dim in text_config if hasattr(self.text_config, "head_dim"): @@ -282,6 +294,11 @@ class Base(nn.Module): # Some models also need _attn_implementation in config self.config._attn_implementation = "vllm" + logger.info("DEBUG: config.head_dim=%s (after fix)", + getattr(self.config, 'head_dim', 'NOT SET')) + logger.info("DEBUG: text_config.head_dim=%s (after fix)", + getattr(self.text_config, 'head_dim', 'NOT SET')) + with init_on_device_without_buffers("meta"): self.model: "PreTrainedModel" = AutoModel.from_config( self.config, @@ -456,33 +473,85 @@ class Base(nn.Module): def _fix_attention_head_dim(self) -> None: """ - Fix head_dim in attention modules after model creation. + Fix head_dim in attention modules and rotary embeddings after model creation. Some models may have incorrect head_dim in config, which causes - Transformers attention modules to use wrong dimensions for RoPE. - This method corrects head_dim in all attention modules. + Transformers attention modules and RoPE to use wrong dimensions. + This method corrects head_dim in all attention modules and recreates + rotary embeddings if needed. """ correct_head_dim = self.hidden_size // getattr( self.text_config, "num_attention_heads", 32 ) + logger.info("DEBUG: _fix_attention_head_dim called, correct_head_dim=%d", correct_head_dim) fixed_count = 0 + attention_modules_found = [] + rotary_modules_fixed = [] + for name, module in self.model.named_modules(): - # Check if this is an attention module module_name = module.__class__.__name__ + + # Fix head_dim in Attention modules if "Attention" in module_name: + current_head_dim = getattr(module, 'head_dim', 'NOT SET') + num_heads = getattr(module, 'num_heads', 'NOT SET') + num_kv_heads = getattr(module, 'num_key_value_heads', 'NOT SET') + attention_modules_found.append( + f"{name}: head_dim={current_head_dim}, num_heads={num_heads}, num_kv_heads={num_kv_heads}" + ) + # Fix head_dim if it exists and is incorrect if hasattr(module, "head_dim"): if module.head_dim != correct_head_dim: - logger.debug( - "Fixing head_dim in %s: %d -> %d", + logger.warning( + "DEBUG: Fixing head_dim in %s: %d -> %d", name, module.head_dim, correct_head_dim ) module.head_dim = correct_head_dim fixed_count += 1 + + # Fix rotary embeddings - need to recreate inv_freq buffer + if "RotaryEmbedding" in module_name: + # Check if rotary embedding has wrong dimension + if hasattr(module, "inv_freq"): + current_dim = module.inv_freq.shape[0] * 2 # inv_freq is half the dim + if current_dim != correct_head_dim: + logger.warning( + "DEBUG: Recreating rotary embedding %s: dim %d -> %d", + name, current_dim, correct_head_dim + ) + # Recreate inv_freq with correct dimension + base = getattr(module.config, 'rope_theta', 10000.0) + if hasattr(module.config, 'rope_parameters'): + base = module.config.rope_parameters.get('rope_theta', base) + device = module.inv_freq.device + # Create new inv_freq + inv_freq = 1.0 / ( + base ** ( + torch.arange(0, correct_head_dim, 2, dtype=torch.int64) + .to(device=device, dtype=torch.float) / correct_head_dim + ) + ) + module.register_buffer("inv_freq", inv_freq, persistent=False) + if hasattr(module, "original_inv_freq"): + module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + rotary_modules_fixed.append(name) + + # Print debug info + if attention_modules_found: + logger.info("DEBUG: Found %d attention modules", len(attention_modules_found)) + for info in attention_modules_found[:3]: + logger.info("DEBUG: Attention module: %s", info) + + if rotary_modules_fixed: + logger.info("DEBUG: Fixed %d rotary embedding modules: %s", + len(rotary_modules_fixed), rotary_modules_fixed) if fixed_count > 0: logger.info("Fixed head_dim in %d attention modules", fixed_count) + else: + logger.info("DEBUG: No attention modules needed head_dim fix") def _replace_input_embeddings(self) -> None: """Replace input embeddings with VocabParallelEmbedding.""" @@ -626,6 +695,21 @@ class Base(nn.Module): # Forward through model # Note: return_dict=False returns tuple, first element is last hidden state + + # DEBUG: Print attention module head_dim values just before forward + logger.info("DEBUG: Checking attention modules before forward...") + for name, module in self.model.named_modules(): + if "Attention" in module.__class__.__name__: + head_dim = getattr(module, 'head_dim', 'NOT SET') + rotary_emb = getattr(module, 'rotary_emb', None) + if rotary_emb: + emb_dim = getattr(rotary_emb, 'dim', 'N/A') + logger.info("DEBUG: %s: head_dim=%s, rotary_emb.dim=%s", + name, head_dim, emb_dim) + else: + logger.info("DEBUG: %s: head_dim=%s, rotary_emb=None", name, head_dim) + break # Just print first one + with torch.no_grad(): outputs = self.model( **model_inputs,