testing dynamic register
This commit is contained in:
@@ -8,7 +8,7 @@ module replacement functions.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -123,23 +123,102 @@ def replace_linear_class(
|
||||
)
|
||||
|
||||
|
||||
class TransformersRMSNorm(RMSNorm):
|
||||
"""
|
||||
vLLM RMSNorm subclass that preserves tensor dimensions.
|
||||
|
||||
vLLM's RMSNorm (especially the MLU backend) flattens input to 2D
|
||||
(e.g., [batch, seq, hidden] -> [batch*seq, hidden]), but transformers
|
||||
expects the batch dimension to be preserved. This subclass wraps
|
||||
the parent forward methods to save and restore the original tensor shape.
|
||||
|
||||
Since this inherits from RMSNorm directly, weight loading via
|
||||
named_parameters() works correctly (weight path stays the same).
|
||||
"""
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
result = super().forward_native(x, residual)
|
||||
return self._restore_shape(result, orig_shape)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
result = super().forward_cuda(x, residual)
|
||||
return self._restore_shape(result, orig_shape)
|
||||
|
||||
def forward_mlu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
result = super().forward_mlu(x, residual)
|
||||
return self._restore_shape(result, orig_shape)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
result = super().forward_xpu(x, residual)
|
||||
return self._restore_shape(result, orig_shape)
|
||||
|
||||
def forward_hpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
result = super().forward_hpu(x, residual)
|
||||
return self._restore_shape(result, orig_shape)
|
||||
|
||||
@staticmethod
|
||||
def _restore_shape(result, orig_shape: Tuple):
|
||||
"""Restore original tensor shape if it was changed."""
|
||||
if isinstance(result, tuple):
|
||||
restored = []
|
||||
for t in result:
|
||||
if t is not None and t.shape != orig_shape:
|
||||
t = t.view(orig_shape)
|
||||
restored.append(t)
|
||||
return tuple(restored)
|
||||
else:
|
||||
if result.shape != orig_shape:
|
||||
result = result.view(orig_shape)
|
||||
return result
|
||||
|
||||
|
||||
def replace_rms_norm_class(
|
||||
rms_norm: nn.Module,
|
||||
hidden_size: int,
|
||||
) -> RMSNorm:
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm.
|
||||
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm,
|
||||
wrapped to preserve tensor dimensions.
|
||||
|
||||
vLLM's RMSNorm provides:
|
||||
- Fused CUDA kernels for better performance
|
||||
- Support for fused add + norm operations
|
||||
|
||||
The wrapper ensures that the original tensor shape (including batch
|
||||
dimension) is preserved, which is required by transformers' model
|
||||
forward methods.
|
||||
|
||||
Args:
|
||||
rms_norm: The RMSNorm module to replace.
|
||||
hidden_size: The hidden size of the model.
|
||||
|
||||
Returns:
|
||||
The new vLLM RMSNorm layer.
|
||||
The new vLLM RMSNorm layer wrapped for shape preservation.
|
||||
"""
|
||||
# Try to get epsilon from various attribute names
|
||||
eps = getattr(rms_norm, "eps", None)
|
||||
@@ -153,7 +232,7 @@ def replace_rms_norm_class(
|
||||
if weight is not None:
|
||||
hidden_size = weight.size(0)
|
||||
|
||||
return RMSNorm(hidden_size=hidden_size, eps=eps)
|
||||
return TransformersRMSNorm(hidden_size=hidden_size, eps=eps)
|
||||
|
||||
|
||||
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user