testing dynamic register

This commit is contained in:
Chranos
2026-02-06 14:17:06 +08:00
parent fba02652c8
commit b702adf015
2 changed files with 93 additions and 186 deletions

View File

@@ -8,7 +8,7 @@ module replacement functions.
"""
from contextlib import contextmanager
from typing import TYPE_CHECKING, Literal, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -123,23 +123,102 @@ def replace_linear_class(
)
class TransformersRMSNorm(RMSNorm):
"""
vLLM RMSNorm subclass that preserves tensor dimensions.
vLLM's RMSNorm (especially the MLU backend) flattens input to 2D
(e.g., [batch, seq, hidden] -> [batch*seq, hidden]), but transformers
expects the batch dimension to be preserved. This subclass wraps
the parent forward methods to save and restore the original tensor shape.
Since this inherits from RMSNorm directly, weight loading via
named_parameters() works correctly (weight path stays the same).
"""
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_native(x, residual)
return self._restore_shape(result, orig_shape)
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_cuda(x, residual)
return self._restore_shape(result, orig_shape)
def forward_mlu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_mlu(x, residual)
return self._restore_shape(result, orig_shape)
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_xpu(x, residual)
return self._restore_shape(result, orig_shape)
def forward_hpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_hpu(x, residual)
return self._restore_shape(result, orig_shape)
@staticmethod
def _restore_shape(result, orig_shape: Tuple):
"""Restore original tensor shape if it was changed."""
if isinstance(result, tuple):
restored = []
for t in result:
if t is not None and t.shape != orig_shape:
t = t.view(orig_shape)
restored.append(t)
return tuple(restored)
else:
if result.shape != orig_shape:
result = result.view(orig_shape)
return result
def replace_rms_norm_class(
rms_norm: nn.Module,
hidden_size: int,
) -> RMSNorm:
) -> nn.Module:
"""
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm.
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm,
wrapped to preserve tensor dimensions.
vLLM's RMSNorm provides:
- Fused CUDA kernels for better performance
- Support for fused add + norm operations
The wrapper ensures that the original tensor shape (including batch
dimension) is preserved, which is required by transformers' model
forward methods.
Args:
rms_norm: The RMSNorm module to replace.
hidden_size: The hidden size of the model.
Returns:
The new vLLM RMSNorm layer.
The new vLLM RMSNorm layer wrapped for shape preservation.
"""
# Try to get epsilon from various attribute names
eps = getattr(rms_norm, "eps", None)
@@ -153,7 +232,7 @@ def replace_rms_norm_class(
if weight is not None:
hidden_size = weight.size(0)
return RMSNorm(hidden_size=hidden_size, eps=eps)
return TransformersRMSNorm(hidden_size=hidden_size, eps=eps)
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):