testing dynamic register
This commit is contained in:
@@ -248,6 +248,17 @@ class Base(nn.Module):
|
||||
|
||||
logger.info("Creating model structure on meta device...")
|
||||
|
||||
# DEBUG: Print config info before any modifications
|
||||
logger.info("DEBUG: Config type: %s", type(self.config).__name__)
|
||||
logger.info("DEBUG: text_config type: %s", type(self.text_config).__name__)
|
||||
logger.info("DEBUG: hidden_size=%s, num_attention_heads=%s",
|
||||
getattr(self.text_config, 'hidden_size', 'N/A'),
|
||||
getattr(self.text_config, 'num_attention_heads', 'N/A'))
|
||||
logger.info("DEBUG: config.head_dim=%s (before fix)",
|
||||
getattr(self.config, 'head_dim', 'NOT SET'))
|
||||
logger.info("DEBUG: text_config.head_dim=%s (before fix)",
|
||||
getattr(self.text_config, 'head_dim', 'NOT SET'))
|
||||
|
||||
# Set attention implementation to vLLM's
|
||||
self.text_config._attn_implementation = "vllm"
|
||||
|
||||
@@ -256,6 +267,7 @@ class Base(nn.Module):
|
||||
# Some models may have incorrect head_dim, so we compute and set it
|
||||
if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"):
|
||||
correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads
|
||||
logger.info("DEBUG: Computed correct_head_dim = %d", correct_head_dim)
|
||||
|
||||
# Check and fix head_dim in text_config
|
||||
if hasattr(self.text_config, "head_dim"):
|
||||
@@ -282,6 +294,11 @@ class Base(nn.Module):
|
||||
# Some models also need _attn_implementation in config
|
||||
self.config._attn_implementation = "vllm"
|
||||
|
||||
logger.info("DEBUG: config.head_dim=%s (after fix)",
|
||||
getattr(self.config, 'head_dim', 'NOT SET'))
|
||||
logger.info("DEBUG: text_config.head_dim=%s (after fix)",
|
||||
getattr(self.text_config, 'head_dim', 'NOT SET'))
|
||||
|
||||
with init_on_device_without_buffers("meta"):
|
||||
self.model: "PreTrainedModel" = AutoModel.from_config(
|
||||
self.config,
|
||||
@@ -456,33 +473,85 @@ class Base(nn.Module):
|
||||
|
||||
def _fix_attention_head_dim(self) -> None:
|
||||
"""
|
||||
Fix head_dim in attention modules after model creation.
|
||||
Fix head_dim in attention modules and rotary embeddings after model creation.
|
||||
|
||||
Some models may have incorrect head_dim in config, which causes
|
||||
Transformers attention modules to use wrong dimensions for RoPE.
|
||||
This method corrects head_dim in all attention modules.
|
||||
Transformers attention modules and RoPE to use wrong dimensions.
|
||||
This method corrects head_dim in all attention modules and recreates
|
||||
rotary embeddings if needed.
|
||||
"""
|
||||
correct_head_dim = self.hidden_size // getattr(
|
||||
self.text_config, "num_attention_heads", 32
|
||||
)
|
||||
logger.info("DEBUG: _fix_attention_head_dim called, correct_head_dim=%d", correct_head_dim)
|
||||
|
||||
fixed_count = 0
|
||||
attention_modules_found = []
|
||||
rotary_modules_fixed = []
|
||||
|
||||
for name, module in self.model.named_modules():
|
||||
# Check if this is an attention module
|
||||
module_name = module.__class__.__name__
|
||||
|
||||
# Fix head_dim in Attention modules
|
||||
if "Attention" in module_name:
|
||||
current_head_dim = getattr(module, 'head_dim', 'NOT SET')
|
||||
num_heads = getattr(module, 'num_heads', 'NOT SET')
|
||||
num_kv_heads = getattr(module, 'num_key_value_heads', 'NOT SET')
|
||||
attention_modules_found.append(
|
||||
f"{name}: head_dim={current_head_dim}, num_heads={num_heads}, num_kv_heads={num_kv_heads}"
|
||||
)
|
||||
|
||||
# Fix head_dim if it exists and is incorrect
|
||||
if hasattr(module, "head_dim"):
|
||||
if module.head_dim != correct_head_dim:
|
||||
logger.debug(
|
||||
"Fixing head_dim in %s: %d -> %d",
|
||||
logger.warning(
|
||||
"DEBUG: Fixing head_dim in %s: %d -> %d",
|
||||
name, module.head_dim, correct_head_dim
|
||||
)
|
||||
module.head_dim = correct_head_dim
|
||||
fixed_count += 1
|
||||
|
||||
# Fix rotary embeddings - need to recreate inv_freq buffer
|
||||
if "RotaryEmbedding" in module_name:
|
||||
# Check if rotary embedding has wrong dimension
|
||||
if hasattr(module, "inv_freq"):
|
||||
current_dim = module.inv_freq.shape[0] * 2 # inv_freq is half the dim
|
||||
if current_dim != correct_head_dim:
|
||||
logger.warning(
|
||||
"DEBUG: Recreating rotary embedding %s: dim %d -> %d",
|
||||
name, current_dim, correct_head_dim
|
||||
)
|
||||
# Recreate inv_freq with correct dimension
|
||||
base = getattr(module.config, 'rope_theta', 10000.0)
|
||||
if hasattr(module.config, 'rope_parameters'):
|
||||
base = module.config.rope_parameters.get('rope_theta', base)
|
||||
device = module.inv_freq.device
|
||||
# Create new inv_freq
|
||||
inv_freq = 1.0 / (
|
||||
base ** (
|
||||
torch.arange(0, correct_head_dim, 2, dtype=torch.int64)
|
||||
.to(device=device, dtype=torch.float) / correct_head_dim
|
||||
)
|
||||
)
|
||||
module.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
if hasattr(module, "original_inv_freq"):
|
||||
module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
||||
rotary_modules_fixed.append(name)
|
||||
|
||||
# Print debug info
|
||||
if attention_modules_found:
|
||||
logger.info("DEBUG: Found %d attention modules", len(attention_modules_found))
|
||||
for info in attention_modules_found[:3]:
|
||||
logger.info("DEBUG: Attention module: %s", info)
|
||||
|
||||
if rotary_modules_fixed:
|
||||
logger.info("DEBUG: Fixed %d rotary embedding modules: %s",
|
||||
len(rotary_modules_fixed), rotary_modules_fixed)
|
||||
|
||||
if fixed_count > 0:
|
||||
logger.info("Fixed head_dim in %d attention modules", fixed_count)
|
||||
else:
|
||||
logger.info("DEBUG: No attention modules needed head_dim fix")
|
||||
|
||||
def _replace_input_embeddings(self) -> None:
|
||||
"""Replace input embeddings with VocabParallelEmbedding."""
|
||||
@@ -626,6 +695,21 @@ class Base(nn.Module):
|
||||
|
||||
# Forward through model
|
||||
# Note: return_dict=False returns tuple, first element is last hidden state
|
||||
|
||||
# DEBUG: Print attention module head_dim values just before forward
|
||||
logger.info("DEBUG: Checking attention modules before forward...")
|
||||
for name, module in self.model.named_modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
head_dim = getattr(module, 'head_dim', 'NOT SET')
|
||||
rotary_emb = getattr(module, 'rotary_emb', None)
|
||||
if rotary_emb:
|
||||
emb_dim = getattr(rotary_emb, 'dim', 'N/A')
|
||||
logger.info("DEBUG: %s: head_dim=%s, rotary_emb.dim=%s",
|
||||
name, head_dim, emb_dim)
|
||||
else:
|
||||
logger.info("DEBUG: %s: head_dim=%s, rotary_emb=None", name, head_dim)
|
||||
break # Just print first one
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(
|
||||
**model_inputs,
|
||||
|
||||
Reference in New Issue
Block a user