Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm/model_executor/models/transformers/utils.py
2026-02-06 14:17:06 +08:00

248 lines
7.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend utilities for v0.6.2.
This module provides utility functions for the Transformers backend,
including context managers for meta device initialization and
module replacement functions.
"""
from contextlib import contextmanager
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
logger = init_logger(__name__)
@contextmanager
def init_on_device_without_buffers(device: Union[str, torch.device]):
"""
A context manager under which models are initialized with all
parameters on the specified device. However buffers are not
initialized on specified device.
This is useful for creating model structure without allocating
GPU memory, which is essential for memory efficiency.
Args:
device: Device to initialize all parameters on (e.g., "meta").
Example:
with init_on_device_without_buffers("meta"):
model = AutoModel.from_config(config)
# Now model is on meta device, no GPU memory allocated
"""
if isinstance(device, str):
device = torch.device(device)
old_register_parameter = nn.Module.register_parameter
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs
)
try:
nn.Module.register_parameter = register_empty_parameter
yield
finally:
nn.Module.register_parameter = old_register_parameter
# Linear replacement styles
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
def replace_linear_class(
linear: nn.Linear,
style: Style = "replicate",
quant_config: Optional["QuantizationConfig"] = None,
prefix: str = "",
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
This replacement provides:
- Memory efficiency through proper tensor allocation
- Support for quantization
- Tensor parallel support (when using ColumnParallel/RowParallel)
Args:
linear: `nn.Linear` to be replaced.
style: Tensor parallel style of the new linear:
- "colwise": Column parallel (split output dim)
- "colwise_rep": Column parallel with gather output
- "rowwise": Row parallel (split input dim)
- "rowwise_rep": Row parallel without parallel input
- "replicate": Replicated (no parallelism)
quant_config: Quantization config for the new linear.
prefix: The name of the layer for weight loading.
Returns:
The new vLLM linear layer.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
"replicate": (ReplicatedLinear, {}),
}.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
prefix=prefix,
return_bias=False, # Return tensor only, not (tensor, bias) tuple
**vllm_linear_kwargs,
)
class TransformersRMSNorm(RMSNorm):
"""
vLLM RMSNorm subclass that preserves tensor dimensions.
vLLM's RMSNorm (especially the MLU backend) flattens input to 2D
(e.g., [batch, seq, hidden] -> [batch*seq, hidden]), but transformers
expects the batch dimension to be preserved. This subclass wraps
the parent forward methods to save and restore the original tensor shape.
Since this inherits from RMSNorm directly, weight loading via
named_parameters() works correctly (weight path stays the same).
"""
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_native(x, residual)
return self._restore_shape(result, orig_shape)
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_cuda(x, residual)
return self._restore_shape(result, orig_shape)
def forward_mlu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_mlu(x, residual)
return self._restore_shape(result, orig_shape)
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_xpu(x, residual)
return self._restore_shape(result, orig_shape)
def forward_hpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
orig_shape = x.shape
result = super().forward_hpu(x, residual)
return self._restore_shape(result, orig_shape)
@staticmethod
def _restore_shape(result, orig_shape: Tuple):
"""Restore original tensor shape if it was changed."""
if isinstance(result, tuple):
restored = []
for t in result:
if t is not None and t.shape != orig_shape:
t = t.view(orig_shape)
restored.append(t)
return tuple(restored)
else:
if result.shape != orig_shape:
result = result.view(orig_shape)
return result
def replace_rms_norm_class(
rms_norm: nn.Module,
hidden_size: int,
) -> nn.Module:
"""
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm,
wrapped to preserve tensor dimensions.
vLLM's RMSNorm provides:
- Fused CUDA kernels for better performance
- Support for fused add + norm operations
The wrapper ensures that the original tensor shape (including batch
dimension) is preserved, which is required by transformers' model
forward methods.
Args:
rms_norm: The RMSNorm module to replace.
hidden_size: The hidden size of the model.
Returns:
The new vLLM RMSNorm layer wrapped for shape preservation.
"""
# Try to get epsilon from various attribute names
eps = getattr(rms_norm, "eps", None)
if eps is None:
eps = getattr(rms_norm, "variance_epsilon", None)
if eps is None:
eps = 1e-6
# Check if weight exists and get its size
weight = getattr(rms_norm, "weight", None)
if weight is not None:
hidden_size = weight.size(0)
return TransformersRMSNorm(hidden_size=hidden_size, eps=eps)
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
"""Log module replacement for debugging."""
logger.debug("Replaced %s: %s -> %s", name, type(old_module).__name__, type(new_module).__name__)
def maybe_prefix(prefix: str, name: str) -> str:
"""Combine prefix and name with a dot separator."""
if prefix:
return f"{prefix}.{name}"
return name