Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
135
vllm/kernels/helion/ops/silu_mul_fp8.py
Normal file
135
vllm/kernels/helion/ops/silu_mul_fp8.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import has_helion
|
||||
|
||||
if not has_helion():
|
||||
raise ImportError(
|
||||
"silu_mul_fp8 Helion kernel requires helion to be installed. "
|
||||
"Install it with: pip install helion"
|
||||
)
|
||||
|
||||
import helion.language as hl
|
||||
|
||||
from vllm.kernels.helion.register import register_kernel
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@register_kernel # type: ignore[misc]
|
||||
def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
original_shape = input.shape
|
||||
two_d = hl.specialize(original_shape[-1])
|
||||
d = two_d // 2
|
||||
output_shape = original_shape[:-1] + (d,)
|
||||
|
||||
input_2d = input.view(-1, original_shape[-1])
|
||||
m = input_2d.shape[0]
|
||||
|
||||
# TODO(gmagogsfm): Support for more float8 subtypes (e4m3fnuz, e5m2) coming
|
||||
out = torch.empty((m, d), device=input.device, dtype=torch.float8_e4m3fn)
|
||||
|
||||
input_part_a = input_2d[:, :d]
|
||||
input_part_b = input_2d[:, d:]
|
||||
|
||||
assert scale.numel() == 1, "Scale must be a scalar Tensor"
|
||||
|
||||
for tile_m, tile_n in hl.tile([m, d]):
|
||||
a_vals = input_part_a[tile_m, tile_n]
|
||||
silu_result = torch.nn.functional.silu(a_vals)
|
||||
b_vals = input_part_b[tile_m, tile_n]
|
||||
result = silu_result * b_vals
|
||||
result_f32 = result.to(torch.float32)
|
||||
scale_val = hl.load(scale, [0])
|
||||
inv_scale = 1.0 / scale_val
|
||||
result_scaled = result_f32 * inv_scale
|
||||
out[tile_m, tile_n] = result_scaled.to(out.dtype)
|
||||
|
||||
return out.view(output_shape)
|
||||
|
||||
|
||||
@silu_mul_fp8.register_input_generator # type: ignore[misc]
|
||||
def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]:
|
||||
intermediate_sizes = [2048, 2880, 4096, 8192, 11008, 14336]
|
||||
|
||||
# Use the same num_tokens values as vLLM's default cudagraph capture sizes.
|
||||
# See vllm/config/vllm.py _set_cudagraph_sizes() for the canonical formula.
|
||||
num_tokens_list = [1, 2, 4] + list(range(8, 256, 8)) + list(range(256, 513, 16))
|
||||
|
||||
inputs = {}
|
||||
for num_tokens in num_tokens_list:
|
||||
for intermediate_size in intermediate_sizes:
|
||||
# Input tensor has shape (num_tokens, 2 * intermediate_size)
|
||||
# because silu_mul splits it into two halves
|
||||
input_tensor = torch.randn(
|
||||
num_tokens,
|
||||
2 * intermediate_size,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
scale = torch.tensor([1.0], device="cuda", dtype=torch.float32)
|
||||
|
||||
config_key = f"intermediate_{intermediate_size}_numtokens_{num_tokens}"
|
||||
inputs[config_key] = (input_tensor, scale)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@silu_mul_fp8.register_config_picker # type: ignore[misc]
|
||||
def pick_silu_mul_fp8_config(
|
||||
args: tuple[Any, ...], config_keys: list[str]
|
||||
) -> str | None:
|
||||
"""Pick the best pre-tuned config for the given input shape.
|
||||
|
||||
Selection strategy:
|
||||
1. Find the closest intermediate_size among available configs
|
||||
(exact match preferred).
|
||||
2. Among the num_tokens values tuned for that intermediate_size, pick
|
||||
the smallest num_tokens >= the input's num_tokens. If the input is
|
||||
larger than all available num_tokens, fall back to the largest.
|
||||
|
||||
Config keys must be "default" or follow the format
|
||||
"intermediate_{int}_numtokens_{int}".
|
||||
"""
|
||||
if not config_keys:
|
||||
return None
|
||||
|
||||
input_tensor, _scale = args
|
||||
intermediate_size = input_tensor.shape[-1] // 2
|
||||
num_tokens = input_tensor.view(-1, input_tensor.shape[-1]).shape[0]
|
||||
configs: dict[int, list[int]] = {}
|
||||
for key in config_keys:
|
||||
if key == "default":
|
||||
continue
|
||||
match = re.fullmatch(r"intermediate_(\d+)_numtokens_(\d+)", key)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Malformed config key '{key}', "
|
||||
f"expected format 'intermediate_{{int}}_numtokens_{{int}}'"
|
||||
)
|
||||
isize_str, ntokens_str = match.groups()
|
||||
configs.setdefault(int(isize_str), []).append(int(ntokens_str))
|
||||
|
||||
if not configs:
|
||||
return "default" if "default" in config_keys else None
|
||||
|
||||
best_isize = min(configs, key=lambda s: abs(s - intermediate_size))
|
||||
available_ntokens = sorted(configs[best_isize])
|
||||
best_ntokens = next(
|
||||
(n for n in available_ntokens if n >= num_tokens), available_ntokens[-1]
|
||||
)
|
||||
|
||||
return f"intermediate_{best_isize}_numtokens_{best_ntokens}"
|
||||
|
||||
|
||||
def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
output_shape = input.shape[:-1] + (input.shape[-1] // 2,)
|
||||
out = torch.empty(output_shape, dtype=torch.float8_e4m3fn, device=input.device)
|
||||
torch.ops._C.silu_and_mul_quant(out, input, scale)
|
||||
return out
|
||||
Reference in New Issue
Block a user