testing dynamic register

This commit is contained in:
Chranos
2026-02-05 17:11:09 +08:00
parent b399840b8d
commit 6e38461af6
3 changed files with 657 additions and 98 deletions

View File

@@ -0,0 +1,167 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 The vLLM team.
"""Transformers modeling backend utilities for v0.6.2.
This module provides utility functions for the Transformers backend,
including context managers for meta device initialization and
module replacement functions.
"""
from contextlib import contextmanager
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)
logger = init_logger(__name__)
@contextmanager
def init_on_device_without_buffers(device: Union[str, torch.device]):
"""
A context manager under which models are initialized with all
parameters on the specified device. However buffers are not
initialized on specified device.
This is useful for creating model structure without allocating
GPU memory, which is essential for memory efficiency.
Args:
device: Device to initialize all parameters on (e.g., "meta").
Example:
with init_on_device_without_buffers("meta"):
model = AutoModel.from_config(config)
# Now model is on meta device, no GPU memory allocated
"""
if isinstance(device, str):
device = torch.device(device)
old_register_parameter = nn.Module.register_parameter
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs
)
try:
nn.Module.register_parameter = register_empty_parameter
yield
finally:
nn.Module.register_parameter = old_register_parameter
# Linear replacement styles
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
def replace_linear_class(
linear: nn.Linear,
style: Style = "replicate",
quant_config: Optional["QuantizationConfig"] = None,
prefix: str = "",
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
This replacement provides:
- Memory efficiency through proper tensor allocation
- Support for quantization
- Tensor parallel support (when using ColumnParallel/RowParallel)
Args:
linear: `nn.Linear` to be replaced.
style: Tensor parallel style of the new linear:
- "colwise": Column parallel (split output dim)
- "colwise_rep": Column parallel with gather output
- "rowwise": Row parallel (split input dim)
- "rowwise_rep": Row parallel without parallel input
- "replicate": Replicated (no parallelism)
quant_config: Quantization config for the new linear.
prefix: The name of the layer for weight loading.
Returns:
The new vLLM linear layer.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
"replicate": (ReplicatedLinear, {}),
}.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
prefix=prefix,
**vllm_linear_kwargs,
)
def replace_rms_norm_class(
rms_norm: nn.Module,
hidden_size: int,
) -> RMSNorm:
"""
Replace a Transformers RMSNorm with vLLM's optimized RMSNorm.
vLLM's RMSNorm provides:
- Fused CUDA kernels for better performance
- Support for fused add + norm operations
Args:
rms_norm: The RMSNorm module to replace.
hidden_size: The hidden size of the model.
Returns:
The new vLLM RMSNorm layer.
"""
# Try to get epsilon from various attribute names
eps = getattr(rms_norm, "eps", None)
if eps is None:
eps = getattr(rms_norm, "variance_epsilon", None)
if eps is None:
eps = 1e-6
# Check if weight exists and get its size
weight = getattr(rms_norm, "weight", None)
if weight is not None:
hidden_size = weight.size(0)
return RMSNorm(hidden_size=hidden_size, eps=eps)
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
"""Log module replacement for debugging."""
logger.debug("Replaced %s: %s -> %s", name, type(old_module).__name__, type(new_module).__name__)
def maybe_prefix(prefix: str, name: str) -> str:
"""Combine prefix and name with a dot separator."""
if prefix:
return f"{prefix}.{name}"
return name