testing dynamic register
This commit is contained in:
167
vllm-v0.6.2/vllm/model_executor/models/transformers/utils.py
Normal file
167
vllm-v0.6.2/vllm/model_executor/models/transformers/utils.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright 2024 The vLLM team.
|
||||
"""Transformers modeling backend utilities for v0.6.2.
|
||||
|
||||
This module provides utility functions for the Transformers backend,
|
||||
including context managers for meta device initialization and
|
||||
module replacement functions.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def init_on_device_without_buffers(device: Union[str, torch.device]):
|
||||
"""
|
||||
A context manager under which models are initialized with all
|
||||
parameters on the specified device. However buffers are not
|
||||
initialized on specified device.
|
||||
|
||||
This is useful for creating model structure without allocating
|
||||
GPU memory, which is essential for memory efficiency.
|
||||
|
||||
Args:
|
||||
device: Device to initialize all parameters on (e.g., "meta").
|
||||
|
||||
Example:
|
||||
with init_on_device_without_buffers("meta"):
|
||||
model = AutoModel.from_config(config)
|
||||
# Now model is on meta device, no GPU memory allocated
|
||||
"""
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
old_register_parameter = nn.Module.register_parameter
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(
|
||||
module._parameters[name].to(device), **kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
nn.Module.register_parameter = register_empty_parameter
|
||||
yield
|
||||
finally:
|
||||
nn.Module.register_parameter = old_register_parameter
|
||||
|
||||
|
||||
# Linear replacement styles
|
||||
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
|
||||
|
||||
|
||||
def replace_linear_class(
|
||||
linear: nn.Linear,
|
||||
style: Style = "replicate",
|
||||
quant_config: Optional["QuantizationConfig"] = None,
|
||||
prefix: str = "",
|
||||
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
|
||||
"""
|
||||
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
||||
|
||||
This replacement provides:
|
||||
- Memory efficiency through proper tensor allocation
|
||||
- Support for quantization
|
||||
- Tensor parallel support (when using ColumnParallel/RowParallel)
|
||||
|
||||
Args:
|
||||
linear: `nn.Linear` to be replaced.
|
||||
style: Tensor parallel style of the new linear:
|
||||
- "colwise": Column parallel (split output dim)
|
||||
- "colwise_rep": Column parallel with gather output
|
||||
- "rowwise": Row parallel (split input dim)
|
||||
- "rowwise_rep": Row parallel without parallel input
|
||||
- "replicate": Replicated (no parallelism)
|
||||
quant_config: Quantization config for the new linear.
|
||||
prefix: The name of the layer for weight loading.
|
||||
|
||||
Returns:
|
||||
The new vLLM linear layer.
|
||||
"""
|
||||
if not isinstance(style, str):
|
||||
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
|
||||
|
||||
vllm_linear_cls, vllm_linear_kwargs = {
|
||||
"colwise": (ColumnParallelLinear, {}),
|
||||
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
|
||||
"rowwise": (RowParallelLinear, {}),
|
||||
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
|
||||
"replicate": (ReplicatedLinear, {}),
|
||||
}.get(style, (ReplicatedLinear, {}))
|
||||
|
||||
return vllm_linear_cls(
|
||||
input_size=linear.in_features,
|
||||
output_size=linear.out_features,
|
||||
bias=linear.bias is not None,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
**vllm_linear_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def replace_rms_norm_class(
|
||||
rms_norm: nn.Module,
|
||||
hidden_size: int,
|
||||
) -> RMSNorm:
|
||||
"""
|
||||
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm.
|
||||
|
||||
vLLM's RMSNorm provides:
|
||||
- Fused CUDA kernels for better performance
|
||||
- Support for fused add + norm operations
|
||||
|
||||
Args:
|
||||
rms_norm: The RMSNorm module to replace.
|
||||
hidden_size: The hidden size of the model.
|
||||
|
||||
Returns:
|
||||
The new vLLM RMSNorm layer.
|
||||
"""
|
||||
# Try to get epsilon from various attribute names
|
||||
eps = getattr(rms_norm, "eps", None)
|
||||
if eps is None:
|
||||
eps = getattr(rms_norm, "variance_epsilon", None)
|
||||
if eps is None:
|
||||
eps = 1e-6
|
||||
|
||||
# Check if weight exists and get its size
|
||||
weight = getattr(rms_norm, "weight", None)
|
||||
if weight is not None:
|
||||
hidden_size = weight.size(0)
|
||||
|
||||
return RMSNorm(hidden_size=hidden_size, eps=eps)
|
||||
|
||||
|
||||
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
||||
"""Log module replacement for debugging."""
|
||||
logger.debug("Replaced %s: %s -> %s", name, type(old_module).__name__, type(new_module).__name__)
|
||||
|
||||
|
||||
def maybe_prefix(prefix: str, name: str) -> str:
|
||||
"""Combine prefix and name with a dot separator."""
|
||||
if prefix:
|
||||
return f"{prefix}.{name}"
|
||||
return name
|
||||
Reference in New Issue
Block a user