testing dynamic register

This commit is contained in:
Chranos
2026-02-06 13:39:13 +08:00
parent d31ace279b
commit f088a6b45d

View File

@@ -248,6 +248,17 @@ class Base(nn.Module):
logger.info("Creating model structure on meta device...")
# DEBUG: Print config info before any modifications
logger.info("DEBUG: Config type: %s", type(self.config).__name__)
logger.info("DEBUG: text_config type: %s", type(self.text_config).__name__)
logger.info("DEBUG: hidden_size=%s, num_attention_heads=%s",
getattr(self.text_config, 'hidden_size', 'N/A'),
getattr(self.text_config, 'num_attention_heads', 'N/A'))
logger.info("DEBUG: config.head_dim=%s (before fix)",
getattr(self.config, 'head_dim', 'NOT SET'))
logger.info("DEBUG: text_config.head_dim=%s (before fix)",
getattr(self.text_config, 'head_dim', 'NOT SET'))
# Set attention implementation to vLLM's
self.text_config._attn_implementation = "vllm"
@@ -256,6 +267,7 @@ class Base(nn.Module):
# Some models may have incorrect head_dim, so we compute and set it
if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"):
correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads
logger.info("DEBUG: Computed correct_head_dim = %d", correct_head_dim)
# Check and fix head_dim in text_config
if hasattr(self.text_config, "head_dim"):
@@ -282,6 +294,11 @@ class Base(nn.Module):
# Some models also need _attn_implementation in config
self.config._attn_implementation = "vllm"
logger.info("DEBUG: config.head_dim=%s (after fix)",
getattr(self.config, 'head_dim', 'NOT SET'))
logger.info("DEBUG: text_config.head_dim=%s (after fix)",
getattr(self.text_config, 'head_dim', 'NOT SET'))
with init_on_device_without_buffers("meta"):
self.model: "PreTrainedModel" = AutoModel.from_config(
self.config,
@@ -456,33 +473,85 @@ class Base(nn.Module):
def _fix_attention_head_dim(self) -> None:
"""
Fix head_dim in attention modules after model creation.
Fix head_dim in attention modules and rotary embeddings after model creation.
Some models may have incorrect head_dim in config, which causes
Transformers attention modules to use wrong dimensions for RoPE.
This method corrects head_dim in all attention modules.
Transformers attention modules and RoPE to use wrong dimensions.
This method corrects head_dim in all attention modules and recreates
rotary embeddings if needed.
"""
correct_head_dim = self.hidden_size // getattr(
self.text_config, "num_attention_heads", 32
)
logger.info("DEBUG: _fix_attention_head_dim called, correct_head_dim=%d", correct_head_dim)
fixed_count = 0
attention_modules_found = []
rotary_modules_fixed = []
for name, module in self.model.named_modules():
# Check if this is an attention module
module_name = module.__class__.__name__
# Fix head_dim in Attention modules
if "Attention" in module_name:
current_head_dim = getattr(module, 'head_dim', 'NOT SET')
num_heads = getattr(module, 'num_heads', 'NOT SET')
num_kv_heads = getattr(module, 'num_key_value_heads', 'NOT SET')
attention_modules_found.append(
f"{name}: head_dim={current_head_dim}, num_heads={num_heads}, num_kv_heads={num_kv_heads}"
)
# Fix head_dim if it exists and is incorrect
if hasattr(module, "head_dim"):
if module.head_dim != correct_head_dim:
logger.debug(
"Fixing head_dim in %s: %d -> %d",
logger.warning(
"DEBUG: Fixing head_dim in %s: %d -> %d",
name, module.head_dim, correct_head_dim
)
module.head_dim = correct_head_dim
fixed_count += 1
# Fix rotary embeddings - need to recreate inv_freq buffer
if "RotaryEmbedding" in module_name:
# Check if rotary embedding has wrong dimension
if hasattr(module, "inv_freq"):
current_dim = module.inv_freq.shape[0] * 2 # inv_freq is half the dim
if current_dim != correct_head_dim:
logger.warning(
"DEBUG: Recreating rotary embedding %s: dim %d -> %d",
name, current_dim, correct_head_dim
)
# Recreate inv_freq with correct dimension
base = getattr(module.config, 'rope_theta', 10000.0)
if hasattr(module.config, 'rope_parameters'):
base = module.config.rope_parameters.get('rope_theta', base)
device = module.inv_freq.device
# Create new inv_freq
inv_freq = 1.0 / (
base ** (
torch.arange(0, correct_head_dim, 2, dtype=torch.int64)
.to(device=device, dtype=torch.float) / correct_head_dim
)
)
module.register_buffer("inv_freq", inv_freq, persistent=False)
if hasattr(module, "original_inv_freq"):
module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
rotary_modules_fixed.append(name)
# Print debug info
if attention_modules_found:
logger.info("DEBUG: Found %d attention modules", len(attention_modules_found))
for info in attention_modules_found[:3]:
logger.info("DEBUG: Attention module: %s", info)
if rotary_modules_fixed:
logger.info("DEBUG: Fixed %d rotary embedding modules: %s",
len(rotary_modules_fixed), rotary_modules_fixed)
if fixed_count > 0:
logger.info("Fixed head_dim in %d attention modules", fixed_count)
else:
logger.info("DEBUG: No attention modules needed head_dim fix")
def _replace_input_embeddings(self) -> None:
"""Replace input embeddings with VocabParallelEmbedding."""
@@ -626,6 +695,21 @@ class Base(nn.Module):
# Forward through model
# Note: return_dict=False returns tuple, first element is last hidden state
# DEBUG: Print attention module head_dim values just before forward
logger.info("DEBUG: Checking attention modules before forward...")
for name, module in self.model.named_modules():
if "Attention" in module.__class__.__name__:
head_dim = getattr(module, 'head_dim', 'NOT SET')
rotary_emb = getattr(module, 'rotary_emb', None)
if rotary_emb:
emb_dim = getattr(rotary_emb, 'dim', 'N/A')
logger.info("DEBUG: %s: head_dim=%s, rotary_emb.dim=%s",
name, head_dim, emb_dim)
else:
logger.info("DEBUG: %s: head_dim=%s, rotary_emb=None", name, head_dim)
break # Just print first one
with torch.no_grad():
outputs = self.model(
**model_inputs,