136 lines
4.7 KiB
Python
136 lines
4.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Any
|
|
|
|
import regex as re
|
|
import torch
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.utils.import_utils import has_helion
|
|
|
|
if not has_helion():
|
|
raise ImportError(
|
|
"silu_mul_fp8 Helion kernel requires helion to be installed. "
|
|
"Install it with: pip install helion"
|
|
)
|
|
|
|
import helion.language as hl
|
|
|
|
from vllm.kernels.helion.register import register_kernel
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@register_kernel # type: ignore[misc]
|
|
def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
original_shape = input.shape
|
|
two_d = hl.specialize(original_shape[-1])
|
|
d = two_d // 2
|
|
output_shape = original_shape[:-1] + (d,)
|
|
|
|
input_2d = input.view(-1, original_shape[-1])
|
|
m = input_2d.shape[0]
|
|
|
|
# TODO(gmagogsfm): Support for more float8 subtypes (e4m3fnuz, e5m2) coming
|
|
out = torch.empty((m, d), device=input.device, dtype=torch.float8_e4m3fn)
|
|
|
|
input_part_a = input_2d[:, :d]
|
|
input_part_b = input_2d[:, d:]
|
|
|
|
assert scale.numel() == 1, "Scale must be a scalar Tensor"
|
|
|
|
for tile_m, tile_n in hl.tile([m, d]):
|
|
a_vals = input_part_a[tile_m, tile_n]
|
|
silu_result = torch.nn.functional.silu(a_vals)
|
|
b_vals = input_part_b[tile_m, tile_n]
|
|
result = silu_result * b_vals
|
|
result_f32 = result.to(torch.float32)
|
|
scale_val = hl.load(scale, [0])
|
|
inv_scale = 1.0 / scale_val
|
|
result_scaled = result_f32 * inv_scale
|
|
out[tile_m, tile_n] = result_scaled.to(out.dtype)
|
|
|
|
return out.view(output_shape)
|
|
|
|
|
|
@silu_mul_fp8.register_input_generator # type: ignore[misc]
|
|
def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]:
|
|
intermediate_sizes = [2048, 2880, 4096, 8192, 11008, 14336]
|
|
|
|
# Use the same num_tokens values as vLLM's default cudagraph capture sizes.
|
|
# See vllm/config/vllm.py _set_cudagraph_sizes() for the canonical formula.
|
|
num_tokens_list = [1, 2, 4] + list(range(8, 256, 8)) + list(range(256, 513, 16))
|
|
|
|
inputs = {}
|
|
for num_tokens in num_tokens_list:
|
|
for intermediate_size in intermediate_sizes:
|
|
# Input tensor has shape (num_tokens, 2 * intermediate_size)
|
|
# because silu_mul splits it into two halves
|
|
input_tensor = torch.randn(
|
|
num_tokens,
|
|
2 * intermediate_size,
|
|
device="cuda",
|
|
dtype=torch.bfloat16,
|
|
)
|
|
scale = torch.tensor([1.0], device="cuda", dtype=torch.float32)
|
|
|
|
config_key = f"intermediate_{intermediate_size}_numtokens_{num_tokens}"
|
|
inputs[config_key] = (input_tensor, scale)
|
|
|
|
return inputs
|
|
|
|
|
|
@silu_mul_fp8.register_config_picker # type: ignore[misc]
|
|
def pick_silu_mul_fp8_config(
|
|
args: tuple[Any, ...], config_keys: list[str]
|
|
) -> str | None:
|
|
"""Pick the best pre-tuned config for the given input shape.
|
|
|
|
Selection strategy:
|
|
1. Find the closest intermediate_size among available configs
|
|
(exact match preferred).
|
|
2. Among the num_tokens values tuned for that intermediate_size, pick
|
|
the smallest num_tokens >= the input's num_tokens. If the input is
|
|
larger than all available num_tokens, fall back to the largest.
|
|
|
|
Config keys must be "default" or follow the format
|
|
"intermediate_{int}_numtokens_{int}".
|
|
"""
|
|
if not config_keys:
|
|
return None
|
|
|
|
input_tensor, _scale = args
|
|
intermediate_size = input_tensor.shape[-1] // 2
|
|
num_tokens = input_tensor.view(-1, input_tensor.shape[-1]).shape[0]
|
|
configs: dict[int, list[int]] = {}
|
|
for key in config_keys:
|
|
if key == "default":
|
|
continue
|
|
match = re.fullmatch(r"intermediate_(\d+)_numtokens_(\d+)", key)
|
|
if not match:
|
|
raise ValueError(
|
|
f"Malformed config key '{key}', "
|
|
f"expected format 'intermediate_{{int}}_numtokens_{{int}}'"
|
|
)
|
|
isize_str, ntokens_str = match.groups()
|
|
configs.setdefault(int(isize_str), []).append(int(ntokens_str))
|
|
|
|
if not configs:
|
|
return "default" if "default" in config_keys else None
|
|
|
|
best_isize = min(configs, key=lambda s: abs(s - intermediate_size))
|
|
available_ntokens = sorted(configs[best_isize])
|
|
best_ntokens = next(
|
|
(n for n in available_ntokens if n >= num_tokens), available_ntokens[-1]
|
|
)
|
|
|
|
return f"intermediate_{best_isize}_numtokens_{best_ntokens}"
|
|
|
|
|
|
def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
output_shape = input.shape[:-1] + (input.shape[-1] // 2,)
|
|
out = torch.empty(output_shape, dtype=torch.float8_e4m3fn, device=input.device)
|
|
torch.ops._C.silu_and_mul_quant(out, input, scale)
|
|
return out
|