forked from EngineX-Cambricon/enginex-mlu370-vllm
169 lines
5.4 KiB
Python
169 lines
5.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# Copyright 2024 The vLLM team.
|
|
"""Transformers modeling backend utilities for v0.6.2.
|
|
|
|
This module provides utility functions for the Transformers backend,
|
|
including context managers for meta device initialization and
|
|
module replacement functions.
|
|
"""
|
|
|
|
from contextlib import contextmanager
|
|
from typing import TYPE_CHECKING, Literal, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@contextmanager
|
|
def init_on_device_without_buffers(device: Union[str, torch.device]):
|
|
"""
|
|
A context manager under which models are initialized with all
|
|
parameters on the specified device. However buffers are not
|
|
initialized on specified device.
|
|
|
|
This is useful for creating model structure without allocating
|
|
GPU memory, which is essential for memory efficiency.
|
|
|
|
Args:
|
|
device: Device to initialize all parameters on (e.g., "meta").
|
|
|
|
Example:
|
|
with init_on_device_without_buffers("meta"):
|
|
model = AutoModel.from_config(config)
|
|
# Now model is on meta device, no GPU memory allocated
|
|
"""
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
|
|
old_register_parameter = nn.Module.register_parameter
|
|
|
|
def register_empty_parameter(module, name, param):
|
|
old_register_parameter(module, name, param)
|
|
if param is not None:
|
|
param_cls = type(module._parameters[name])
|
|
kwargs = module._parameters[name].__dict__
|
|
kwargs["requires_grad"] = param.requires_grad
|
|
module._parameters[name] = param_cls(
|
|
module._parameters[name].to(device), **kwargs
|
|
)
|
|
|
|
try:
|
|
nn.Module.register_parameter = register_empty_parameter
|
|
yield
|
|
finally:
|
|
nn.Module.register_parameter = old_register_parameter
|
|
|
|
|
|
# Linear replacement styles
|
|
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
|
|
|
|
|
|
def replace_linear_class(
|
|
linear: nn.Linear,
|
|
style: Style = "replicate",
|
|
quant_config: Optional["QuantizationConfig"] = None,
|
|
prefix: str = "",
|
|
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
|
|
"""
|
|
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
|
|
|
This replacement provides:
|
|
- Memory efficiency through proper tensor allocation
|
|
- Support for quantization
|
|
- Tensor parallel support (when using ColumnParallel/RowParallel)
|
|
|
|
Args:
|
|
linear: `nn.Linear` to be replaced.
|
|
style: Tensor parallel style of the new linear:
|
|
- "colwise": Column parallel (split output dim)
|
|
- "colwise_rep": Column parallel with gather output
|
|
- "rowwise": Row parallel (split input dim)
|
|
- "rowwise_rep": Row parallel without parallel input
|
|
- "replicate": Replicated (no parallelism)
|
|
quant_config: Quantization config for the new linear.
|
|
prefix: The name of the layer for weight loading.
|
|
|
|
Returns:
|
|
The new vLLM linear layer.
|
|
"""
|
|
if not isinstance(style, str):
|
|
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
|
|
|
|
vllm_linear_cls, vllm_linear_kwargs = {
|
|
"colwise": (ColumnParallelLinear, {}),
|
|
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
|
|
"rowwise": (RowParallelLinear, {}),
|
|
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
|
|
"replicate": (ReplicatedLinear, {}),
|
|
}.get(style, (ReplicatedLinear, {}))
|
|
|
|
return vllm_linear_cls(
|
|
input_size=linear.in_features,
|
|
output_size=linear.out_features,
|
|
bias=linear.bias is not None,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
return_bias=False, # Return tensor only, not (tensor, bias) tuple
|
|
**vllm_linear_kwargs,
|
|
)
|
|
|
|
|
|
def replace_rms_norm_class(
|
|
rms_norm: nn.Module,
|
|
hidden_size: int,
|
|
) -> RMSNorm:
|
|
"""
|
|
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm.
|
|
|
|
vLLM's RMSNorm provides:
|
|
- Fused CUDA kernels for better performance
|
|
- Support for fused add + norm operations
|
|
|
|
Args:
|
|
rms_norm: The RMSNorm module to replace.
|
|
hidden_size: The hidden size of the model.
|
|
|
|
Returns:
|
|
The new vLLM RMSNorm layer.
|
|
"""
|
|
# Try to get epsilon from various attribute names
|
|
eps = getattr(rms_norm, "eps", None)
|
|
if eps is None:
|
|
eps = getattr(rms_norm, "variance_epsilon", None)
|
|
if eps is None:
|
|
eps = 1e-6
|
|
|
|
# Check if weight exists and get its size
|
|
weight = getattr(rms_norm, "weight", None)
|
|
if weight is not None:
|
|
hidden_size = weight.size(0)
|
|
|
|
return RMSNorm(hidden_size=hidden_size, eps=eps)
|
|
|
|
|
|
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
|
"""Log module replacement for debugging."""
|
|
logger.debug("Replaced %s: %s -> %s", name, type(old_module).__name__, type(new_module).__name__)
|
|
|
|
|
|
def maybe_prefix(prefix: str, name: str) -> str:
|
|
"""Combine prefix and name with a dot separator."""
|
|
if prefix:
|
|
return f"{prefix}.{name}"
|
|
return name
|