forked from EngineX-Cambricon/enginex-mlu370-vllm
testing dynamic register
This commit is contained in:
@@ -22,6 +22,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed import get_pp_group, get_tp_group
|
from vllm.distributed import get_pp_group, get_tp_group
|
||||||
from vllm.distributed.utils import get_pp_indices
|
from vllm.distributed.utils import get_pp_indices
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from vllm.model_executor.models.utils import (
|
from vllm.model_executor.models.utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
@@ -251,17 +252,6 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
logger.info("Creating model structure on meta device...")
|
logger.info("Creating model structure on meta device...")
|
||||||
|
|
||||||
# DEBUG: Print config info before any modifications
|
|
||||||
logger.info("DEBUG: Config type: %s", type(self.config).__name__)
|
|
||||||
logger.info("DEBUG: text_config type: %s", type(self.text_config).__name__)
|
|
||||||
logger.info("DEBUG: hidden_size=%s, num_attention_heads=%s",
|
|
||||||
getattr(self.text_config, 'hidden_size', 'N/A'),
|
|
||||||
getattr(self.text_config, 'num_attention_heads', 'N/A'))
|
|
||||||
logger.info("DEBUG: config.head_dim=%s (before fix)",
|
|
||||||
getattr(self.config, 'head_dim', 'NOT SET'))
|
|
||||||
logger.info("DEBUG: text_config.head_dim=%s (before fix)",
|
|
||||||
getattr(self.text_config, 'head_dim', 'NOT SET'))
|
|
||||||
|
|
||||||
# Set attention implementation to vLLM's
|
# Set attention implementation to vLLM's
|
||||||
self.text_config._attn_implementation = "vllm"
|
self.text_config._attn_implementation = "vllm"
|
||||||
|
|
||||||
@@ -270,7 +260,6 @@ class Base(nn.Module):
|
|||||||
# Some models may have incorrect head_dim, so we compute and set it
|
# Some models may have incorrect head_dim, so we compute and set it
|
||||||
if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"):
|
if hasattr(self.text_config, "num_attention_heads") and hasattr(self.text_config, "hidden_size"):
|
||||||
correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads
|
correct_head_dim = self.text_config.hidden_size // self.text_config.num_attention_heads
|
||||||
logger.info("DEBUG: Computed correct_head_dim = %d", correct_head_dim)
|
|
||||||
|
|
||||||
# Check and fix head_dim in text_config
|
# Check and fix head_dim in text_config
|
||||||
if hasattr(self.text_config, "head_dim"):
|
if hasattr(self.text_config, "head_dim"):
|
||||||
@@ -297,11 +286,6 @@ class Base(nn.Module):
|
|||||||
# Some models also need _attn_implementation in config
|
# Some models also need _attn_implementation in config
|
||||||
self.config._attn_implementation = "vllm"
|
self.config._attn_implementation = "vllm"
|
||||||
|
|
||||||
logger.info("DEBUG: config.head_dim=%s (after fix)",
|
|
||||||
getattr(self.config, 'head_dim', 'NOT SET'))
|
|
||||||
logger.info("DEBUG: text_config.head_dim=%s (after fix)",
|
|
||||||
getattr(self.text_config, 'head_dim', 'NOT SET'))
|
|
||||||
|
|
||||||
with init_on_device_without_buffers("meta"):
|
with init_on_device_without_buffers("meta"):
|
||||||
self.model: "PreTrainedModel" = AutoModel.from_config(
|
self.model: "PreTrainedModel" = AutoModel.from_config(
|
||||||
self.config,
|
self.config,
|
||||||
@@ -461,7 +445,8 @@ class Base(nn.Module):
|
|||||||
)
|
)
|
||||||
replaced_count += 1
|
replaced_count += 1
|
||||||
|
|
||||||
elif child.__class__.__name__.endswith("RMSNorm"):
|
elif child.__class__.__name__.endswith("RMSNorm") and \
|
||||||
|
not isinstance(child, RMSNorm):
|
||||||
new_module = replace_rms_norm_class(child, self.hidden_size)
|
new_module = replace_rms_norm_class(child, self.hidden_size)
|
||||||
replaced_count += 1
|
replaced_count += 1
|
||||||
|
|
||||||
@@ -475,64 +460,8 @@ class Base(nn.Module):
|
|||||||
logger.info("Replaced %d modules", replaced_count)
|
logger.info("Replaced %d modules", replaced_count)
|
||||||
|
|
||||||
def _add_attention_debug_hook(self) -> None:
|
def _add_attention_debug_hook(self) -> None:
|
||||||
"""Add debug hooks to capture actual tensor shapes during forward."""
|
"""No-op. Debug hooks removed after root cause identified."""
|
||||||
# Monkey-patch apply_rotary_pos_emb in the transformers module
|
pass
|
||||||
try:
|
|
||||||
import transformers.models.qwen2.modeling_qwen2 as qwen2_module
|
|
||||||
original_apply_rotary = qwen2_module.apply_rotary_pos_emb
|
|
||||||
|
|
||||||
def _debug_apply_rotary(q, k, cos, sin, unsqueeze_dim=1):
|
|
||||||
logger.info("DEBUG ROTARY: q.shape=%s, k.shape=%s, cos.shape=%s, sin.shape=%s",
|
|
||||||
q.shape, k.shape, cos.shape, sin.shape)
|
|
||||||
# After unsqueeze
|
|
||||||
cos_unsqueezed = cos.unsqueeze(unsqueeze_dim)
|
|
||||||
sin_unsqueezed = sin.unsqueeze(unsqueeze_dim)
|
|
||||||
logger.info("DEBUG ROTARY: after unsqueeze(%d): cos.shape=%s, sin.shape=%s",
|
|
||||||
unsqueeze_dim, cos_unsqueezed.shape, sin_unsqueezed.shape)
|
|
||||||
logger.info("DEBUG ROTARY: q dim 3 = %d, cos dim 3 = %d",
|
|
||||||
q.shape[3] if q.dim() >= 4 else -1,
|
|
||||||
cos_unsqueezed.shape[3] if cos_unsqueezed.dim() >= 4 else -1)
|
|
||||||
return original_apply_rotary(q, k, cos, sin, unsqueeze_dim)
|
|
||||||
|
|
||||||
qwen2_module.apply_rotary_pos_emb = _debug_apply_rotary
|
|
||||||
logger.info("DEBUG: Patched apply_rotary_pos_emb for debugging")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("DEBUG: Failed to patch apply_rotary_pos_emb: %s", e)
|
|
||||||
|
|
||||||
# Also add a forward pre-hook with kwargs support
|
|
||||||
for name, module in self.model.named_modules():
|
|
||||||
if "Attention" in module.__class__.__name__:
|
|
||||||
def _debug_hook(mod, args, kwargs):
|
|
||||||
hidden = kwargs.get('hidden_states', args[0] if args else None)
|
|
||||||
if hidden is not None:
|
|
||||||
logger.info("DEBUG HOOK: Attention input hidden_states.shape=%s", hidden.shape)
|
|
||||||
logger.info("DEBUG HOOK: mod.head_dim=%s (at forward time)", getattr(mod, 'head_dim', 'NOT SET'))
|
|
||||||
# Check mod.config.head_dim
|
|
||||||
mod_config = getattr(mod, 'config', None)
|
|
||||||
if mod_config:
|
|
||||||
logger.info("DEBUG HOOK: mod.config.head_dim=%s", getattr(mod_config, 'head_dim', 'NOT SET'))
|
|
||||||
logger.info("DEBUG HOOK: mod.config id=%d, same as self.config=%s",
|
|
||||||
id(mod_config), id(mod_config) == id(mod_config))
|
|
||||||
# Try q_proj
|
|
||||||
q_proj = getattr(mod, 'q_proj', None)
|
|
||||||
if q_proj is not None:
|
|
||||||
try:
|
|
||||||
q_out = q_proj(hidden)
|
|
||||||
logger.info("DEBUG HOOK: q_proj output shape=%s", q_out.shape)
|
|
||||||
head_dim = getattr(mod, 'head_dim', 128)
|
|
||||||
input_shape = hidden.shape[:-1]
|
|
||||||
hidden_shape = (*input_shape, -1, head_dim)
|
|
||||||
logger.info("DEBUG HOOK: view target shape=%s", hidden_shape)
|
|
||||||
q_viewed = q_out.view(hidden_shape)
|
|
||||||
logger.info("DEBUG HOOK: q_proj viewed shape=%s", q_viewed.shape)
|
|
||||||
q_transposed = q_viewed.transpose(1, 2)
|
|
||||||
logger.info("DEBUG HOOK: q_proj transposed shape=%s", q_transposed.shape)
|
|
||||||
except Exception as e:
|
|
||||||
logger.info("DEBUG HOOK: Error: %s", e)
|
|
||||||
|
|
||||||
module.register_forward_pre_hook(_debug_hook, with_kwargs=True)
|
|
||||||
logger.info("DEBUG: Added debug hook (with_kwargs) to %s", name)
|
|
||||||
break
|
|
||||||
|
|
||||||
def _fix_attention_head_dim(self) -> None:
|
def _fix_attention_head_dim(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -546,50 +475,36 @@ class Base(nn.Module):
|
|||||||
correct_head_dim = self.hidden_size // getattr(
|
correct_head_dim = self.hidden_size // getattr(
|
||||||
self.text_config, "num_attention_heads", 32
|
self.text_config, "num_attention_heads", 32
|
||||||
)
|
)
|
||||||
logger.info("DEBUG: _fix_attention_head_dim called, correct_head_dim=%d", correct_head_dim)
|
|
||||||
|
|
||||||
fixed_count = 0
|
fixed_count = 0
|
||||||
attention_modules_found = []
|
|
||||||
rotary_modules_fixed = []
|
|
||||||
|
|
||||||
for name, module in self.model.named_modules():
|
for name, module in self.model.named_modules():
|
||||||
module_name = module.__class__.__name__
|
module_name = module.__class__.__name__
|
||||||
|
|
||||||
# Fix head_dim in Attention modules
|
# Fix head_dim in Attention modules
|
||||||
if "Attention" in module_name:
|
if "Attention" in module_name:
|
||||||
current_head_dim = getattr(module, 'head_dim', 'NOT SET')
|
|
||||||
num_heads = getattr(module, 'num_heads', 'NOT SET')
|
|
||||||
num_kv_heads = getattr(module, 'num_key_value_heads', 'NOT SET')
|
|
||||||
attention_modules_found.append(
|
|
||||||
f"{name}: head_dim={current_head_dim}, num_heads={num_heads}, num_kv_heads={num_kv_heads}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fix head_dim if it exists and is incorrect
|
|
||||||
if hasattr(module, "head_dim"):
|
if hasattr(module, "head_dim"):
|
||||||
if module.head_dim != correct_head_dim:
|
if module.head_dim != correct_head_dim:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"DEBUG: Fixing head_dim in %s: %d -> %d",
|
"Fixing head_dim in %s: %d -> %d",
|
||||||
name, module.head_dim, correct_head_dim
|
name, module.head_dim, correct_head_dim
|
||||||
)
|
)
|
||||||
module.head_dim = correct_head_dim
|
module.head_dim = correct_head_dim
|
||||||
fixed_count += 1
|
fixed_count += 1
|
||||||
|
|
||||||
# Fix rotary embeddings - need to recreate inv_freq buffer
|
# Fix rotary embeddings - recreate inv_freq buffer if needed
|
||||||
if "RotaryEmbedding" in module_name:
|
if "RotaryEmbedding" in module_name:
|
||||||
# Check if rotary embedding has wrong dimension
|
|
||||||
if hasattr(module, "inv_freq"):
|
if hasattr(module, "inv_freq"):
|
||||||
current_dim = module.inv_freq.shape[0] * 2 # inv_freq is half the dim
|
current_dim = module.inv_freq.shape[0] * 2
|
||||||
if current_dim != correct_head_dim:
|
if current_dim != correct_head_dim:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"DEBUG: Recreating rotary embedding %s: dim %d -> %d",
|
"Recreating rotary embedding %s: dim %d -> %d",
|
||||||
name, current_dim, correct_head_dim
|
name, current_dim, correct_head_dim
|
||||||
)
|
)
|
||||||
# Recreate inv_freq with correct dimension
|
|
||||||
base = getattr(module.config, 'rope_theta', 10000.0)
|
base = getattr(module.config, 'rope_theta', 10000.0)
|
||||||
if hasattr(module.config, 'rope_parameters'):
|
if hasattr(module.config, 'rope_parameters'):
|
||||||
base = module.config.rope_parameters.get('rope_theta', base)
|
base = module.config.rope_parameters.get('rope_theta', base)
|
||||||
device = module.inv_freq.device
|
device = module.inv_freq.device
|
||||||
# Create new inv_freq
|
|
||||||
inv_freq = 1.0 / (
|
inv_freq = 1.0 / (
|
||||||
base ** (
|
base ** (
|
||||||
torch.arange(0, correct_head_dim, 2, dtype=torch.int64)
|
torch.arange(0, correct_head_dim, 2, dtype=torch.int64)
|
||||||
@@ -599,22 +514,9 @@ class Base(nn.Module):
|
|||||||
module.register_buffer("inv_freq", inv_freq, persistent=False)
|
module.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
if hasattr(module, "original_inv_freq"):
|
if hasattr(module, "original_inv_freq"):
|
||||||
module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
module.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
||||||
rotary_modules_fixed.append(name)
|
|
||||||
|
|
||||||
# Print debug info
|
|
||||||
if attention_modules_found:
|
|
||||||
logger.info("DEBUG: Found %d attention modules", len(attention_modules_found))
|
|
||||||
for info in attention_modules_found[:3]:
|
|
||||||
logger.info("DEBUG: Attention module: %s", info)
|
|
||||||
|
|
||||||
if rotary_modules_fixed:
|
|
||||||
logger.info("DEBUG: Fixed %d rotary embedding modules: %s",
|
|
||||||
len(rotary_modules_fixed), rotary_modules_fixed)
|
|
||||||
|
|
||||||
if fixed_count > 0:
|
if fixed_count > 0:
|
||||||
logger.info("Fixed head_dim in %d attention modules", fixed_count)
|
logger.info("Fixed head_dim in %d attention modules", fixed_count)
|
||||||
else:
|
|
||||||
logger.info("DEBUG: No attention modules needed head_dim fix")
|
|
||||||
|
|
||||||
def _replace_input_embeddings(self) -> None:
|
def _replace_input_embeddings(self) -> None:
|
||||||
"""Replace input embeddings with VocabParallelEmbedding."""
|
"""Replace input embeddings with VocabParallelEmbedding."""
|
||||||
@@ -758,80 +660,6 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
# Forward through model
|
# Forward through model
|
||||||
# Note: return_dict=False returns tuple, first element is last hidden state
|
# Note: return_dict=False returns tuple, first element is last hidden state
|
||||||
|
|
||||||
# DEBUG: Print detailed model structure info before forward
|
|
||||||
if not hasattr(self, '_debug_printed'):
|
|
||||||
self._debug_printed = True
|
|
||||||
logger.info("DEBUG: === Detailed model structure debug ===")
|
|
||||||
|
|
||||||
# Print transformers version
|
|
||||||
try:
|
|
||||||
import transformers
|
|
||||||
logger.info("DEBUG: transformers version: %s", transformers.__version__)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Print TP world size
|
|
||||||
logger.info("DEBUG: TP world_size=%d", self.tp_group.world_size)
|
|
||||||
|
|
||||||
# Print first attention module details
|
|
||||||
for name, module in self.model.named_modules():
|
|
||||||
if "Attention" in module.__class__.__name__:
|
|
||||||
logger.info("DEBUG: First attention: %s (class=%s)", name, module.__class__.__name__)
|
|
||||||
# Print all attributes
|
|
||||||
for attr in ['head_dim', 'num_heads', 'num_key_value_heads',
|
|
||||||
'hidden_size', 'num_attention_heads',
|
|
||||||
'num_key_value_groups']:
|
|
||||||
val = getattr(module, attr, 'NOT SET')
|
|
||||||
logger.info("DEBUG: %s = %s", attr, val)
|
|
||||||
|
|
||||||
# Print rotary_emb
|
|
||||||
rotary = getattr(module, 'rotary_emb', None)
|
|
||||||
if rotary:
|
|
||||||
logger.info("DEBUG: rotary_emb: %s", type(rotary).__name__)
|
|
||||||
if hasattr(rotary, 'inv_freq'):
|
|
||||||
logger.info("DEBUG: rotary_emb.inv_freq.shape: %s", rotary.inv_freq.shape)
|
|
||||||
else:
|
|
||||||
logger.info("DEBUG: rotary_emb: None")
|
|
||||||
|
|
||||||
# Print projection shapes
|
|
||||||
for proj_name in ['q_proj', 'k_proj', 'v_proj', 'o_proj']:
|
|
||||||
proj = getattr(module, proj_name, None)
|
|
||||||
if proj:
|
|
||||||
if hasattr(proj, 'weight'):
|
|
||||||
logger.info("DEBUG: %s: type=%s, weight.shape=%s",
|
|
||||||
proj_name, type(proj).__name__,
|
|
||||||
proj.weight.shape if proj.weight is not None else 'None')
|
|
||||||
elif hasattr(proj, 'output_size'):
|
|
||||||
logger.info("DEBUG: %s: type=%s, in=%s, out=%s, out_per_part=%s",
|
|
||||||
proj_name, type(proj).__name__,
|
|
||||||
getattr(proj, 'input_size', 'N/A'),
|
|
||||||
getattr(proj, 'output_size', 'N/A'),
|
|
||||||
getattr(proj, 'output_size_per_partition', 'N/A'))
|
|
||||||
break
|
|
||||||
|
|
||||||
# Print model-level rotary_emb
|
|
||||||
model_rotary = getattr(self.model, 'rotary_emb', None)
|
|
||||||
if model_rotary:
|
|
||||||
logger.info("DEBUG: Model-level rotary_emb: %s", type(model_rotary).__name__)
|
|
||||||
if hasattr(model_rotary, 'inv_freq'):
|
|
||||||
logger.info("DEBUG: Model rotary_emb.inv_freq.shape: %s", model_rotary.inv_freq.shape)
|
|
||||||
else:
|
|
||||||
logger.info("DEBUG: No model-level rotary_emb")
|
|
||||||
# Check nested
|
|
||||||
for name, module in self.model.named_modules():
|
|
||||||
if "RotaryEmbedding" in module.__class__.__name__:
|
|
||||||
inv_freq_shape = module.inv_freq.shape if hasattr(module, 'inv_freq') else 'N/A'
|
|
||||||
logger.info("DEBUG: Found rotary at %s: inv_freq.shape=%s", name, inv_freq_shape)
|
|
||||||
break
|
|
||||||
|
|
||||||
# Print config details
|
|
||||||
for attr in ['head_dim', 'hidden_size', 'num_attention_heads', 'num_key_value_heads',
|
|
||||||
'intermediate_size', 'num_hidden_layers']:
|
|
||||||
logger.info("DEBUG: config.%s = %s", attr, getattr(self.config, attr, 'NOT SET'))
|
|
||||||
|
|
||||||
logger.info("DEBUG: === End debug ===")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
**model_inputs,
|
**model_inputs,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ module replacement functions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -123,23 +123,102 @@ def replace_linear_class(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersRMSNorm(RMSNorm):
|
||||||
|
"""
|
||||||
|
vLLM RMSNorm subclass that preserves tensor dimensions.
|
||||||
|
|
||||||
|
vLLM's RMSNorm (especially the MLU backend) flattens input to 2D
|
||||||
|
(e.g., [batch, seq, hidden] -> [batch*seq, hidden]), but transformers
|
||||||
|
expects the batch dimension to be preserved. This subclass wraps
|
||||||
|
the parent forward methods to save and restore the original tensor shape.
|
||||||
|
|
||||||
|
Since this inherits from RMSNorm directly, weight loading via
|
||||||
|
named_parameters() works correctly (weight path stays the same).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward_native(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
result = super().forward_native(x, residual)
|
||||||
|
return self._restore_shape(result, orig_shape)
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
result = super().forward_cuda(x, residual)
|
||||||
|
return self._restore_shape(result, orig_shape)
|
||||||
|
|
||||||
|
def forward_mlu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
result = super().forward_mlu(x, residual)
|
||||||
|
return self._restore_shape(result, orig_shape)
|
||||||
|
|
||||||
|
def forward_xpu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
result = super().forward_xpu(x, residual)
|
||||||
|
return self._restore_shape(result, orig_shape)
|
||||||
|
|
||||||
|
def forward_hpu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
result = super().forward_hpu(x, residual)
|
||||||
|
return self._restore_shape(result, orig_shape)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _restore_shape(result, orig_shape: Tuple):
|
||||||
|
"""Restore original tensor shape if it was changed."""
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
restored = []
|
||||||
|
for t in result:
|
||||||
|
if t is not None and t.shape != orig_shape:
|
||||||
|
t = t.view(orig_shape)
|
||||||
|
restored.append(t)
|
||||||
|
return tuple(restored)
|
||||||
|
else:
|
||||||
|
if result.shape != orig_shape:
|
||||||
|
result = result.view(orig_shape)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def replace_rms_norm_class(
|
def replace_rms_norm_class(
|
||||||
rms_norm: nn.Module,
|
rms_norm: nn.Module,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
) -> RMSNorm:
|
) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm.
|
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm,
|
||||||
|
wrapped to preserve tensor dimensions.
|
||||||
|
|
||||||
vLLM's RMSNorm provides:
|
vLLM's RMSNorm provides:
|
||||||
- Fused CUDA kernels for better performance
|
- Fused CUDA kernels for better performance
|
||||||
- Support for fused add + norm operations
|
- Support for fused add + norm operations
|
||||||
|
|
||||||
|
The wrapper ensures that the original tensor shape (including batch
|
||||||
|
dimension) is preserved, which is required by transformers' model
|
||||||
|
forward methods.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
rms_norm: The RMSNorm module to replace.
|
rms_norm: The RMSNorm module to replace.
|
||||||
hidden_size: The hidden size of the model.
|
hidden_size: The hidden size of the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The new vLLM RMSNorm layer.
|
The new vLLM RMSNorm layer wrapped for shape preservation.
|
||||||
"""
|
"""
|
||||||
# Try to get epsilon from various attribute names
|
# Try to get epsilon from various attribute names
|
||||||
eps = getattr(rms_norm, "eps", None)
|
eps = getattr(rms_norm, "eps", None)
|
||||||
@@ -153,7 +232,7 @@ def replace_rms_norm_class(
|
|||||||
if weight is not None:
|
if weight is not None:
|
||||||
hidden_size = weight.size(0)
|
hidden_size = weight.size(0)
|
||||||
|
|
||||||
return RMSNorm(hidden_size=hidden_size, eps=eps)
|
return TransformersRMSNorm(hidden_size=hidden_size, eps=eps)
|
||||||
|
|
||||||
|
|
||||||
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
||||||
|
|||||||
Reference in New Issue
Block a user