248 lines
7.9 KiB
Python
248 lines
7.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# Copyright 2024 The vLLM team.
|
|
"""Transformers modeling backend utilities for v0.6.2.
|
|
|
|
This module provides utility functions for the Transformers backend,
|
|
including context managers for meta device initialization and
|
|
module replacement functions.
|
|
"""
|
|
|
|
from contextlib import contextmanager
|
|
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@contextmanager
|
|
def init_on_device_without_buffers(device: Union[str, torch.device]):
|
|
"""
|
|
A context manager under which models are initialized with all
|
|
parameters on the specified device. However buffers are not
|
|
initialized on specified device.
|
|
|
|
This is useful for creating model structure without allocating
|
|
GPU memory, which is essential for memory efficiency.
|
|
|
|
Args:
|
|
device: Device to initialize all parameters on (e.g., "meta").
|
|
|
|
Example:
|
|
with init_on_device_without_buffers("meta"):
|
|
model = AutoModel.from_config(config)
|
|
# Now model is on meta device, no GPU memory allocated
|
|
"""
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
|
|
old_register_parameter = nn.Module.register_parameter
|
|
|
|
def register_empty_parameter(module, name, param):
|
|
old_register_parameter(module, name, param)
|
|
if param is not None:
|
|
param_cls = type(module._parameters[name])
|
|
kwargs = module._parameters[name].__dict__
|
|
kwargs["requires_grad"] = param.requires_grad
|
|
module._parameters[name] = param_cls(
|
|
module._parameters[name].to(device), **kwargs
|
|
)
|
|
|
|
try:
|
|
nn.Module.register_parameter = register_empty_parameter
|
|
yield
|
|
finally:
|
|
nn.Module.register_parameter = old_register_parameter
|
|
|
|
|
|
# Linear replacement styles
|
|
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
|
|
|
|
|
|
def replace_linear_class(
|
|
linear: nn.Linear,
|
|
style: Style = "replicate",
|
|
quant_config: Optional["QuantizationConfig"] = None,
|
|
prefix: str = "",
|
|
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
|
|
"""
|
|
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
|
|
|
This replacement provides:
|
|
- Memory efficiency through proper tensor allocation
|
|
- Support for quantization
|
|
- Tensor parallel support (when using ColumnParallel/RowParallel)
|
|
|
|
Args:
|
|
linear: `nn.Linear` to be replaced.
|
|
style: Tensor parallel style of the new linear:
|
|
- "colwise": Column parallel (split output dim)
|
|
- "colwise_rep": Column parallel with gather output
|
|
- "rowwise": Row parallel (split input dim)
|
|
- "rowwise_rep": Row parallel without parallel input
|
|
- "replicate": Replicated (no parallelism)
|
|
quant_config: Quantization config for the new linear.
|
|
prefix: The name of the layer for weight loading.
|
|
|
|
Returns:
|
|
The new vLLM linear layer.
|
|
"""
|
|
if not isinstance(style, str):
|
|
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
|
|
|
|
vllm_linear_cls, vllm_linear_kwargs = {
|
|
"colwise": (ColumnParallelLinear, {}),
|
|
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
|
|
"rowwise": (RowParallelLinear, {}),
|
|
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
|
|
"replicate": (ReplicatedLinear, {}),
|
|
}.get(style, (ReplicatedLinear, {}))
|
|
|
|
return vllm_linear_cls(
|
|
input_size=linear.in_features,
|
|
output_size=linear.out_features,
|
|
bias=linear.bias is not None,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
return_bias=False, # Return tensor only, not (tensor, bias) tuple
|
|
**vllm_linear_kwargs,
|
|
)
|
|
|
|
|
|
class TransformersRMSNorm(RMSNorm):
|
|
"""
|
|
vLLM RMSNorm subclass that preserves tensor dimensions.
|
|
|
|
vLLM's RMSNorm (especially the MLU backend) flattens input to 2D
|
|
(e.g., [batch, seq, hidden] -> [batch*seq, hidden]), but transformers
|
|
expects the batch dimension to be preserved. This subclass wraps
|
|
the parent forward methods to save and restore the original tensor shape.
|
|
|
|
Since this inherits from RMSNorm directly, weight loading via
|
|
named_parameters() works correctly (weight path stays the same).
|
|
"""
|
|
|
|
def forward_native(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
):
|
|
orig_shape = x.shape
|
|
result = super().forward_native(x, residual)
|
|
return self._restore_shape(result, orig_shape)
|
|
|
|
def forward_cuda(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
):
|
|
orig_shape = x.shape
|
|
result = super().forward_cuda(x, residual)
|
|
return self._restore_shape(result, orig_shape)
|
|
|
|
def forward_mlu(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
):
|
|
orig_shape = x.shape
|
|
result = super().forward_mlu(x, residual)
|
|
return self._restore_shape(result, orig_shape)
|
|
|
|
def forward_xpu(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
):
|
|
orig_shape = x.shape
|
|
result = super().forward_xpu(x, residual)
|
|
return self._restore_shape(result, orig_shape)
|
|
|
|
def forward_hpu(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
):
|
|
orig_shape = x.shape
|
|
result = super().forward_hpu(x, residual)
|
|
return self._restore_shape(result, orig_shape)
|
|
|
|
@staticmethod
|
|
def _restore_shape(result, orig_shape: Tuple):
|
|
"""Restore original tensor shape if it was changed."""
|
|
if isinstance(result, tuple):
|
|
restored = []
|
|
for t in result:
|
|
if t is not None and t.shape != orig_shape:
|
|
t = t.view(orig_shape)
|
|
restored.append(t)
|
|
return tuple(restored)
|
|
else:
|
|
if result.shape != orig_shape:
|
|
result = result.view(orig_shape)
|
|
return result
|
|
|
|
|
|
def replace_rms_norm_class(
|
|
rms_norm: nn.Module,
|
|
hidden_size: int,
|
|
) -> nn.Module:
|
|
"""
|
|
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm,
|
|
wrapped to preserve tensor dimensions.
|
|
|
|
vLLM's RMSNorm provides:
|
|
- Fused CUDA kernels for better performance
|
|
- Support for fused add + norm operations
|
|
|
|
The wrapper ensures that the original tensor shape (including batch
|
|
dimension) is preserved, which is required by transformers' model
|
|
forward methods.
|
|
|
|
Args:
|
|
rms_norm: The RMSNorm module to replace.
|
|
hidden_size: The hidden size of the model.
|
|
|
|
Returns:
|
|
The new vLLM RMSNorm layer wrapped for shape preservation.
|
|
"""
|
|
# Try to get epsilon from various attribute names
|
|
eps = getattr(rms_norm, "eps", None)
|
|
if eps is None:
|
|
eps = getattr(rms_norm, "variance_epsilon", None)
|
|
if eps is None:
|
|
eps = 1e-6
|
|
|
|
# Check if weight exists and get its size
|
|
weight = getattr(rms_norm, "weight", None)
|
|
if weight is not None:
|
|
hidden_size = weight.size(0)
|
|
|
|
return TransformersRMSNorm(hidden_size=hidden_size, eps=eps)
|
|
|
|
|
|
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
|
"""Log module replacement for debugging."""
|
|
logger.debug("Replaced %s: %s -> %s", name, type(old_module).__name__, type(new_module).__name__)
|
|
|
|
|
|
def maybe_prefix(prefix: str, name: str) -> str:
|
|
"""Combine prefix and name with a dot separator."""
|
|
if prefix:
|
|
return f"{prefix}.{name}"
|
|
return name
|