# SPDX-License-Identifier: Apache-2.0 # Copyright 2024 The vLLM team. """Transformers modeling backend utilities for v0.6.2. This module provides utility functions for the Transformers backend, including context managers for meta device initialization and module replacement functions. """ from contextlib import contextmanager from typing import TYPE_CHECKING, Literal, Optional, Union import torch import torch.nn as nn from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, ReplicatedLinear, RowParallelLinear, ) if TYPE_CHECKING: from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, ) logger = init_logger(__name__) @contextmanager def init_on_device_without_buffers(device: Union[str, torch.device]): """ A context manager under which models are initialized with all parameters on the specified device. However buffers are not initialized on specified device. This is useful for creating model structure without allocating GPU memory, which is essential for memory efficiency. Args: device: Device to initialize all parameters on (e.g., "meta"). Example: with init_on_device_without_buffers("meta"): model = AutoModel.from_config(config) # Now model is on meta device, no GPU memory allocated """ if isinstance(device, str): device = torch.device(device) old_register_parameter = nn.Module.register_parameter def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) if param is not None: param_cls = type(module._parameters[name]) kwargs = module._parameters[name].__dict__ kwargs["requires_grad"] = param.requires_grad module._parameters[name] = param_cls( module._parameters[name].to(device), **kwargs ) try: nn.Module.register_parameter = register_empty_parameter yield finally: nn.Module.register_parameter = old_register_parameter # Linear replacement styles Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] def replace_linear_class( linear: nn.Linear, style: Style = "replicate", quant_config: Optional["QuantizationConfig"] = None, prefix: str = "", ) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. This replacement provides: - Memory efficiency through proper tensor allocation - Support for quantization - Tensor parallel support (when using ColumnParallel/RowParallel) Args: linear: `nn.Linear` to be replaced. style: Tensor parallel style of the new linear: - "colwise": Column parallel (split output dim) - "colwise_rep": Column parallel with gather output - "rowwise": Row parallel (split input dim) - "rowwise_rep": Row parallel without parallel input - "replicate": Replicated (no parallelism) quant_config: Quantization config for the new linear. prefix: The name of the layer for weight loading. Returns: The new vLLM linear layer. """ if not isinstance(style, str): raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") vllm_linear_cls, vllm_linear_kwargs = { "colwise": (ColumnParallelLinear, {}), "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), "rowwise": (RowParallelLinear, {}), "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), "replicate": (ReplicatedLinear, {}), }.get(style, (ReplicatedLinear, {})) return vllm_linear_cls( input_size=linear.in_features, output_size=linear.out_features, bias=linear.bias is not None, quant_config=quant_config, prefix=prefix, **vllm_linear_kwargs, ) def replace_rms_norm_class( rms_norm: nn.Module, hidden_size: int, ) -> RMSNorm: """ Replace a Transformers RMSNorm with vLLM's optimized RMSNorm. vLLM's RMSNorm provides: - Fused CUDA kernels for better performance - Support for fused add + norm operations Args: rms_norm: The RMSNorm module to replace. hidden_size: The hidden size of the model. Returns: The new vLLM RMSNorm layer. """ # Try to get epsilon from various attribute names eps = getattr(rms_norm, "eps", None) if eps is None: eps = getattr(rms_norm, "variance_epsilon", None) if eps is None: eps = 1e-6 # Check if weight exists and get its size weight = getattr(rms_norm, "weight", None) if weight is not None: hidden_size = weight.size(0) return RMSNorm(hidden_size=hidden_size, eps=eps) def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): """Log module replacement for debugging.""" logger.debug("Replaced %s: %s -> %s", name, type(old_module).__name__, type(new_module).__name__) def maybe_prefix(prefix: str, name: str) -> str: """Combine prefix and name with a dot separator.""" if prefix: return f"{prefix}.{name}" return name