forked from EngineX-Cambricon/enginex-mlu370-vllm
testing dynamic register
This commit is contained in:
@@ -22,6 +22,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tp_group
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
@@ -251,17 +252,6 @@ class Base(nn.Module):
|
||||
|
||||
logger.info("Creating model structure on meta device...")
|
||||
|
||||
# DEBUG: Print config info before any modifications
|
||||
logger.info("DEBUG: Config type: %s", type(self.config).__name__)
|
||||
logger.info("DEBUG: text_config type: %s", type(self.text_config).__name__)
|
||||
logger.info("DEBUG: hidden_size=%s, num_attention_heads=%s",
|
||||
getattr(self.text_config, 'hidden_size', 'N/A'),
|
||||
getattr(self.text_config, 'num_attention_heads', 'N/A'))
|
||||
logger.info("DEBUG: config.head_dim=%s (before fix)",
|
||||
getattr(self.config, 'head_dim', 'NOT SET'))
|
||||
logger.info("DEBUG: text_config.head_dim=%s (before fix)",
|
||||
getattr(self.text_config, 'head_dim', 'NOT SET'))
|
||||
|
||||
# Set attention implementation to vLLM's
|
||||
self.text_config._attn_implementation = "vllm"
|
||||
|
||||
@@ -270,7 +260,6 @@ class Base(nn.Module):
|
||||
# Some models may have incorrect head_dim, so we compute and set it
|
||||
if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"):
|
||||
correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads
|
||||
logger.info("DEBUG: Computed correct_head_dim = %d", correct_head_dim)
|
||||
|
||||
# Check and fix head_dim in text_config
|
||||
if hasattr(self.text_config, "head_dim"):
|
||||
@@ -297,11 +286,6 @@ class Base(nn.Module):
|
||||
# Some models also need _attn_implementation in config
|
||||
self.config._attn_implementation = "vllm"
|
||||
|
||||
logger.info("DEBUG: config.head_dim=%s (after fix)",
|
||||
getattr(self.config, 'head_dim', 'NOT SET'))
|
||||
logger.info("DEBUG: text_config.head_dim=%s (after fix)",
|
||||
getattr(self.text_config, 'head_dim', 'NOT SET'))
|
||||
|
||||
with init_on_device_without_buffers("meta"):
|
||||
self.model: "PreTrainedModel" = AutoModel.from_config(
|
||||
self.config,
|
||||
@@ -461,7 +445,8 @@ class Base(nn.Module):
|
||||
)
|
||||
replaced_count += 1
|
||||
|
||||
elif child.__class__.__name__.endswith("RMSNorm"):
|
||||
elif child.__class__.__name__.endswith("RMSNorm") and \
|
||||
not isinstance(child, RMSNorm):
|
||||
new_module = replace_rms_norm_class(child, self.hidden_size)
|
||||
replaced_count += 1
|
||||
|
||||
@@ -475,64 +460,8 @@ class Base(nn.Module):
|
||||
logger.info("Replaced %d modules", replaced_count)
|
||||
|
||||
def _add_attention_debug_hook(self) -> None:
|
||||
"""Add debug hooks to capture actual tensor shapes during forward."""
|
||||
# Monkey-patch apply_rotary_pos_emb in the transformers module
|
||||
try:
|
||||
import transformers.models.qwen2.modeling_qwen2 as qwen2_module
|
||||
original_apply_rotary = qwen2_module.apply_rotary_pos_emb
|
||||
|
||||
def _debug_apply_rotary(q, k, cos, sin, unsqueeze_dim=1):
|
||||
logger.info("DEBUG ROTARY: q.shape=%s, k.shape=%s, cos.shape=%s, sin.shape=%s",
|
||||
q.shape, k.shape, cos.shape, sin.shape)
|
||||
# After unsqueeze
|
||||
cos_unsqueezed = cos.unsqueeze(unsqueeze_dim)
|
||||
sin_unsqueezed = sin.unsqueeze(unsqueeze_dim)
|
||||
logger.info("DEBUG ROTARY: after unsqueeze(%d): cos.shape=%s, sin.shape=%s",
|
||||
unsqueeze_dim, cos_unsqueezed.shape, sin_unsqueezed.shape)
|
||||
logger.info("DEBUG ROTARY: q dim 3 = %d, cos dim 3 = %d",
|
||||
q.shape[3] if q.dim() >= 4 else -1,
|
||||
cos_unsqueezed.shape[3] if cos_unsqueezed.dim() >= 4 else -1)
|
||||
return original_apply_rotary(q, k, cos, sin, unsqueeze_dim)
|
||||
|
||||
qwen2_module.apply_rotary_pos_emb = _debug_apply_rotary
|
||||
logger.info("DEBUG: Patched apply_rotary_pos_emb for debugging")
|
||||
except Exception as e:
|
||||
logger.warning("DEBUG: Failed to patch apply_rotary_pos_emb: %s", e)
|
||||
|
||||
# Also add a forward pre-hook with kwargs support
|
||||
for name, module in self.model.named_modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
def _debug_hook(mod, args, kwargs):
|
||||
hidden = kwargs.get('hidden_states', args[0] if args else None)
|
||||
if hidden is not None:
|
||||
logger.info("DEBUG HOOK: Attention input hidden_states.shape=%s", hidden.shape)
|
||||
logger.info("DEBUG HOOK: mod.head_dim=%s (at forward time)", getattr(mod, 'head_dim', 'NOT SET'))
|
||||
# Check mod.config.head_dim
|
||||
mod_config = getattr(mod, 'config', None)
|
||||
if mod_config:
|
||||
logger.info("DEBUG HOOK: mod.config.head_dim=%s", getattr(mod_config, 'head_dim', 'NOT SET'))
|
||||
logger.info("DEBUG HOOK: mod.config id=%d, same as self.config=%s",
|
||||
id(mod_config), id(mod_config) == id(mod_config))
|
||||
# Try q_proj
|
||||
q_proj = getattr(mod, 'q_proj', None)
|
||||
if q_proj is not None:
|
||||
try:
|
||||
q_out = q_proj(hidden)
|
||||
logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape)
|
||||
head_dim = getattr(mod, 'head_dim', 128)
|
||||
input_shape = hidden.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, head_dim)
|
||||
logger.info("DEBUG HOOK: view target shape=%s", hidden_shape)
|
||||
q_viewed = q_out.view(hidden_shape)
|
||||
logger.info("DEBUG HOOK: q_proj viewed shape=%s", q_viewed.shape)
|
||||
q_transposed = q_viewed.transpose(1, 2)
|
||||
logger.info("DEBUG HOOK: q_proj transposed shape=%s", q_transposed.shape)
|
||||
except Exception as e:
|
||||
logger.info("DEBUG HOOK: Error: %s", e)
|
||||
|
||||
module.register_forward_pre_hook(_debug_hook, with_kwargs=True)
|
||||
logger.info("DEBUG: Added debug hook (with_kwargs) to %s", name)
|
||||
break
|
||||
"""No-op. Debug hooks removed after root cause identified."""
|
||||
pass
|
||||
|
||||
def _fix_attention_head_dim(self) -> None:
|
||||
"""
|
||||
@@ -546,50 +475,36 @@ class Base(nn.Module):
|
||||
correct_head_dim = self.hidden_size // getattr(
|
||||
self.text_config, "num_attention_heads", 32
|
||||
)
|
||||
logger.info("DEBUG: _fix_attention_head_dim called, correct_head_dim=%d", correct_head_dim)
|
||||
|
||||
fixed_count = 0
|
||||
attention_modules_found = []
|
||||
rotary_modules_fixed = []
|
||||
|
||||
for name, module in self.model.named_modules():
|
||||
module_name = module.__class__.__name__
|
||||
|
||||
# Fix head_dim in Attention modules
|
||||
if "Attention" in module_name:
|
||||
current_head_dim = getattr(module, 'head_dim', 'NOT SET')
|
||||
num_heads = getattr(module, 'num_heads', 'NOT SET')
|
||||
num_kv_heads = getattr(module, 'num_key_value_heads', 'NOT SET')
|
||||
attention_modules_found.append(
|
||||
f"{name}: head_dim={current_head_dim}, num_heads={num_heads}, num_kv_heads={num_kv_heads}"
|
||||
)
|
||||
|
||||
# Fix head_dim if it exists and is incorrect
|
||||
if hasattr(module, "head_dim"):
|
||||
if module.head_dim != correct_head_dim:
|
||||
logger.warning(
|
||||
"DEBUG: Fixing head_dim in %s: %d -> %d",
|
||||
"Fixing head_dim in %s: %d -> %d",
|
||||
name, module.head_dim, correct_head_dim
|
||||
)
|
||||
module.head_dim = correct_head_dim
|
||||
fixed_count += 1
|
||||
|
||||
# Fix rotary embeddings - need to recreate inv_freq buffer
|
||||
# Fix rotary embeddings - recreate inv_freq buffer if needed
|
||||
if "RotaryEmbedding" in module_name:
|
||||
# Check if rotary embedding has wrong dimension
|
||||
if hasattr(module, "inv_freq"):
|
||||
current_dim = module.inv_freq.shape[0] * 2 # inv_freq is half the dim
|
||||
current_dim = module.inv_freq.shape[0] * 2
|
||||
if current_dim != correct_head_dim:
|
||||
logger.warning(
|
||||
"DEBUG: Recreating rotary embedding %s: dim %d -> %d",
|
||||
"Recreating rotary embedding %s: dim %d -> %d",
|
||||
name, current_dim, correct_head_dim
|
||||
)
|
||||
# Recreate inv_freq with correct dimension
|
||||
base = getattr(module.config, 'rope_theta', 10000.0)
|
||||
if hasattr(module.config, 'rope_parameters'):
|
||||
base = module.config.rope_parameters.get('rope_theta', base)
|
||||
device = module.inv_freq.device
|
||||
# Create new inv_freq
|
||||
inv_freq = 1.0 / (
|
||||
base ** (
|
||||
torch.arange(0, correct_head_dim, 2, dtype=torch.int64)
|
||||
@@ -599,22 +514,9 @@ class Base(nn.Module):
|
||||
module.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
if hasattr(module, "original_inv_freq"):
|
||||
module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
||||
rotary_modules_fixed.append(name)
|
||||
|
||||
# Print debug info
|
||||
if attention_modules_found:
|
||||
logger.info("DEBUG: Found %d attention modules", len(attention_modules_found))
|
||||
for info in attention_modules_found[:3]:
|
||||
logger.info("DEBUG: Attention module: %s", info)
|
||||
|
||||
if rotary_modules_fixed:
|
||||
logger.info("DEBUG: Fixed %d rotary embedding modules: %s",
|
||||
len(rotary_modules_fixed), rotary_modules_fixed)
|
||||
|
||||
if fixed_count > 0:
|
||||
logger.info("Fixed head_dim in %d attention modules", fixed_count)
|
||||
else:
|
||||
logger.info("DEBUG: No attention modules needed head_dim fix")
|
||||
|
||||
def _replace_input_embeddings(self) -> None:
|
||||
"""Replace input embeddings with VocabParallelEmbedding."""
|
||||
@@ -758,80 +660,6 @@ class Base(nn.Module):
|
||||
|
||||
# Forward through model
|
||||
# Note: return_dict=False returns tuple, first element is last hidden state
|
||||
|
||||
# DEBUG: Print detailed model structure info before forward
|
||||
if not hasattr(self, '_debug_printed'):
|
||||
self._debug_printed = True
|
||||
logger.info("DEBUG: === Detailed model structure debug ===")
|
||||
|
||||
# Print transformers version
|
||||
try:
|
||||
import transformers
|
||||
logger.info("DEBUG: transformers version: %s", transformers.__version__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Print TP world size
|
||||
logger.info("DEBUG: TP world_size=%d", self.tp_group.world_size)
|
||||
|
||||
# Print first attention module details
|
||||
for name, module in self.model.named_modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
logger.info("DEBUG: First attention: %s (class=%s)", name, module.__class__.__name__)
|
||||
# Print all attributes
|
||||
for attr in ['head_dim', 'num_heads', 'num_key_value_heads',
|
||||
'hidden_size', 'num_attention_heads',
|
||||
'num_key_value_groups']:
|
||||
val = getattr(module, attr, 'NOT SET')
|
||||
logger.info("DEBUG: %s = %s", attr, val)
|
||||
|
||||
# Print rotary_emb
|
||||
rotary = getattr(module, 'rotary_emb', None)
|
||||
if rotary:
|
||||
logger.info("DEBUG: rotary_emb: %s", type(rotary).__name__)
|
||||
if hasattr(rotary, 'inv_freq'):
|
||||
logger.info("DEBUG: rotary_emb.inv_freq.shape: %s", rotary.inv_freq.shape)
|
||||
else:
|
||||
logger.info("DEBUG: rotary_emb: None")
|
||||
|
||||
# Print projection shapes
|
||||
for proj_name in ['q_proj', 'k_proj', 'v_proj', 'o_proj']:
|
||||
proj = getattr(module, proj_name, None)
|
||||
if proj:
|
||||
if hasattr(proj, 'weight'):
|
||||
logger.info("DEBUG: %s: type=%s, weight.shape=%s",
|
||||
proj_name, type(proj).__name__,
|
||||
proj.weight.shape if proj.weight is not None else 'None')
|
||||
elif hasattr(proj, 'output_size'):
|
||||
logger.info("DEBUG: %s: type=%s, in=%s, out=%s, out_per_part=%s",
|
||||
proj_name, type(proj).__name__,
|
||||
getattr(proj, 'input_size', 'N/A'),
|
||||
getattr(proj, 'output_size', 'N/A'),
|
||||
getattr(proj, 'output_size_per_partition', 'N/A'))
|
||||
break
|
||||
|
||||
# Print model-level rotary_emb
|
||||
model_rotary = getattr(self.model, 'rotary_emb', None)
|
||||
if model_rotary:
|
||||
logger.info("DEBUG: Model-level rotary_emb: %s", type(model_rotary).__name__)
|
||||
if hasattr(model_rotary, 'inv_freq'):
|
||||
logger.info("DEBUG: Model rotary_emb.inv_freq.shape: %s", model_rotary.inv_freq.shape)
|
||||
else:
|
||||
logger.info("DEBUG: No model-level rotary_emb")
|
||||
# Check nested
|
||||
for name, module in self.model.named_modules():
|
||||
if "RotaryEmbedding" in module.__class__.__name__:
|
||||
inv_freq_shape = module.inv_freq.shape if hasattr(module, 'inv_freq') else 'N/A'
|
||||
logger.info("DEBUG: Found rotary at %s: inv_freq.shape=%s", name, inv_freq_shape)
|
||||
break
|
||||
|
||||
# Print config details
|
||||
for attr in ['head_dim', 'hidden_size', 'num_attention_heads', 'num_key_value_heads',
|
||||
'intermediate_size', 'num_hidden_layers']:
|
||||
logger.info("DEBUG: config.%s = %s", attr, getattr(self.config, attr, 'NOT SET'))
|
||||
|
||||
logger.info("DEBUG: === End debug ===")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(
|
||||
**model_inputs,
|
||||
|
||||
@@ -8,7 +8,7 @@ module replacement functions.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -123,23 +123,102 @@ def replace_linear_class(
|
||||
)
|
||||
|
||||
|
||||
class TransformersRMSNorm(RMSNorm):
|
||||
"""
|
||||
vLLM RMSNorm subclass that preserves tensor dimensions.
|
||||
|
||||
vLLM's RMSNorm (especially the MLU backend) flattens input to 2D
|
||||
(e.g., [batch, seq, hidden] -> [batch*seq, hidden]), but transformers
|
||||
expects the batch dimension to be preserved. This subclass wraps
|
||||
the parent forward methods to save and restore the original tensor shape.
|
||||
|
||||
Since this inherits from RMSNorm directly, weight loading via
|
||||
named_parameters() works correctly (weight path stays the same).
|
||||
"""
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
result = super().forward_native(x, residual)
|
||||
return self._restore_shape(result, orig_shape)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
result = super().forward_cuda(x, residual)
|
||||
return self._restore_shape(result, orig_shape)
|
||||
|
||||
def forward_mlu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
result = super().forward_mlu(x, residual)
|
||||
return self._restore_shape(result, orig_shape)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
result = super().forward_xpu(x, residual)
|
||||
return self._restore_shape(result, orig_shape)
|
||||
|
||||
def forward_hpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
result = super().forward_hpu(x, residual)
|
||||
return self._restore_shape(result, orig_shape)
|
||||
|
||||
@staticmethod
|
||||
def _restore_shape(result, orig_shape: Tuple):
|
||||
"""Restore original tensor shape if it was changed."""
|
||||
if isinstance(result, tuple):
|
||||
restored = []
|
||||
for t in result:
|
||||
if t is not None and t.shape != orig_shape:
|
||||
t = t.view(orig_shape)
|
||||
restored.append(t)
|
||||
return tuple(restored)
|
||||
else:
|
||||
if result.shape != orig_shape:
|
||||
result = result.view(orig_shape)
|
||||
return result
|
||||
|
||||
|
||||
def replace_rms_norm_class(
|
||||
rms_norm: nn.Module,
|
||||
hidden_size: int,
|
||||
) -> RMSNorm:
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm.
|
||||
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm,
|
||||
wrapped to preserve tensor dimensions.
|
||||
|
||||
vLLM's RMSNorm provides:
|
||||
- Fused CUDA kernels for better performance
|
||||
- Support for fused add + norm operations
|
||||
|
||||
The wrapper ensures that the original tensor shape (including batch
|
||||
dimension) is preserved, which is required by transformers' model
|
||||
forward methods.
|
||||
|
||||
Args:
|
||||
rms_norm: The RMSNorm module to replace.
|
||||
hidden_size: The hidden size of the model.
|
||||
|
||||
Returns:
|
||||
The new vLLM RMSNorm layer.
|
||||
The new vLLM RMSNorm layer wrapped for shape preservation.
|
||||
"""
|
||||
# Try to get epsilon from various attribute names
|
||||
eps = getattr(rms_norm, "eps", None)
|
||||
@@ -153,7 +232,7 @@ def replace_rms_norm_class(
|
||||
if weight is not None:
|
||||
hidden_size = weight.size(0)
|
||||
|
||||
return RMSNorm(hidden_size=hidden_size, eps=eps)
|
||||
return TransformersRMSNorm(hidden_size=hidden_size, eps=eps)
|
||||
|
||||
|
||||
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user