Sync from v0.13
This commit is contained in:
213
vllm/model_executor/models/transformers/utils.py
Normal file
213
vllm/model_executor/models/transformers/utils.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2024 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Transformers modeling backend utilities."""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
||||
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Copied from `accelerate`
|
||||
@contextmanager
|
||||
def init_on_device_without_buffers(device: torch.device):
|
||||
"""
|
||||
A context manager under which models are initialized with all
|
||||
parameters on the specified device. However buffers are not
|
||||
initialized on specified device.
|
||||
|
||||
Args:
|
||||
device (`torch.device`):
|
||||
Device to initialize all parameters on.
|
||||
"""
|
||||
|
||||
old_register_parameter = nn.Module.register_parameter
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(
|
||||
module._parameters[name].to(device), **kwargs
|
||||
)
|
||||
|
||||
tensor_constructors_to_patch = {}
|
||||
|
||||
def patch_tensor_constructor(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
kwargs["device"] = device
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
try:
|
||||
nn.Module.register_parameter = register_empty_parameter
|
||||
for torch_function_name in tensor_constructors_to_patch:
|
||||
setattr(
|
||||
torch,
|
||||
torch_function_name,
|
||||
patch_tensor_constructor(getattr(torch, torch_function_name)),
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
nn.Module.register_parameter = old_register_parameter
|
||||
for (
|
||||
torch_function_name,
|
||||
old_torch_function,
|
||||
) in tensor_constructors_to_patch.items():
|
||||
setattr(torch, torch_function_name, old_torch_function)
|
||||
|
||||
|
||||
Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]
|
||||
|
||||
|
||||
def replace_linear_class(
|
||||
linear: nn.Linear,
|
||||
style: Style = "replicate",
|
||||
quant_config: "QuantizationConfig | None" = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear:
|
||||
"""
|
||||
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
||||
|
||||
Args:
|
||||
linear: `nn.Linear` to be replaced.
|
||||
style: Tensor parallel style of the new linear, e.g. "colwise".
|
||||
quant_config: Quantization config for the new linear.
|
||||
Returns:
|
||||
The new linear.
|
||||
"""
|
||||
|
||||
if not isinstance(style, str):
|
||||
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
|
||||
|
||||
vllm_linear_cls, vllm_linear_kwargs = {
|
||||
"colwise": (ColumnParallelLinear, {}),
|
||||
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
|
||||
"rowwise": (RowParallelLinear, {}),
|
||||
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
|
||||
"replicate": (ReplicatedLinear, {}),
|
||||
}.get(style, (ReplicatedLinear, {}))
|
||||
|
||||
return vllm_linear_cls(
|
||||
input_size=linear.in_features,
|
||||
output_size=linear.out_features,
|
||||
bias=linear.bias is not None,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=False,
|
||||
**vllm_linear_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
|
||||
"""Replace a Transformers RMSNorm with vLLM's RMSNorm.
|
||||
|
||||
This method assumes:
|
||||
- Weight is stored as `weight`.
|
||||
- Epsilon is stored as `eps` or `variance_epsilon`.
|
||||
- `with_scale` indicates whether the layer has a weight (Gemma3n only).
|
||||
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
|
||||
and Transformers doesn't appear to have the same concept.
|
||||
"""
|
||||
eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
|
||||
kwargs = {"hidden_size": hidden_size, "eps": eps}
|
||||
# Update hidden size if weight is available
|
||||
weight_meta = getattr(rms_norm, "weight", None)
|
||||
if weight_meta is not None:
|
||||
kwargs["hidden_size"] = weight_meta.size(0)
|
||||
# Check if weight is all zeros, which indicates GemmaRMSNorm
|
||||
# We must create a new instance because rms_norm is on meta
|
||||
try:
|
||||
with torch.device("cpu"):
|
||||
weight_test = getattr(rms_norm.__class__(1), "weight", None)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to determine if RMSNorm weight is centered on zero or one. "
|
||||
"Defaulting to one."
|
||||
)
|
||||
weight_test = None
|
||||
if weight_test is not None and torch.all(weight_test == 0):
|
||||
return GemmaRMSNorm(**kwargs)
|
||||
# Otherwise assume it's a regular RMSNorm
|
||||
kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
|
||||
if weight_meta is not None:
|
||||
kwargs["dtype"] = weight_meta.dtype
|
||||
else:
|
||||
# No weight, fall back to weightless RMSNorm
|
||||
kwargs["has_weight"] = False
|
||||
return RMSNorm(**kwargs)
|
||||
|
||||
|
||||
def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
||||
logger.debug("%s: %s -> %s", name, old_module, new_module)
|
||||
|
||||
|
||||
def get_feature_request_tip(
|
||||
model: str,
|
||||
trust_remote_code: bool,
|
||||
) -> str:
|
||||
hf_url = f"a discussion at https://huggingface.co/{model}/discussions/new"
|
||||
gh_url = "an issue at https://github.com/huggingface/transformers/issues/new/choose"
|
||||
url = hf_url if trust_remote_code else gh_url
|
||||
prefix = f"Please open {url} to request support for this feature. "
|
||||
if Path(model).exists():
|
||||
prefix = ""
|
||||
doc_url = "https://docs.vllm.ai/en/latest/models/supported_models.html#writing-custom-models"
|
||||
tip = f"See {doc_url} for instructions on how to add support yourself."
|
||||
return f"{prefix}{tip}"
|
||||
|
||||
|
||||
def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
|
||||
"""
|
||||
Callable to be passed to `@support_torch_compile`'s `enable_if` argument.
|
||||
|
||||
Defaults to `True` but is disabled in the following situations:
|
||||
|
||||
- The model uses dynamic rope scaling.
|
||||
"""
|
||||
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||
# Dynamic rope scaling is not compatible with torch.compile
|
||||
rope_parameters: dict | None = getattr(text_config, "rope_parameters", None) or {}
|
||||
if rope_parameters:
|
||||
# Nest rope_parameters if not nested already to simplify logic
|
||||
if not set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
|
||||
rope_parameters = {"": rope_parameters}
|
||||
return all(rp["rope_type"] != "dynamic" for rp in rope_parameters.values())
|
||||
return True
|
||||
Reference in New Issue
Block a user