testing dynamic register
This commit is contained in:
@@ -447,12 +447,18 @@ class _ModelRegistry:
|
||||
warn_on_fail=True,
|
||||
)
|
||||
if model_cls is not None:
|
||||
logger.info(
|
||||
"Found custom model class %s from auto_map[%s], "
|
||||
"using TransformersForCausalLM wrapper",
|
||||
model_cls.__name__,
|
||||
name
|
||||
)
|
||||
# Only log once per model class to avoid spam
|
||||
log_key = f"{model_cls.__name__}_{name}"
|
||||
if not hasattr(self, '_logged_custom_models'):
|
||||
self._logged_custom_models = set()
|
||||
if log_key not in self._logged_custom_models:
|
||||
logger.info(
|
||||
"Found custom model class %s from auto_map[%s], "
|
||||
"using TransformersForCausalLM wrapper",
|
||||
model_cls.__name__,
|
||||
name
|
||||
)
|
||||
self._logged_custom_models.add(log_key)
|
||||
# Return the wrapper architecture, not the actual class
|
||||
return "TransformersForCausalLM"
|
||||
|
||||
|
||||
@@ -219,6 +219,9 @@ class Base(nn.Module):
|
||||
# Replace modules (with tensor parallel support)
|
||||
self._replace_modules()
|
||||
|
||||
# Fix attention head_dim in case config was incorrect
|
||||
self._fix_attention_head_dim()
|
||||
|
||||
# Replace input embeddings
|
||||
self._replace_input_embeddings()
|
||||
|
||||
@@ -248,20 +251,36 @@ class Base(nn.Module):
|
||||
# Set attention implementation to vLLM's
|
||||
self.text_config._attn_implementation = "vllm"
|
||||
|
||||
# Ensure head_dim is correctly set in config
|
||||
# Ensure head_dim is correctly set in BOTH config and text_config
|
||||
# Transformers models use config.head_dim to compute attention dimensions
|
||||
# Some models may have incorrect head_dim, so we compute and set it
|
||||
if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"):
|
||||
correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads
|
||||
|
||||
# Check and fix head_dim in text_config
|
||||
if hasattr(self.text_config, "head_dim"):
|
||||
if self.text_config.head_dim != correct_head_dim:
|
||||
logger.warning(
|
||||
"Correcting head_dim in config: %d -> %d",
|
||||
"Correcting head_dim in text_config: %d -> %d",
|
||||
self.text_config.head_dim, correct_head_dim
|
||||
)
|
||||
self.text_config.head_dim = correct_head_dim
|
||||
else:
|
||||
# Set head_dim if not present, some models need it
|
||||
self.text_config.head_dim = correct_head_dim
|
||||
|
||||
# Also set in self.config (which is passed to AutoModel.from_config)
|
||||
if hasattr(self.config, "head_dim"):
|
||||
if self.config.head_dim != correct_head_dim:
|
||||
logger.warning(
|
||||
"Correcting head_dim in config: %d -> %d",
|
||||
self.config.head_dim, correct_head_dim
|
||||
)
|
||||
self.config.head_dim = correct_head_dim
|
||||
else:
|
||||
self.config.head_dim = correct_head_dim
|
||||
|
||||
# Some models also need _attn_implementation in config
|
||||
self.config._attn_implementation = "vllm"
|
||||
|
||||
with init_on_device_without_buffers("meta"):
|
||||
self.model: "PreTrainedModel" = AutoModel.from_config(
|
||||
@@ -435,6 +454,36 @@ class Base(nn.Module):
|
||||
_recursive_replace(self.model, "model")
|
||||
logger.info("Replaced %d modules", replaced_count)
|
||||
|
||||
def _fix_attention_head_dim(self) -> None:
|
||||
"""
|
||||
Fix head_dim in attention modules after model creation.
|
||||
|
||||
Some models may have incorrect head_dim in config, which causes
|
||||
Transformers attention modules to use wrong dimensions for RoPE.
|
||||
This method corrects head_dim in all attention modules.
|
||||
"""
|
||||
correct_head_dim = self.hidden_size // getattr(
|
||||
self.text_config, "num_attention_heads", 32
|
||||
)
|
||||
|
||||
fixed_count = 0
|
||||
for name, module in self.model.named_modules():
|
||||
# Check if this is an attention module
|
||||
module_name = module.__class__.__name__
|
||||
if "Attention" in module_name:
|
||||
# Fix head_dim if it exists and is incorrect
|
||||
if hasattr(module, "head_dim"):
|
||||
if module.head_dim != correct_head_dim:
|
||||
logger.debug(
|
||||
"Fixing head_dim in %s: %d -> %d",
|
||||
name, module.head_dim, correct_head_dim
|
||||
)
|
||||
module.head_dim = correct_head_dim
|
||||
fixed_count += 1
|
||||
|
||||
if fixed_count > 0:
|
||||
logger.info("Fixed head_dim in %d attention modules", fixed_count)
|
||||
|
||||
def _replace_input_embeddings(self) -> None:
|
||||
"""Replace input embeddings with VocabParallelEmbedding."""
|
||||
input_embeddings = self.model.get_input_embeddings()
|
||||
|
||||
Reference in New Issue
Block a user