358 lines
12 KiB
Python
358 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Utility methods for model layers."""
|
|
|
|
from collections.abc import Callable
|
|
import ast
|
|
import re
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm import envs
|
|
from vllm._aiter_ops import rocm_aiter_ops
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import CpuArchEnum, current_platform
|
|
from vllm.utils.platform_utils import num_compute_units
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
import ixformer.inference.functions as IXF
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
MOE_LAYER_ROUTER_GATE_SUFFIXES = {
|
|
"gate",
|
|
"router",
|
|
"router_gate",
|
|
"shared_expert_gate",
|
|
"expert_gate",
|
|
}
|
|
|
|
|
|
def is_layer_moe_router_gate(prefix: str) -> bool:
|
|
if not prefix:
|
|
return False
|
|
return prefix.rsplit(".", 1)[-1] in MOE_LAYER_ROUTER_GATE_SUFFIXES
|
|
|
|
|
|
def get_token_bin_counts_and_mask(
|
|
tokens: torch.Tensor,
|
|
vocab_size: int,
|
|
num_seqs: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# Compute the bin counts for the tokens.
|
|
# vocab_size + 1 for padding.
|
|
bin_counts = torch.zeros(
|
|
(num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device
|
|
)
|
|
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
|
bin_counts = bin_counts[:, :vocab_size]
|
|
mask = bin_counts > 0
|
|
|
|
return bin_counts, mask
|
|
|
|
|
|
def apply_penalties(
|
|
logits: torch.Tensor,
|
|
prompt_tokens_tensor: torch.Tensor,
|
|
output_tokens_tensor: torch.Tensor,
|
|
presence_penalties: torch.Tensor,
|
|
frequency_penalties: torch.Tensor,
|
|
repetition_penalties: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Applies penalties in place to the logits tensor
|
|
logits : The input logits tensor of shape [num_seqs, vocab_size]
|
|
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
|
|
are padded to the maximum prompt length within the batch using
|
|
`vocab_size` as the padding value. The value `vocab_size` is used
|
|
for padding because it does not correspond to any valid token ID
|
|
in the vocabulary.
|
|
output_tokens_tensor: The output tokens tensor.
|
|
presence_penalties: The presence penalties of shape (num_seqs, )
|
|
frequency_penalties: The frequency penalties of shape (num_seqs, )
|
|
repetition_penalties: The repetition penalties of shape (num_seqs, )
|
|
"""
|
|
num_seqs, vocab_size = logits.shape
|
|
_, prompt_mask = get_token_bin_counts_and_mask(
|
|
prompt_tokens_tensor, vocab_size, num_seqs
|
|
)
|
|
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
|
output_tokens_tensor, vocab_size, num_seqs
|
|
)
|
|
|
|
# Apply repetition penalties as a custom op
|
|
from vllm._custom_ops import apply_repetition_penalties
|
|
|
|
apply_repetition_penalties(logits, prompt_mask, output_mask, repetition_penalties)
|
|
|
|
# We follow the definition in OpenAI API.
|
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
|
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
|
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
|
return logits
|
|
|
|
|
|
def default_unquantized_gemm(
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
):
|
|
if bias is None and x.dtype in [torch.half, torch.bfloat16] and weight.dtype == torch.float32:
|
|
return IXF.mixed_type_linear(input=x, weight=layer.weight)
|
|
if x.dtype == torch.float32:
|
|
return torch.nn.functional.linear(x, weight, bias)
|
|
return IXF.linear(x, weight, bias)
|
|
|
|
|
|
def use_aiter_triton_gemm(n, m, k, dtype):
|
|
if (
|
|
not rocm_aiter_ops.is_triton_gemm_enabled()
|
|
# MI300's - fp8nuz=True
|
|
or current_platform.is_fp8_fnuz()
|
|
or dtype not in [torch.float16, torch.bfloat16]
|
|
):
|
|
return False
|
|
|
|
# use hipblaslt for the larger GEMMs
|
|
if n > 2048 and m > 512:
|
|
return False
|
|
return (
|
|
(m == 5120 and k == 2880)
|
|
or (m == 2880 and k == 4096)
|
|
or (m == 128 and k == 2880)
|
|
or (m == 640 and k == 2880)
|
|
or (m == 2880 and k == 512)
|
|
)
|
|
|
|
|
|
def rocm_unquantized_gemm_impl(
|
|
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
|
) -> torch.Tensor:
|
|
from vllm.platforms.rocm import on_gfx9, on_gfx950
|
|
|
|
n = x.numel() // x.size(-1)
|
|
m = weight.shape[0]
|
|
k = weight.shape[1]
|
|
|
|
cu_count = num_compute_units()
|
|
if use_aiter_triton_gemm(n, m, k, x.dtype):
|
|
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
|
|
|
return gemm_a16w16(x, weight, bias)
|
|
|
|
# Next ^2 of n
|
|
N_p2 = 1 << (n - 1).bit_length()
|
|
# With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile),
|
|
# and each working on a 512-shard of K, how many CUs would we need?
|
|
rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512)
|
|
# How many of 4 waves in a group can work on same 16 Ms at same time?
|
|
# This reduces the Ms each group works on, i.e. increasing the number of CUs needed.
|
|
GrpsShrB = min(N_p2 // 16, 4)
|
|
# Given the above, how many CUs would we need?
|
|
CuNeeded = rndup_cus * GrpsShrB
|
|
# candidate for atomic reduce count splitk?
|
|
fits_wvsplitkrc = CuNeeded <= cu_count
|
|
|
|
use_skinny_reduce_counting = (
|
|
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
|
and on_gfx950()
|
|
and x.dtype in [torch.float16, torch.bfloat16]
|
|
and (
|
|
10 <= n <= 128
|
|
and k % 8 == 0
|
|
and k > 512
|
|
and m % 16 == 0
|
|
and fits_wvsplitkrc
|
|
and x.is_contiguous()
|
|
)
|
|
)
|
|
if use_skinny_reduce_counting:
|
|
x_view = x.reshape(-1, x.size(-1))
|
|
out = ops.wvSplitKrc(weight, x_view, cu_count, bias)
|
|
return out.reshape(*x.shape[:-1], weight.shape[0])
|
|
|
|
use_skinny = (
|
|
envs.VLLM_ROCM_USE_SKINNY_GEMM
|
|
and on_gfx9()
|
|
and x.dtype in [torch.float16, torch.bfloat16]
|
|
and k % 8 == 0
|
|
)
|
|
|
|
if use_skinny is not True:
|
|
return torch.nn.functional.linear(x, weight, bias)
|
|
|
|
x_view = x.reshape(-1, x.size(-1))
|
|
if m > 8 and 0 < n <= 4:
|
|
cu_count = num_compute_units()
|
|
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
|
return out.reshape(*x.shape[:-1], weight.shape[0])
|
|
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
|
out = ops.LLMM1(weight, x_view, 4)
|
|
return out.reshape(*x.shape[:-1], weight.shape[0])
|
|
return torch.nn.functional.linear(x, weight, bias)
|
|
|
|
|
|
def rocm_unquantized_gemm_fake(
|
|
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None = None
|
|
) -> torch.Tensor:
|
|
return x.new_empty((*x.shape[:-1], weight.shape[0]))
|
|
|
|
|
|
def rocm_unquantized_gemm(
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
return torch.ops.vllm.rocm_unquantized_gemm(x, weight, bias)
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="rocm_unquantized_gemm",
|
|
op_func=rocm_unquantized_gemm_impl,
|
|
fake_impl=rocm_unquantized_gemm_fake,
|
|
)
|
|
|
|
|
|
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
|
|
return (
|
|
torch._C._cpu._is_amx_tile_supported()
|
|
and (dtype in (torch.bfloat16, torch.int8))
|
|
and k % 32 == 0
|
|
and n % 16 == 0
|
|
)
|
|
|
|
|
|
def dispatch_cpu_unquantized_gemm(
|
|
layer: torch.nn.Module,
|
|
remove_weight: bool,
|
|
) -> None:
|
|
# skip for missing layers
|
|
if layer.weight.is_meta:
|
|
layer.cpu_linear = torch.nn.functional.linear
|
|
return
|
|
|
|
N, K = layer.weight.size()
|
|
dtype = layer.weight.dtype
|
|
|
|
if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
|
|
packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
|
|
if getattr(layer, "bias", None) is not None:
|
|
bias_f32 = layer.bias.to(torch.float32)
|
|
else:
|
|
bias_f32 = None
|
|
layer.cpu_linear = lambda x, weight, bias: torch.ops._C.weight_packed_linear(
|
|
x, packed_weight, bias_f32 if bias is not None else None, True
|
|
)
|
|
if remove_weight:
|
|
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
|
|
return
|
|
elif (
|
|
ops._supports_onednn
|
|
and current_platform.get_cpu_architecture() != CpuArchEnum.POWERPC
|
|
):
|
|
try:
|
|
origin_weight = layer.weight
|
|
handler = ops.create_onednn_mm(origin_weight.t(), 32)
|
|
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(handler, x, bias)
|
|
if remove_weight:
|
|
layer.weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
|
|
return
|
|
except RuntimeError as e:
|
|
logger.warning_once(
|
|
"Failed to create oneDNN linear, fallback to torch linear."
|
|
f" Exception: {e}"
|
|
)
|
|
|
|
# fallback case
|
|
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
|
|
x, weight, bias
|
|
)
|
|
|
|
|
|
def cpu_unquantized_gemm(
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
):
|
|
return layer.cpu_linear(x, weight, bias)
|
|
|
|
|
|
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
|
|
if current_platform.is_rocm():
|
|
return rocm_unquantized_gemm
|
|
elif current_platform.is_cpu():
|
|
return cpu_unquantized_gemm
|
|
else:
|
|
return default_unquantized_gemm
|
|
|
|
def weight_quant_l1(input: torch.Tensor):
|
|
qmax = 127.0
|
|
input = input.to(device="cuda")
|
|
abs_max = torch.abs(input).max(dim=1, keepdim=True)[0]
|
|
scale = abs_max / qmax
|
|
assert scale.shape == (input.shape[0], 1)
|
|
quantized = torch.round(input / scale)
|
|
quantized = torch.clamp(quantized, -qmax, qmax)
|
|
return quantized.to(torch.int8), scale.to(torch.float32)
|
|
|
|
def weight_quant_l2(input: torch.Tensor, format: str = "TN"):
|
|
qmax = 127.0
|
|
input = input.to(device="cuda")
|
|
abs_max = torch.abs(input).max(dim=1, keepdim=True)[0] # [rows, 1]
|
|
scale = abs_max / qmax # [rows, 1]
|
|
assert scale.shape == (input.shape[0], 1)
|
|
quantized = torch.round(input / scale)
|
|
quantized = torch.clamp(quantized, -qmax, qmax)
|
|
|
|
i4_weights, i8scales, i8zeros = IXF.quant_repack_int4(quantized.to(torch.int8).unsqueeze_(0), -1, 2, format, False)
|
|
return i4_weights.squeeze(0), scale.to(torch.float32)
|
|
|
|
|
|
def parse_opt_exclude_layers(
|
|
opt_exclude_layers_str: str,
|
|
prefix: str,
|
|
) -> bool:
|
|
"""
|
|
Parses the VLLM_OPT_EXCLUDE_LAYERS environment variable to determine if
|
|
the current layer should be excluded from optimization.
|
|
|
|
Args:
|
|
opt_exclude_layers_str: The string value from the
|
|
VLLM_OPT_EXCLUDE_LAYERS environment variable.
|
|
prefix: The prefix of the current layer (e.g.,
|
|
"model.layers.12.qkv_proj").
|
|
|
|
Returns:
|
|
A boolean indicating whether the layer should be excluded.
|
|
"""
|
|
if not opt_exclude_layers_str:
|
|
return False
|
|
|
|
try:
|
|
# Safely evaluate the string to a Python object
|
|
excluded_layers = ast.literal_eval(opt_exclude_layers_str)
|
|
|
|
# If a single integer is provided, convert it to a set
|
|
if isinstance(excluded_layers, int):
|
|
excluded_layers = {excluded_layers}
|
|
elif not isinstance(excluded_layers, (set, tuple, list)):
|
|
raise TypeError
|
|
|
|
excluded_layers: set[int] = set(excluded_layers)
|
|
|
|
# Extract layer number from the prefix string
|
|
layer_match = re.search(r"\.(\d+)", prefix)
|
|
if layer_match and int(layer_match.group(1)) in excluded_layers:
|
|
return True # Exclude this layer
|
|
except (ValueError, SyntaxError, TypeError):
|
|
logger.warning(
|
|
"Failed to parse VLLM_OPT_EXCLUDE_LAYERS: %s. "
|
|
"Expected a string representation of an integer or a "
|
|
"tuple/list/set of integers.",
|
|
opt_exclude_layers_str,
|
|
)
|
|
|
|
return False # Do not exclude this layer |