# SPDX-License-Identifier: Apache-2.0 # Copyright 2024 The vLLM team. """Transformers modeling backend utilities for v0.6.2. This module provides utility functions for the Transformers backend, including context managers for meta device initialization and module replacement functions. """ from contextlib import contextmanager from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union import torch import torch.nn as nn from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, ReplicatedLinear, RowParallelLinear, ) if TYPE_CHECKING: from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, ) logger = init_logger(__name__) @contextmanager def init_on_device_without_buffers(device: Union[str, torch.device]): """ A context manager under which models are initialized with all parameters on the specified device. However buffers are not initialized on specified device. This is useful for creating model structure without allocating GPU memory, which is essential for memory efficiency. Args: device: Device to initialize all parameters on (e.g., "meta"). Example: with init_on_device_without_buffers("meta"): model = AutoModel.from_config(config) # Now model is on meta device, no GPU memory allocated """ if isinstance(device, str): device = torch.device(device) old_register_parameter = nn.Module.register_parameter def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) if param is not None: param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad module._parameters[name] = param_cls( module._parameters[name].to(device), **kwargs ) try: nn.Module.register_parameter = register_empty_parameter yield finally: nn.Module.register_parameter = old_register_parameter # Linear replacement styles Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] def replace_linear_class( linear: nn.Linear, style: Style = "replicate", quant_config: Optional["QuantizationConfig"] = None, prefix: str = "", ) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. This replacement provides: - Memory efficiency through proper tensor allocation - Support for quantization - Tensor parallel support (when using ColumnParallel/RowParallel) Args: linear: `nn.Linear` to be replaced. style: Tensor parallel style of the new linear: - "colwise": Column parallel (split output dim) - "colwise_rep": Column parallel with gather output - "rowwise": Row parallel (split input dim) - "rowwise_rep": Row parallel without parallel input - "replicate": Replicated (no parallelism) quant_config: Quantization config for the new linear. prefix: The name of the layer for weight loading. Returns: The new vLLM linear layer. """ if not isinstance(style, str): raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") vllm_linear_cls, vllm_linear_kwargs = { "colwise": (ColumnParallelLinear, {}), "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), "rowwise": (RowParallelLinear, {}), "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), "replicate": (ReplicatedLinear, {}), }.get(style, (ReplicatedLinear, {})) return vllm_linear_cls( input_size=linear.in_features, output_size=linear.out_features, bias=linear.bias is not None, quant_config=quant_config, prefix=prefix, return_bias=False, # Return tensor only, not (tensor, bias) tuple **vllm_linear_kwargs, ) class TransformersRMSNorm(RMSNorm): """ vLLM RMSNorm subclass that preserves tensor dimensions. vLLM's RMSNorm (especially the MLU backend) flattens input to 2D (e.g., [batch, seq, hidden] -> [batch*seq, hidden]), but transformers expects the batch dimension to be preserved. This subclass wraps the parent forward methods to save and restore the original tensor shape. Since this inherits from RMSNorm directly, weight loading via named_parameters() works correctly (weight path stays the same). """ def forward_native( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ): orig_shape = x.shape result = super().forward_native(x, residual) return self._restore_shape(result, orig_shape) def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ): orig_shape = x.shape result = super().forward_cuda(x, residual) return self._restore_shape(result, orig_shape) def forward_mlu( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ): orig_shape = x.shape result = super().forward_mlu(x, residual) return self._restore_shape(result, orig_shape) def forward_xpu( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ): orig_shape = x.shape result = super().forward_xpu(x, residual) return self._restore_shape(result, orig_shape) def forward_hpu( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ): orig_shape = x.shape result = super().forward_hpu(x, residual) return self._restore_shape(result, orig_shape) @staticmethod def _restore_shape(result, orig_shape: Tuple): """Restore original tensor shape if it was changed.""" if isinstance(result, tuple): restored = [] for t in result: if t is not None and t.shape != orig_shape: t = t.view(orig_shape) restored.append(t) return tuple(restored) else: if result.shape != orig_shape: result = result.view(orig_shape) return result def replace_rms_norm_class( rms_norm: nn.Module, hidden_size: int, ) -> nn.Module: """ Replace a Transformers RMSNorm with vLLM's optimized RMSNorm, wrapped to preserve tensor dimensions. vLLM's RMSNorm provides: - Fused CUDA kernels for better performance - Support for fused add + norm operations The wrapper ensures that the original tensor shape (including batch dimension) is preserved, which is required by transformers' model forward methods. Args: rms_norm: The RMSNorm module to replace. hidden_size: The hidden size of the model. Returns: The new vLLM RMSNorm layer wrapped for shape preservation. """ # Try to get epsilon from various attribute names eps = getattr(rms_norm, "eps", None) if eps is None: eps = getattr(rms_norm, "variance_epsilon", None) if eps is None: eps = 1e-6 # Check if weight exists and get its size weight = getattr(rms_norm, "weight", None) if weight is not None: hidden_size = weight.size(0) return TransformersRMSNorm(hidden_size=hidden_size, eps=eps) def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): """Log module replacement for debugging.""" logger.debug("Replaced %s: %s -> %s", name, type(old_module).__name__, type(new_module).__name__) def maybe_prefix(prefix: str, name: str) -> str: """Combine prefix and name with a dot separator.""" if prefix: return f"{prefix}.{name}" return name