update
This commit is contained in:
3
vllm/kernels/__init__.py
Normal file
3
vllm/kernels/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Kernel implementations for vLLM."""
|
||||
34
vllm/kernels/helion/__init__.py
Normal file
34
vllm/kernels/helion/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Helion integration for vLLM."""
|
||||
|
||||
import vllm.kernels.helion.ops # noqa: F401 Auto-register all Helion ops
|
||||
from vllm.kernels.helion.config_manager import (
|
||||
ConfigManager,
|
||||
ConfigSet,
|
||||
)
|
||||
from vllm.kernels.helion.register import (
|
||||
ConfiguredHelionKernel,
|
||||
HelionKernelWrapper,
|
||||
get_kernel_by_name,
|
||||
get_registered_kernels,
|
||||
register_kernel,
|
||||
vllm_helion_lib,
|
||||
)
|
||||
from vllm.kernels.helion.utils import canonicalize_gpu_name, get_canonical_gpu_name
|
||||
|
||||
__all__ = [
|
||||
# Config management
|
||||
"ConfigManager",
|
||||
"ConfigSet",
|
||||
# Kernel registration
|
||||
"ConfiguredHelionKernel",
|
||||
"HelionKernelWrapper",
|
||||
"get_kernel_by_name",
|
||||
"get_registered_kernels",
|
||||
"register_kernel",
|
||||
"vllm_helion_lib",
|
||||
# Utilities
|
||||
"canonicalize_gpu_name",
|
||||
"get_canonical_gpu_name",
|
||||
]
|
||||
281
vllm/kernels/helion/config_manager.py
Normal file
281
vllm/kernels/helion/config_manager.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Configuration management for Helion kernels.
|
||||
|
||||
This module provides centralized configuration file management for Helion custom
|
||||
operations, including naming conventions, directory resolution, and file I/O.
|
||||
|
||||
Config File Structure
|
||||
---------------------
|
||||
Each kernel has a single JSON config file: {kernel_name}.json
|
||||
|
||||
The file uses a simplified 2-layer hierarchical structure:
|
||||
{
|
||||
"h100": { # GPU platform
|
||||
"default": { ... }, # Fallback configuration
|
||||
"batch_32_hidden_4096": { ... },
|
||||
"batch_64_hidden_8192": { ... }
|
||||
},
|
||||
"a100": {
|
||||
"default": { ... },
|
||||
"batch_16_hidden_2048": { ... }
|
||||
}
|
||||
}
|
||||
|
||||
Example file: silu_mul_fp8.json
|
||||
|
||||
Config keys should be structured strings that encode the relevant
|
||||
parameters (e.g., "batch_32_hidden_4096", "seq_512_heads_16", "fp8_batch_64", etc.).
|
||||
|
||||
Classes
|
||||
-------
|
||||
- ConfigSet: In-memory collection of configs for a kernel with lookup/query APIs.
|
||||
- ConfigManager: File-level operations for config persistence.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import has_helion
|
||||
|
||||
if not has_helion():
|
||||
raise ImportError(
|
||||
"ConfigManager requires helion to be installed. "
|
||||
"Install it with: pip install helion"
|
||||
)
|
||||
|
||||
import helion
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ConfigSet:
|
||||
"""In-memory collection of Helion configs with lookup/query capabilities."""
|
||||
|
||||
# Type alias for nested config structure:
|
||||
# platform -> config_key -> helion.Config
|
||||
_ConfigDict = dict[str, dict[str, "helion.Config"]]
|
||||
|
||||
def __init__(self, kernel_name: str):
|
||||
self._kernel_name = kernel_name
|
||||
self._configs: ConfigSet._ConfigDict = {}
|
||||
|
||||
@property
|
||||
def kernel_name(self) -> str:
|
||||
return self._kernel_name
|
||||
|
||||
def get_config(self, platform: str, config_key: str) -> helion.Config:
|
||||
platform_dict = self._configs.get(platform)
|
||||
if platform_dict is None:
|
||||
avail_platforms = self.get_platforms()
|
||||
# TODO(@gmagogsfm): add a CLI/env override flag so users can
|
||||
# directly specify a platform name instead of relying on
|
||||
# auto-detection, and suggest it in this error message.
|
||||
raise KeyError(
|
||||
f"Config not found for kernel '{self._kernel_name}': "
|
||||
f"platform '{platform}' not found. "
|
||||
f"Available platforms: {avail_platforms or '(none)'}. "
|
||||
f"If your GPU is a variant of a supported platform, "
|
||||
f"consider adding a mapping in _GPU_NAME_ALIASES in "
|
||||
f"vllm/kernels/helion/utils.py, or run "
|
||||
f"scripts/autotune_helion_kernels.py to generate configs "
|
||||
f"for your platform."
|
||||
)
|
||||
|
||||
config = platform_dict.get(config_key)
|
||||
if config is None:
|
||||
avail_keys = self.get_config_keys(platform)
|
||||
raise KeyError(
|
||||
f"Config not found for kernel '{self._kernel_name}': "
|
||||
f"config_key '{config_key}' not found for platform '{platform}'. "
|
||||
f"Available config_keys: {avail_keys or '(none)'}"
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def get_platforms(self) -> list[str]:
|
||||
return sorted(self._configs.keys())
|
||||
|
||||
def get_config_keys(self, platform: str) -> list[str]:
|
||||
platform_dict = self._configs.get(platform.lower())
|
||||
if platform_dict is None:
|
||||
return []
|
||||
return sorted(platform_dict.keys())
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for platform, config_keys_dict in self._configs.items():
|
||||
result[platform] = {}
|
||||
|
||||
for config_key, config in config_keys_dict.items():
|
||||
result[platform][config_key] = json.loads(config.to_json())
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kernel_name: str, data: dict[str, Any]) -> "ConfigSet":
|
||||
config_set = cls(kernel_name)
|
||||
count = 0
|
||||
|
||||
for platform, platform_data in data.items():
|
||||
if platform not in config_set._configs:
|
||||
config_set._configs[platform] = {}
|
||||
|
||||
for config_key, config_data in platform_data.items():
|
||||
config = helion.Config(**config_data)
|
||||
config_set._configs[platform][config_key] = config
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
logger.debug(
|
||||
"Loaded %d configs for kernel '%s'",
|
||||
count,
|
||||
kernel_name,
|
||||
)
|
||||
|
||||
return config_set
|
||||
|
||||
def set_config(
|
||||
self, platform: str, config_key: str, config: "helion.Config"
|
||||
) -> None:
|
||||
platform = platform.lower()
|
||||
if platform not in self._configs:
|
||||
self._configs[platform] = {}
|
||||
self._configs[platform][config_key] = config
|
||||
logger.debug(
|
||||
"Set config for kernel '%s': platform='%s', key='%s'",
|
||||
self._kernel_name,
|
||||
platform,
|
||||
config_key,
|
||||
)
|
||||
|
||||
def has_config(self, platform: str, config_key: str) -> bool:
|
||||
platform = platform.lower()
|
||||
platform_dict = self._configs.get(platform)
|
||||
if platform_dict is None:
|
||||
return False
|
||||
return config_key in platform_dict
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""File-level configuration management for Helion kernels (global singleton)."""
|
||||
|
||||
_instance: "ConfigManager | None" = None
|
||||
_instance_base_dir: Path | None = None
|
||||
|
||||
def __new__(cls, base_dir: str | Path | None = None) -> "ConfigManager":
|
||||
resolved_base_dir = cls._resolve_base_dir(base_dir)
|
||||
|
||||
if cls._instance is not None:
|
||||
if cls._instance_base_dir != resolved_base_dir:
|
||||
raise ValueError(
|
||||
f"ConfigManager singleton already exists with base_dir "
|
||||
f"'{cls._instance_base_dir}', cannot create with different "
|
||||
f"base_dir '{resolved_base_dir}'"
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
instance = super().__new__(cls)
|
||||
cls._instance = instance
|
||||
cls._instance_base_dir = resolved_base_dir
|
||||
return instance
|
||||
|
||||
def __init__(self, base_dir: str | Path | None = None):
|
||||
if hasattr(self, "_base_dir"):
|
||||
return
|
||||
|
||||
self._base_dir = self._resolve_base_dir(base_dir)
|
||||
logger.debug("ConfigManager initialized with base_dir: %s", self._base_dir)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_base_dir(base_dir: str | Path | None) -> Path:
|
||||
if base_dir is not None:
|
||||
return Path(base_dir).resolve()
|
||||
return (Path(__file__).parent / "configs").resolve()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ConfigManager":
|
||||
if cls._instance is None:
|
||||
raise RuntimeError(
|
||||
"ConfigManager instance has not been created. "
|
||||
"Call ConfigManager(base_dir=...) first to initialize."
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""For testing purposes only."""
|
||||
cls._instance = None
|
||||
cls._instance_base_dir = None
|
||||
|
||||
def get_config_file_path(self, kernel_name: str) -> Path:
|
||||
return self._base_dir / f"{kernel_name}.json"
|
||||
|
||||
def ensure_base_dir_exists(self) -> Path:
|
||||
self._base_dir.mkdir(parents=True, exist_ok=True)
|
||||
return self._base_dir
|
||||
|
||||
def ensure_base_dir_writable(self) -> None:
|
||||
self.ensure_base_dir_exists()
|
||||
test_file = self._base_dir / ".write_test"
|
||||
try:
|
||||
test_file.write_text("test")
|
||||
test_file.unlink()
|
||||
except OSError as e:
|
||||
raise OSError(
|
||||
f"Config directory '{self._base_dir}' is not writable: {e}"
|
||||
) from e
|
||||
|
||||
def load_config_set(self, kernel_name: str) -> ConfigSet:
|
||||
config_path = self.get_config_file_path(kernel_name)
|
||||
if not config_path.exists():
|
||||
return ConfigSet.from_dict(kernel_name, {})
|
||||
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
data = json.load(f)
|
||||
return ConfigSet.from_dict(kernel_name, data)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.error("Failed to load config file %s: %s", config_path, e)
|
||||
return ConfigSet.from_dict(kernel_name, {})
|
||||
|
||||
def get_platform_configs(
|
||||
self, kernel_name: str, platform: str
|
||||
) -> dict[str, helion.Config]:
|
||||
config_set = self.load_config_set(kernel_name)
|
||||
config_keys = config_set.get_config_keys(platform)
|
||||
|
||||
return {
|
||||
config_key: config_set.get_config(platform, config_key)
|
||||
for config_key in config_keys
|
||||
}
|
||||
|
||||
def save_config_set(self, config_set: ConfigSet) -> Path:
|
||||
config_path = self.get_config_file_path(config_set.kernel_name)
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_set.to_dict(), f, indent=2)
|
||||
|
||||
logger.info("Saved config to: %s", config_path)
|
||||
return config_path
|
||||
|
||||
def save_configs(
|
||||
self,
|
||||
kernel_name: str,
|
||||
platform: str,
|
||||
configs: dict[str, "helion.Config"],
|
||||
) -> Path:
|
||||
"""Save configs for a kernel/platform, merging with existing."""
|
||||
config_set = self.load_config_set(kernel_name)
|
||||
for config_key, config in configs.items():
|
||||
config_set.set_config(platform, config_key, config)
|
||||
return self.save_config_set(config_set)
|
||||
|
||||
def config_exists(self, kernel_name: str, platform: str, config_key: str) -> bool:
|
||||
config_set = self.load_config_set(kernel_name)
|
||||
return config_set.has_config(platform, config_key)
|
||||
27726
vllm/kernels/helion/configs/silu_mul_fp8.json
Normal file
27726
vllm/kernels/helion/configs/silu_mul_fp8.json
Normal file
File diff suppressed because it is too large
Load Diff
11
vllm/kernels/helion/ops/__init__.py
Normal file
11
vllm/kernels/helion/ops/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Auto-import all Helion op modules to trigger kernel registration."""
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
# Automatically import all submodules so that @register_kernel
|
||||
# decorators execute and register ops with torch.ops.vllm_helion.
|
||||
for _module_info in pkgutil.iter_modules(__path__):
|
||||
importlib.import_module(f"{__name__}.{_module_info.name}")
|
||||
135
vllm/kernels/helion/ops/silu_mul_fp8.py
Normal file
135
vllm/kernels/helion/ops/silu_mul_fp8.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import has_helion
|
||||
|
||||
if not has_helion():
|
||||
raise ImportError(
|
||||
"silu_mul_fp8 Helion kernel requires helion to be installed. "
|
||||
"Install it with: pip install helion"
|
||||
)
|
||||
|
||||
import helion.language as hl
|
||||
|
||||
from vllm.kernels.helion.register import register_kernel
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@register_kernel # type: ignore[misc]
|
||||
def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
original_shape = input.shape
|
||||
two_d = hl.specialize(original_shape[-1])
|
||||
d = two_d // 2
|
||||
output_shape = original_shape[:-1] + (d,)
|
||||
|
||||
input_2d = input.view(-1, original_shape[-1])
|
||||
m = input_2d.shape[0]
|
||||
|
||||
# TODO(gmagogsfm): Support for more float8 subtypes (e4m3fnuz, e5m2) coming
|
||||
out = torch.empty((m, d), device=input.device, dtype=torch.float8_e4m3fn)
|
||||
|
||||
input_part_a = input_2d[:, :d]
|
||||
input_part_b = input_2d[:, d:]
|
||||
|
||||
assert scale.numel() == 1, "Scale must be a scalar Tensor"
|
||||
|
||||
for tile_m, tile_n in hl.tile([m, d]):
|
||||
a_vals = input_part_a[tile_m, tile_n]
|
||||
silu_result = torch.nn.functional.silu(a_vals)
|
||||
b_vals = input_part_b[tile_m, tile_n]
|
||||
result = silu_result * b_vals
|
||||
result_f32 = result.to(torch.float32)
|
||||
scale_val = hl.load(scale, [0])
|
||||
inv_scale = 1.0 / scale_val
|
||||
result_scaled = result_f32 * inv_scale
|
||||
out[tile_m, tile_n] = result_scaled.to(out.dtype)
|
||||
|
||||
return out.view(output_shape)
|
||||
|
||||
|
||||
@silu_mul_fp8.register_input_generator # type: ignore[misc]
|
||||
def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]:
|
||||
intermediate_sizes = [2048, 2880, 4096, 8192, 11008, 14336]
|
||||
|
||||
# Use the same num_tokens values as vLLM's default cudagraph capture sizes.
|
||||
# See vllm/config/vllm.py _set_cudagraph_sizes() for the canonical formula.
|
||||
num_tokens_list = [1, 2, 4] + list(range(8, 256, 8)) + list(range(256, 513, 16))
|
||||
|
||||
inputs = {}
|
||||
for num_tokens in num_tokens_list:
|
||||
for intermediate_size in intermediate_sizes:
|
||||
# Input tensor has shape (num_tokens, 2 * intermediate_size)
|
||||
# because silu_mul splits it into two halves
|
||||
input_tensor = torch.randn(
|
||||
num_tokens,
|
||||
2 * intermediate_size,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
scale = torch.tensor([1.0], device="cuda", dtype=torch.float32)
|
||||
|
||||
config_key = f"intermediate_{intermediate_size}_numtokens_{num_tokens}"
|
||||
inputs[config_key] = (input_tensor, scale)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@silu_mul_fp8.register_config_picker # type: ignore[misc]
|
||||
def pick_silu_mul_fp8_config(
|
||||
args: tuple[Any, ...], config_keys: list[str]
|
||||
) -> str | None:
|
||||
"""Pick the best pre-tuned config for the given input shape.
|
||||
|
||||
Selection strategy:
|
||||
1. Find the closest intermediate_size among available configs
|
||||
(exact match preferred).
|
||||
2. Among the num_tokens values tuned for that intermediate_size, pick
|
||||
the smallest num_tokens >= the input's num_tokens. If the input is
|
||||
larger than all available num_tokens, fall back to the largest.
|
||||
|
||||
Config keys must be "default" or follow the format
|
||||
"intermediate_{int}_numtokens_{int}".
|
||||
"""
|
||||
if not config_keys:
|
||||
return None
|
||||
|
||||
input_tensor, _scale = args
|
||||
intermediate_size = input_tensor.shape[-1] // 2
|
||||
num_tokens = input_tensor.view(-1, input_tensor.shape[-1]).shape[0]
|
||||
configs: dict[int, list[int]] = {}
|
||||
for key in config_keys:
|
||||
if key == "default":
|
||||
continue
|
||||
match = re.fullmatch(r"intermediate_(\d+)_numtokens_(\d+)", key)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Malformed config key '{key}', "
|
||||
f"expected format 'intermediate_{{int}}_numtokens_{{int}}'"
|
||||
)
|
||||
isize_str, ntokens_str = match.groups()
|
||||
configs.setdefault(int(isize_str), []).append(int(ntokens_str))
|
||||
|
||||
if not configs:
|
||||
return "default" if "default" in config_keys else None
|
||||
|
||||
best_isize = min(configs, key=lambda s: abs(s - intermediate_size))
|
||||
available_ntokens = sorted(configs[best_isize])
|
||||
best_ntokens = next(
|
||||
(n for n in available_ntokens if n >= num_tokens), available_ntokens[-1]
|
||||
)
|
||||
|
||||
return f"intermediate_{best_isize}_numtokens_{best_ntokens}"
|
||||
|
||||
|
||||
def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
output_shape = input.shape[:-1] + (input.shape[-1] // 2,)
|
||||
out = torch.empty(output_shape, dtype=torch.float8_e4m3fn, device=input.device)
|
||||
torch.ops._C.silu_and_mul_quant(out, input, scale)
|
||||
return out
|
||||
451
vllm/kernels/helion/register.py
Normal file
451
vllm/kernels/helion/register.py
Normal file
@@ -0,0 +1,451 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
vLLM Helion kernel registration with pre-tuned config selection.
|
||||
|
||||
This module leverages Helion's internal config selection infrastructure to use
|
||||
pre-tuned configs instead of runtime autotuning.
|
||||
|
||||
How Helion Normally Works
|
||||
-------------------------
|
||||
For each kernel invocation, Helion:
|
||||
1. Computes a cache key from input arguments
|
||||
2. Looks up the key in its internal compilation cache
|
||||
3. On cache miss, runs autotuning to find the best config
|
||||
4. Compiles and caches the kernel with that config
|
||||
|
||||
How We Override It
|
||||
------------------
|
||||
We override two Helion hooks to use pre-tuned configs:
|
||||
|
||||
1. **key**: We provide a key function (derived from config_picker) that
|
||||
computes cache keys matching our pre-tuned config keys. This ensures Helion's
|
||||
internal cache uses keys that correspond to configs we've prepared.
|
||||
|
||||
2. **autotuner_fn**: We provide PresetConfigSearch which, instead of autotuning,
|
||||
simply returns the pre-tuned config for the computed key. On cache miss,
|
||||
Helion calls our autotuner which returns the author-prepared config.
|
||||
|
||||
Both hooks use the same config_picker logic to ensure the cache key computed
|
||||
by key matches the config returned by the autotuner.
|
||||
|
||||
Key Classes
|
||||
-----------
|
||||
- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured ops
|
||||
- ConfiguredHelionKernel: Platform-specific kernel registered as PyTorch custom op
|
||||
- PresetConfigSearch: Custom autotuner that returns pre-tuned configs
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast, overload
|
||||
|
||||
import torch
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import has_helion
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
if not has_helion():
|
||||
raise ImportError(
|
||||
"register module requires helion to be installed. "
|
||||
"Install it with: pip install helion"
|
||||
)
|
||||
|
||||
import helion
|
||||
from helion.autotuner.base_search import BaseAutotuner
|
||||
from helion.runtime.config import Config
|
||||
from helion.runtime.settings import default_autotuner_fn
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
vllm_helion_lib = Library("vllm_helion", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def validate_helion_settings(
|
||||
helion_settings: "helion.Settings | None", op_name: str
|
||||
) -> None:
|
||||
if helion_settings is None:
|
||||
return
|
||||
|
||||
settings_dict = helion_settings.to_dict()
|
||||
|
||||
if (
|
||||
"autotuner_fn" in settings_dict
|
||||
and settings_dict["autotuner_fn"] is not None
|
||||
and settings_dict["autotuner_fn"] is not default_autotuner_fn
|
||||
):
|
||||
raise ValueError(
|
||||
f"HelionKernelWrapper for '{op_name}' uses a custom autotuner via "
|
||||
f"config picker. Remove 'autotuner_fn' from helion_settings and use "
|
||||
f"@{op_name}.register_config_picker instead."
|
||||
)
|
||||
|
||||
# Warn if static_shapes is explicitly set to True since most vLLM ops need
|
||||
# dynamic shapes for variable batch sizes and sequence lengths
|
||||
if settings_dict.get("static_shapes") is True:
|
||||
logger.warning(
|
||||
"Kernel '%s' has static_shapes=True in helion_settings. "
|
||||
"Most vLLM ops require dynamic shapes for variable batch sizes "
|
||||
"and sequence lengths. Consider removing this setting.",
|
||||
op_name,
|
||||
)
|
||||
|
||||
|
||||
def create_helion_decorated_kernel(
|
||||
raw_kernel_func: Callable,
|
||||
helion_settings: "helion.Settings | None" = None,
|
||||
extra_kwargs: dict[str, Any] | None = None,
|
||||
) -> Any:
|
||||
kernel_kwargs: dict[str, Any] = {}
|
||||
if helion_settings:
|
||||
kernel_kwargs.update(helion_settings.to_dict())
|
||||
|
||||
# Set static_shapes=False by default if user didn't explicitly set it
|
||||
# This is needed for dynamic batch sizes and sequence lengths in vLLM
|
||||
if kernel_kwargs.get("static_shapes") is not True:
|
||||
kernel_kwargs["static_shapes"] = False
|
||||
|
||||
if extra_kwargs:
|
||||
kernel_kwargs.update(extra_kwargs)
|
||||
|
||||
return helion.kernel(**kernel_kwargs)(raw_kernel_func)
|
||||
|
||||
|
||||
class PresetConfigSearch(BaseAutotuner):
|
||||
"""Custom autotuner that uses a preset config selector instead of autotuning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: tuple[Any, ...],
|
||||
config_selector: Callable[[tuple[Any, ...]], Config],
|
||||
):
|
||||
self.args = args
|
||||
self.config_selector = config_selector
|
||||
|
||||
def autotune(self, *, skip_cache: bool = False) -> Config:
|
||||
return self.config_selector(self.args)
|
||||
|
||||
|
||||
class ConfiguredHelionKernel:
|
||||
"""A configured Helion kernel bound to a specific platform."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
op_name: str,
|
||||
config_picker: Callable[[tuple[Any, ...], list[str]], str | None] | None,
|
||||
raw_kernel_func: Callable,
|
||||
helion_settings: "helion.Settings | None" = None,
|
||||
):
|
||||
self.op_name = op_name
|
||||
self.config_picker = config_picker
|
||||
self.raw_kernel_func = raw_kernel_func
|
||||
self.helion_settings = helion_settings
|
||||
self._decorated_kernel = self._create_decorated_kernel()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._decorated_kernel(*args, **kwargs)
|
||||
|
||||
def _create_key_computer(self):
|
||||
"""
|
||||
Create a key computer function derived from the config picker.
|
||||
|
||||
The returned function receives kernel arguments unpacked (*args) to match
|
||||
Helion's key signature (called as self._key_fn(*args)).
|
||||
"""
|
||||
if self.config_picker is None:
|
||||
raise RuntimeError(
|
||||
f"No config picker registered for kernel '{self.op_name}'. "
|
||||
f"Use @{self.op_name}.register_config_picker to register one."
|
||||
)
|
||||
|
||||
# After None check, config_picker is guaranteed to be non-None
|
||||
assert self.config_picker is not None
|
||||
|
||||
def key_computer(*args):
|
||||
config_keys = list(self.configs.keys())
|
||||
# Cast is safe because we checked for None above
|
||||
config_picker = cast(
|
||||
Callable[[tuple[Any, ...], list[str]], str | None], self.config_picker
|
||||
)
|
||||
selected_key = config_picker(args, config_keys)
|
||||
if selected_key:
|
||||
return selected_key
|
||||
return "default" if "default" in self.configs else None
|
||||
|
||||
return key_computer
|
||||
|
||||
def _create_config_selector(self, key_computer):
|
||||
def config_selector(args):
|
||||
# args is a tuple; key_computer expects unpacked args
|
||||
selected_config_key = key_computer(*args)
|
||||
|
||||
if selected_config_key is None:
|
||||
raise ValueError(
|
||||
f"Config picker returned None for kernel '{self.op_name}' "
|
||||
f"with available config keys: {list(self.configs.keys())}"
|
||||
)
|
||||
|
||||
if selected_config_key not in self.configs:
|
||||
raise ValueError(
|
||||
f"Config picker returned invalid config key "
|
||||
f"'{selected_config_key}' for kernel '{self.op_name}'. "
|
||||
f"Available keys: {list(self.configs.keys())}"
|
||||
)
|
||||
|
||||
return self.configs[selected_config_key]
|
||||
|
||||
return config_selector
|
||||
|
||||
def _load_platform_configs(self) -> None:
|
||||
from vllm.kernels.helion.config_manager import ConfigManager
|
||||
from vllm.kernels.helion.utils import get_canonical_gpu_name
|
||||
|
||||
self.platform = get_canonical_gpu_name()
|
||||
config_manager = ConfigManager.get_instance()
|
||||
self.configs = config_manager.get_platform_configs(self.op_name, self.platform)
|
||||
|
||||
if not self.configs:
|
||||
raise ValueError(
|
||||
f"No configs available for kernel '{self.op_name}' "
|
||||
f"on platform '{self.platform}'"
|
||||
)
|
||||
|
||||
def _create_decorated_kernel(self) -> Callable[..., Any]:
|
||||
self._load_platform_configs()
|
||||
|
||||
key_computer = self._create_key_computer()
|
||||
config_selector = self._create_config_selector(key_computer)
|
||||
|
||||
extra_kwargs = {
|
||||
"autotuner_fn": lambda _, args: PresetConfigSearch(args, config_selector),
|
||||
"key": key_computer,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"Creating decorated kernel %s with custom autotuner on platform %s",
|
||||
self.op_name,
|
||||
self.platform,
|
||||
)
|
||||
return create_helion_decorated_kernel(
|
||||
self.raw_kernel_func, self.helion_settings, extra_kwargs
|
||||
)
|
||||
|
||||
|
||||
class HelionKernelWrapper:
|
||||
"""Wrapper for Helion kernels that creates config-specific PyTorch custom ops."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
raw_kernel_func: Callable,
|
||||
op_name: str,
|
||||
fake_impl: Callable,
|
||||
helion_settings: "helion.Settings | None" = None,
|
||||
):
|
||||
# Validate helion_settings doesn't conflict with our custom autotuner
|
||||
validate_helion_settings(helion_settings, op_name)
|
||||
|
||||
self.raw_kernel_func = raw_kernel_func
|
||||
self.op_name = op_name
|
||||
self._fake_impl = fake_impl
|
||||
self.helion_settings = helion_settings
|
||||
self._config_picker: (
|
||||
Callable[[tuple[Any, ...], list[str]], str | None] | None
|
||||
) = None
|
||||
self._input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
configured_op = self.get_configured_op()
|
||||
return configured_op(*args, **kwargs)
|
||||
|
||||
def register_config_picker(
|
||||
self, picker_func: Callable[[tuple[Any, ...], list[str]], str | None]
|
||||
) -> Callable[[tuple[Any, ...], list[str]], str | None]:
|
||||
self._config_picker = picker_func
|
||||
return picker_func
|
||||
|
||||
def register_input_generator(
|
||||
self, generator_func: Callable[[], dict[str, tuple[Any, ...]]]
|
||||
) -> Callable[[], dict[str, tuple[Any, ...]]]:
|
||||
"""
|
||||
Register a function to generate inputs for autotuning and benchmarking.
|
||||
|
||||
Args:
|
||||
generator_func: Function that returns dict[str, tuple] where:
|
||||
- key: Configuration identifier (e.g., "4096", "hidden_4096")
|
||||
- value: Tuple of arguments to pass to the kernel
|
||||
|
||||
Returns:
|
||||
The registered function (for decorator usage)
|
||||
|
||||
Example:
|
||||
@kernel_wrapper.register_input_generator
|
||||
def generate_inputs():
|
||||
return {
|
||||
"4096": (torch.randn(4096, device="cuda"), 0.5),
|
||||
"8192": (torch.randn(8192, device="cuda"), 0.5),
|
||||
}
|
||||
"""
|
||||
self._input_generator = generator_func
|
||||
return generator_func
|
||||
|
||||
def get_inputs(self) -> dict[str, tuple[Any, ...]]:
|
||||
if self._input_generator is None:
|
||||
raise NotImplementedError(
|
||||
f"No input generator registered for kernel '{self.op_name}'. "
|
||||
f"Use @{self.op_name}.register_input_generator to register one."
|
||||
)
|
||||
return self._input_generator()
|
||||
|
||||
def run_autotune(
|
||||
self,
|
||||
inputs: tuple[Any, ...],
|
||||
autotune_effort: str = "quick",
|
||||
) -> Config:
|
||||
"""Run autotuning for a single input configuration."""
|
||||
extra_kwargs = {"autotune_effort": autotune_effort}
|
||||
autotune_kernel = create_helion_decorated_kernel(
|
||||
self.raw_kernel_func, self.helion_settings, extra_kwargs
|
||||
)
|
||||
return autotune_kernel.autotune(inputs)
|
||||
|
||||
def get_configured_op(self) -> Any:
|
||||
assert self._config_picker is not None, (
|
||||
f"No config picker registered for kernel '{self.op_name}'. "
|
||||
f"Use @{self.op_name}.register_config_picker to register one."
|
||||
)
|
||||
|
||||
if hasattr(torch.ops.vllm_helion, self.op_name):
|
||||
logger.debug("Op vllm_helion::%s already registered", self.op_name)
|
||||
return getattr(torch.ops.vllm_helion, self.op_name)
|
||||
|
||||
configured_kernel = ConfiguredHelionKernel(
|
||||
op_name=self.op_name,
|
||||
config_picker=self._config_picker,
|
||||
raw_kernel_func=self.raw_kernel_func,
|
||||
helion_settings=self.helion_settings,
|
||||
)
|
||||
|
||||
logger.info("Registering op: vllm_helion::%s", self.op_name)
|
||||
direct_register_custom_op(
|
||||
op_name=self.op_name,
|
||||
op_func=configured_kernel._decorated_kernel, # Register decorated kernel
|
||||
# TODO(gmagogsfm): Implement automatic mutation/aliasing detection
|
||||
# for Helion kernels.
|
||||
mutates_args=None,
|
||||
fake_impl=self._fake_impl,
|
||||
target_lib=vllm_helion_lib,
|
||||
)
|
||||
return getattr(torch.ops.vllm_helion, self.op_name)
|
||||
|
||||
|
||||
# Global registry for tracking all registered HelionKernelWrapper instances
|
||||
_REGISTERED_KERNELS: dict[str, HelionKernelWrapper] = {}
|
||||
|
||||
|
||||
def get_registered_kernels() -> dict[str, HelionKernelWrapper]:
|
||||
return _REGISTERED_KERNELS.copy()
|
||||
|
||||
|
||||
def get_kernel_by_name(kernel_name: str) -> HelionKernelWrapper | None:
|
||||
return _REGISTERED_KERNELS.get(kernel_name)
|
||||
|
||||
|
||||
def infer_fake_impl(
|
||||
kernel_func: Callable,
|
||||
helion_settings: "helion.Settings | None" = None,
|
||||
) -> Callable:
|
||||
def helion_fake_kernel(*args, **kwargs):
|
||||
kernel_kwargs = {}
|
||||
if helion_settings:
|
||||
kernel_kwargs.update(helion_settings.to_dict())
|
||||
|
||||
temp_decorated_kernel = helion.kernel(**kernel_kwargs)(kernel_func)
|
||||
|
||||
# Bind with args to get config_spec, then get a valid default config
|
||||
bound = temp_decorated_kernel.bind(args)
|
||||
default_config = bound.config_spec.default_config()
|
||||
compiled_runner = bound.compile_config(default_config)
|
||||
|
||||
return compiled_runner(*args, **kwargs, _launcher=lambda *a, **kw: None)
|
||||
|
||||
return helion_fake_kernel
|
||||
|
||||
|
||||
# Overloads are necessary for proper mypy type inference.
|
||||
# Without overloads, the union return type HelionKernelWrapper | Callable[...]
|
||||
# causes mypy to complain about missing attributes when tests do:
|
||||
# wrapper = register_kernel(func) # Should return HelionKernelWrapper
|
||||
# wrapper._fake_impl # mypy error: "Callable has no attribute _fake_impl"
|
||||
# The overloads tell mypy the exact return type based on the argument pattern.
|
||||
@overload
|
||||
def register_kernel(
|
||||
op_name_or_func: Callable,
|
||||
*,
|
||||
fake_impl: Callable | None = None,
|
||||
helion_settings: "helion.Settings | None" = None,
|
||||
) -> HelionKernelWrapper: ...
|
||||
|
||||
|
||||
@overload
|
||||
def register_kernel(
|
||||
op_name_or_func: str | None = None,
|
||||
*,
|
||||
fake_impl: Callable | None = None,
|
||||
helion_settings: "helion.Settings | None" = None,
|
||||
) -> Callable[[Callable], HelionKernelWrapper]: ...
|
||||
|
||||
|
||||
def register_kernel(
|
||||
op_name_or_func: str | Callable | None = None,
|
||||
*,
|
||||
fake_impl: Callable | None = None,
|
||||
helion_settings: "helion.Settings | None" = None,
|
||||
) -> HelionKernelWrapper | Callable[[Callable], HelionKernelWrapper]:
|
||||
"""
|
||||
Decorator to register a Helion kernel function as a HelionKernelWrapper.
|
||||
|
||||
Wraps the raw kernel function in a HelionKernelWrapper and registers it
|
||||
in the global kernel registry. Auto-generates fake_impl if not provided.
|
||||
"""
|
||||
|
||||
def decorator(kernel_func: Callable) -> HelionKernelWrapper:
|
||||
op_name = op_name_or_func if isinstance(op_name_or_func, str) else None
|
||||
final_op_name = op_name if op_name else kernel_func.__name__
|
||||
|
||||
if final_op_name in _REGISTERED_KERNELS:
|
||||
raise ValueError(
|
||||
f"Helion kernel '{final_op_name}' is already registered. "
|
||||
f"Use a different op_name or check for duplicate registrations."
|
||||
)
|
||||
|
||||
final_fake_impl = fake_impl
|
||||
if final_fake_impl is None:
|
||||
final_fake_impl = infer_fake_impl(kernel_func, helion_settings)
|
||||
logger.debug(
|
||||
"Auto-generated fake_impl for Helion kernel '%s'",
|
||||
kernel_func.__name__,
|
||||
)
|
||||
|
||||
kernel_wrapper = HelionKernelWrapper(
|
||||
raw_kernel_func=kernel_func,
|
||||
op_name=final_op_name,
|
||||
fake_impl=final_fake_impl,
|
||||
helion_settings=helion_settings,
|
||||
)
|
||||
|
||||
_REGISTERED_KERNELS[final_op_name] = kernel_wrapper
|
||||
|
||||
logger.info(
|
||||
"Registered Helion kernel '%s' as HelionKernelWrapper",
|
||||
kernel_func.__name__,
|
||||
)
|
||||
|
||||
return kernel_wrapper
|
||||
|
||||
if callable(op_name_or_func) and not isinstance(op_name_or_func, str):
|
||||
# Bare decorator usage: @register_kernel
|
||||
return decorator(op_name_or_func)
|
||||
else:
|
||||
# Decorator with arguments: @register_kernel(...)
|
||||
return decorator
|
||||
81
vllm/kernels/helion/utils.py
Normal file
81
vllm/kernels/helion/utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility functions for Helion kernel management."""
|
||||
|
||||
import logging
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maps known variant GPU names (after lowercase/underscore normalization)
|
||||
# to their canonical form.
|
||||
#
|
||||
# Names that are already canonical after normalization are NOT listed here.
|
||||
# For example, "NVIDIA H200" normalizes to "nvidia_h200" which needs no
|
||||
# further mapping, and AMD ROCm names like "AMD_Instinct_MI300X" come from
|
||||
# a controlled lookup table in rocm.py and normalize cleanly to
|
||||
# "amd_instinct_mi300x". Only names with variant suffixes (form factor,
|
||||
# memory size, memory type, etc.) that should be stripped need entries.
|
||||
#
|
||||
# To add a new GPU variant: run `canonicalize_gpu_name()` without the alias
|
||||
# to see the normalized name, then add a mapping here if it contains variant
|
||||
# suffixes that should be stripped (e.g. Blackwell/Rubin variants).
|
||||
_GPU_NAME_ALIASES: dict[str, str] = {
|
||||
# H100 variants
|
||||
"nvidia_h100_pcie": "nvidia_h100",
|
||||
"nvidia_h100_sxm5": "nvidia_h100",
|
||||
"nvidia_h100_80gb_hbm3": "nvidia_h100",
|
||||
"nvidia_h100_nvl": "nvidia_h100",
|
||||
# H200 variants
|
||||
"nvidia_h200_nvl": "nvidia_h200",
|
||||
"nvidia_h200_141gb_hbm3e": "nvidia_h200",
|
||||
# A100 variants
|
||||
"nvidia_a100_sxm4_80gb": "nvidia_a100",
|
||||
"nvidia_a100_sxm4_40gb": "nvidia_a100",
|
||||
"nvidia_a100_pcie_80gb": "nvidia_a100",
|
||||
"nvidia_a100_pcie_40gb": "nvidia_a100",
|
||||
"nvidia_a100_80gb_pcie": "nvidia_a100",
|
||||
# V100 variants (Tesla-branded)
|
||||
"tesla_v100_sxm2_32gb": "tesla_v100",
|
||||
"tesla_v100_sxm2_16gb": "tesla_v100",
|
||||
"tesla_v100_pcie_32gb": "tesla_v100",
|
||||
"tesla_v100_pcie_16gb": "tesla_v100",
|
||||
# AMD ROCm variants (from _ROCM_DEVICE_ID_NAME_MAP in rocm.py)
|
||||
"amd_instinct_mi300x_hf": "amd_instinct_mi300x",
|
||||
# ADD MORE HERE
|
||||
}
|
||||
|
||||
|
||||
def get_gpu_name(device_id: int | None = None) -> str:
|
||||
if device_id is None:
|
||||
logger.warning(
|
||||
"get_gpu_name() called without device_id, defaulting to 0. "
|
||||
"This may return the wrong device name in multi-node setups."
|
||||
)
|
||||
device_id = 0
|
||||
return current_platform.get_device_name(device_id)
|
||||
|
||||
|
||||
def canonicalize_gpu_name(name: str) -> str:
|
||||
"""
|
||||
Canonicalize GPU name for use as a platform identifier.
|
||||
|
||||
Converts to lowercase, replaces spaces and hyphens with underscores,
|
||||
and maps known variant names to their canonical form via _GPU_NAME_ALIASES.
|
||||
e.g., "NVIDIA H100 80GB HBM3" -> "nvidia_h100"
|
||||
"NVIDIA A100-SXM4-80GB" -> "nvidia_a100"
|
||||
"AMD Instinct MI300X" -> "amd_instinct_mi300x"
|
||||
"""
|
||||
if not name or not name.strip():
|
||||
raise ValueError("GPU name cannot be empty")
|
||||
name = name.lower()
|
||||
name = name.replace(" ", "_")
|
||||
name = name.replace("-", "_")
|
||||
if name in _GPU_NAME_ALIASES:
|
||||
return _GPU_NAME_ALIASES[name]
|
||||
return name
|
||||
|
||||
|
||||
def get_canonical_gpu_name(device_id: int | None = None) -> str:
|
||||
return canonicalize_gpu_name(get_gpu_name(device_id))
|
||||
Reference in New Issue
Block a user