diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py index a30dde2..85d8651 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/base.py @@ -22,6 +22,7 @@ from vllm.config import VllmConfig from vllm.distributed import get_pp_group, get_tp_group from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.models.utils import ( AutoWeightsLoader, @@ -251,17 +252,6 @@ class Base(nn.Module): logger.info("Creating model structure on meta device...") - # DEBUG: Print config info before any modifications - logger.info("DEBUG: Config type: %s", type(self.config).__name__) - logger.info("DEBUG: text_config type: %s", type(self.text_config).__name__) - logger.info("DEBUG: hidden_size=%s, num_attention_heads=%s", - getattr(self.text_config, 'hidden_size', 'N/A'), - getattr(self.text_config, 'num_attention_heads', 'N/A')) - logger.info("DEBUG: config.head_dim=%s (before fix)", - getattr(self.config, 'head_dim', 'NOT SET')) - logger.info("DEBUG: text_config.head_dim=%s (before fix)", - getattr(self.text_config, 'head_dim', 'NOT SET')) - # Set attention implementation to vLLM's self.text_config._attn_implementation = "vllm" @@ -270,7 +260,6 @@ class Base(nn.Module): # Some models may have incorrect head_dim, so we compute and set it if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"): correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads - logger.info("DEBUG: Computed correct_head_dim = %d", correct_head_dim) # Check and fix head_dim in text_config if hasattr(self.text_config, "head_dim"): @@ -297,11 +286,6 @@ class Base(nn.Module): # Some models also need _attn_implementation in config self.config._attn_implementation = "vllm" - logger.info("DEBUG: config.head_dim=%s (after fix)", - getattr(self.config, 'head_dim', 'NOT SET')) - logger.info("DEBUG: text_config.head_dim=%s (after fix)", - getattr(self.text_config, 'head_dim', 'NOT SET')) - with init_on_device_without_buffers("meta"): self.model: "PreTrainedModel" = AutoModel.from_config( self.config, @@ -461,7 +445,8 @@ class Base(nn.Module): ) replaced_count += 1 - elif child.__class__.__name__.endswith("RMSNorm"): + elif child.__class__.__name__.endswith("RMSNorm") and \ + not isinstance(child, RMSNorm): new_module = replace_rms_norm_class(child, self.hidden_size) replaced_count += 1 @@ -475,64 +460,8 @@ class Base(nn.Module): logger.info("Replaced %d modules", replaced_count) def _add_attention_debug_hook(self) -> None: - """Add debug hooks to capture actual tensor shapes during forward.""" - # Monkey-patch apply_rotary_pos_emb in the transformers module - try: - import transformers.models.qwen2.modeling_qwen2 as qwen2_module - original_apply_rotary = qwen2_module.apply_rotary_pos_emb - - def _debug_apply_rotary(q, k, cos, sin, unsqueeze_dim=1): - logger.info("DEBUG ROTARY: q.shape=%s, k.shape=%s, cos.shape=%s, sin.shape=%s", - q.shape, k.shape, cos.shape, sin.shape) - # After unsqueeze - cos_unsqueezed = cos.unsqueeze(unsqueeze_dim) - sin_unsqueezed = sin.unsqueeze(unsqueeze_dim) - logger.info("DEBUG ROTARY: after unsqueeze(%d): cos.shape=%s, sin.shape=%s", - unsqueeze_dim, cos_unsqueezed.shape, sin_unsqueezed.shape) - logger.info("DEBUG ROTARY: q dim 3 = %d, cos dim 3 = %d", - q.shape[3] if q.dim() >= 4 else -1, - cos_unsqueezed.shape[3] if cos_unsqueezed.dim() >= 4 else -1) - return original_apply_rotary(q, k, cos, sin, unsqueeze_dim) - - qwen2_module.apply_rotary_pos_emb = _debug_apply_rotary - logger.info("DEBUG: Patched apply_rotary_pos_emb for debugging") - except Exception as e: - logger.warning("DEBUG: Failed to patch apply_rotary_pos_emb: %s", e) - - # Also add a forward pre-hook with kwargs support - for name, module in self.model.named_modules(): - if "Attention" in module.__class__.__name__: - def _debug_hook(mod, args, kwargs): - hidden = kwargs.get('hidden_states', args[0] if args else None) - if hidden is not None: - logger.info("DEBUG HOOK: Attention input hidden_states.shape=%s", hidden.shape) - logger.info("DEBUG HOOK: mod.head_dim=%s (at forward time)", getattr(mod, 'head_dim', 'NOT SET')) - # Check mod.config.head_dim - mod_config = getattr(mod, 'config', None) - if mod_config: - logger.info("DEBUG HOOK: mod.config.head_dim=%s", getattr(mod_config, 'head_dim', 'NOT SET')) - logger.info("DEBUG HOOK: mod.config id=%d, same as self.config=%s", - id(mod_config), id(mod_config) == id(mod_config)) - # Try q_proj - q_proj = getattr(mod, 'q_proj', None) - if q_proj is not None: - try: - q_out = q_proj(hidden) - logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape) - head_dim = getattr(mod, 'head_dim', 128) - input_shape = hidden.shape[:-1] - hidden_shape = (*input_shape, -1, head_dim) - logger.info("DEBUG HOOK: view target shape=%s", hidden_shape) - q_viewed = q_out.view(hidden_shape) - logger.info("DEBUG HOOK: q_proj viewed shape=%s", q_viewed.shape) - q_transposed = q_viewed.transpose(1, 2) - logger.info("DEBUG HOOK: q_proj transposed shape=%s", q_transposed.shape) - except Exception as e: - logger.info("DEBUG HOOK: Error: %s", e) - - module.register_forward_pre_hook(_debug_hook, with_kwargs=True) - logger.info("DEBUG: Added debug hook (with_kwargs) to %s", name) - break + """No-op. Debug hooks removed after root cause identified.""" + pass def _fix_attention_head_dim(self) -> None: """ @@ -546,50 +475,36 @@ class Base(nn.Module): correct_head_dim = self.hidden_size // getattr( self.text_config, "num_attention_heads", 32 ) - logger.info("DEBUG: _fix_attention_head_dim called, correct_head_dim=%d", correct_head_dim) fixed_count = 0 - attention_modules_found = [] - rotary_modules_fixed = [] for name, module in self.model.named_modules(): module_name = module.__class__.__name__ # Fix head_dim in Attention modules if "Attention" in module_name: - current_head_dim = getattr(module, 'head_dim', 'NOT SET') - num_heads = getattr(module, 'num_heads', 'NOT SET') - num_kv_heads = getattr(module, 'num_key_value_heads', 'NOT SET') - attention_modules_found.append( - f"{name}: head_dim={current_head_dim}, num_heads={num_heads}, num_kv_heads={num_kv_heads}" - ) - - # Fix head_dim if it exists and is incorrect if hasattr(module, "head_dim"): if module.head_dim != correct_head_dim: logger.warning( - "DEBUG: Fixing head_dim in %s: %d -> %d", + "Fixing head_dim in %s: %d -> %d", name, module.head_dim, correct_head_dim ) module.head_dim = correct_head_dim fixed_count += 1 - # Fix rotary embeddings - need to recreate inv_freq buffer + # Fix rotary embeddings - recreate inv_freq buffer if needed if "RotaryEmbedding" in module_name: - # Check if rotary embedding has wrong dimension if hasattr(module, "inv_freq"): - current_dim = module.inv_freq.shape[0] * 2 # inv_freq is half the dim + current_dim = module.inv_freq.shape[0] * 2 if current_dim != correct_head_dim: logger.warning( - "DEBUG: Recreating rotary embedding %s: dim %d -> %d", + "Recreating rotary embedding %s: dim %d -> %d", name, current_dim, correct_head_dim ) - # Recreate inv_freq with correct dimension base = getattr(module.config, 'rope_theta', 10000.0) if hasattr(module.config, 'rope_parameters'): base = module.config.rope_parameters.get('rope_theta', base) device = module.inv_freq.device - # Create new inv_freq inv_freq = 1.0 / ( base ** ( torch.arange(0, correct_head_dim, 2, dtype=torch.int64) @@ -599,22 +514,9 @@ class Base(nn.Module): module.register_buffer("inv_freq", inv_freq, persistent=False) if hasattr(module, "original_inv_freq"): module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) - rotary_modules_fixed.append(name) - - # Print debug info - if attention_modules_found: - logger.info("DEBUG: Found %d attention modules", len(attention_modules_found)) - for info in attention_modules_found[:3]: - logger.info("DEBUG: Attention module: %s", info) - - if rotary_modules_fixed: - logger.info("DEBUG: Fixed %d rotary embedding modules: %s", - len(rotary_modules_fixed), rotary_modules_fixed) if fixed_count > 0: logger.info("Fixed head_dim in %d attention modules", fixed_count) - else: - logger.info("DEBUG: No attention modules needed head_dim fix") def _replace_input_embeddings(self) -> None: """Replace input embeddings with VocabParallelEmbedding.""" @@ -758,80 +660,6 @@ class Base(nn.Module): # Forward through model # Note: return_dict=False returns tuple, first element is last hidden state - - # DEBUG: Print detailed model structure info before forward - if not hasattr(self, '_debug_printed'): - self._debug_printed = True - logger.info("DEBUG: === Detailed model structure debug ===") - - # Print transformers version - try: - import transformers - logger.info("DEBUG: transformers version: %s", transformers.__version__) - except Exception: - pass - - # Print TP world size - logger.info("DEBUG: TP world_size=%d", self.tp_group.world_size) - - # Print first attention module details - for name, module in self.model.named_modules(): - if "Attention" in module.__class__.__name__: - logger.info("DEBUG: First attention: %s (class=%s)", name, module.__class__.__name__) - # Print all attributes - for attr in ['head_dim', 'num_heads', 'num_key_value_heads', - 'hidden_size', 'num_attention_heads', - 'num_key_value_groups']: - val = getattr(module, attr, 'NOT SET') - logger.info("DEBUG: %s = %s", attr, val) - - # Print rotary_emb - rotary = getattr(module, 'rotary_emb', None) - if rotary: - logger.info("DEBUG: rotary_emb: %s", type(rotary).__name__) - if hasattr(rotary, 'inv_freq'): - logger.info("DEBUG: rotary_emb.inv_freq.shape: %s", rotary.inv_freq.shape) - else: - logger.info("DEBUG: rotary_emb: None") - - # Print projection shapes - for proj_name in ['q_proj', 'k_proj', 'v_proj', 'o_proj']: - proj = getattr(module, proj_name, None) - if proj: - if hasattr(proj, 'weight'): - logger.info("DEBUG: %s: type=%s, weight.shape=%s", - proj_name, type(proj).__name__, - proj.weight.shape if proj.weight is not None else 'None') - elif hasattr(proj, 'output_size'): - logger.info("DEBUG: %s: type=%s, in=%s, out=%s, out_per_part=%s", - proj_name, type(proj).__name__, - getattr(proj, 'input_size', 'N/A'), - getattr(proj, 'output_size', 'N/A'), - getattr(proj, 'output_size_per_partition', 'N/A')) - break - - # Print model-level rotary_emb - model_rotary = getattr(self.model, 'rotary_emb', None) - if model_rotary: - logger.info("DEBUG: Model-level rotary_emb: %s", type(model_rotary).__name__) - if hasattr(model_rotary, 'inv_freq'): - logger.info("DEBUG: Model rotary_emb.inv_freq.shape: %s", model_rotary.inv_freq.shape) - else: - logger.info("DEBUG: No model-level rotary_emb") - # Check nested - for name, module in self.model.named_modules(): - if "RotaryEmbedding" in module.__class__.__name__: - inv_freq_shape = module.inv_freq.shape if hasattr(module, 'inv_freq') else 'N/A' - logger.info("DEBUG: Found rotary at %s: inv_freq.shape=%s", name, inv_freq_shape) - break - - # Print config details - for attr in ['head_dim', 'hidden_size', 'num_attention_heads', 'num_key_value_heads', - 'intermediate_size', 'num_hidden_layers']: - logger.info("DEBUG: config.%s = %s", attr, getattr(self.config, attr, 'NOT SET')) - - logger.info("DEBUG: === End debug ===") - with torch.no_grad(): outputs = self.model( **model_inputs, diff --git a/vllm-v0.6.2/vllm/model_executor/models/transformers/utils.py b/vllm-v0.6.2/vllm/model_executor/models/transformers/utils.py index abd2020..1d6f3f8 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/transformers/utils.py +++ b/vllm-v0.6.2/vllm/model_executor/models/transformers/utils.py @@ -8,7 +8,7 @@ module replacement functions. """ from contextlib import contextmanager -from typing import TYPE_CHECKING, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union import torch import torch.nn as nn @@ -123,23 +123,102 @@ def replace_linear_class( ) +class TransformersRMSNorm(RMSNorm): + """ + vLLM RMSNorm subclass that preserves tensor dimensions. + + vLLM's RMSNorm (especially the MLU backend) flattens input to 2D + (e.g., [batch, seq, hidden] -> [batch*seq, hidden]), but transformers + expects the batch dimension to be preserved. This subclass wraps + the parent forward methods to save and restore the original tensor shape. + + Since this inherits from RMSNorm directly, weight loading via + named_parameters() works correctly (weight path stays the same). + """ + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ): + orig_shape = x.shape + result = super().forward_native(x, residual) + return self._restore_shape(result, orig_shape) + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ): + orig_shape = x.shape + result = super().forward_cuda(x, residual) + return self._restore_shape(result, orig_shape) + + def forward_mlu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ): + orig_shape = x.shape + result = super().forward_mlu(x, residual) + return self._restore_shape(result, orig_shape) + + def forward_xpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ): + orig_shape = x.shape + result = super().forward_xpu(x, residual) + return self._restore_shape(result, orig_shape) + + def forward_hpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ): + orig_shape = x.shape + result = super().forward_hpu(x, residual) + return self._restore_shape(result, orig_shape) + + @staticmethod + def _restore_shape(result, orig_shape: Tuple): + """Restore original tensor shape if it was changed.""" + if isinstance(result, tuple): + restored = [] + for t in result: + if t is not None and t.shape != orig_shape: + t = t.view(orig_shape) + restored.append(t) + return tuple(restored) + else: + if result.shape != orig_shape: + result = result.view(orig_shape) + return result + + def replace_rms_norm_class( rms_norm: nn.Module, hidden_size: int, -) -> RMSNorm: +) -> nn.Module: """ - Replace a Transformers RMSNorm with vLLM's optimized RMSNorm. + Replace a Transformers RMSNorm with vLLM's optimized RMSNorm, + wrapped to preserve tensor dimensions. vLLM's RMSNorm provides: - Fused CUDA kernels for better performance - Support for fused add + norm operations + The wrapper ensures that the original tensor shape (including batch + dimension) is preserved, which is required by transformers' model + forward methods. + Args: rms_norm: The RMSNorm module to replace. hidden_size: The hidden size of the model. Returns: - The new vLLM RMSNorm layer. + The new vLLM RMSNorm layer wrapped for shape preservation. """ # Try to get epsilon from various attribute names eps = getattr(rms_norm, "eps", None) @@ -153,7 +232,7 @@ def replace_rms_norm_class( if weight is not None: hidden_size = weight.size(0) - return RMSNorm(hidden_size=hidden_size, eps=eps) + return TransformersRMSNorm(hidden_size=hidden_size, eps=eps) def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):