Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -31,8 +31,8 @@ by key matches the config returned by the autotuner.
Key Classes
-----------
- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured ops
- ConfiguredHelionKernel: Platform-specific kernel registered as PyTorch custom op
- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured kernels
- ConfiguredHelionKernel: Platform-specific kernel with pre-tuned configs
- PresetConfigSearch: Custom autotuner that returns pre-tuned configs
"""
@@ -53,10 +53,27 @@ if not has_helion():
)
import helion
from helion._compat import requires_torch_version
from helion.autotuner.base_search import BaseAutotuner
from helion.runtime.config import Config
from helion.runtime.settings import default_autotuner_fn
# TODO(gmagogsfm): Remove CustomOp fallback path (_get_or_register_custom_op,
# vllm_helion_lib, direct_register_custom_op) once vLLM requires PyTorch >= 2.11.
_HOP_AVAILABLE = requires_torch_version("2.11")
if _HOP_AVAILABLE:
import torch.utils._pytree as pytree
from helion._compiler._dynamo.higher_order_ops import (
helion_kernel_side_table,
helion_kernel_wrapper_mutation,
)
from helion._compiler._dynamo.variables import infer_output_spec
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
get_proxy_mode,
)
logger = init_logger(__name__)
vllm_helion_lib = Library("vllm_helion", "FRAGMENT") # noqa
@@ -233,7 +250,7 @@ class ConfiguredHelionKernel:
class HelionKernelWrapper:
"""Wrapper for Helion kernels that creates config-specific PyTorch custom ops."""
"""Wrapper for Helion kernels with pre-tuned config selection and HOP support."""
def __init__(
self,
@@ -252,11 +269,86 @@ class HelionKernelWrapper:
self._config_picker: (
Callable[[tuple[Any, ...], list[str]], str | None] | None
) = None
self._configured_kernel: ConfiguredHelionKernel | None = None
self._input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None
def __call__(self, *args, **kwargs):
configured_op = self.get_configured_op()
return configured_op(*args, **kwargs)
# CustomOp fallback: register as torch custom op for torch.compile
# compatibility on older PyTorch lacking HOP/EffectType support
if not _HOP_AVAILABLE:
custom_op = self._get_or_register_custom_op()
return custom_op(*args, **kwargs)
# HOP tracing: record HigherOrderOp in the FX graph
if get_proxy_mode() is not None:
return self._call_via_hop(args, kwargs)
# Eager: run the configured kernel directly
return self.get_configured_op()(*args, **kwargs)
def _call_via_hop(
self,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
kernel = self.get_configured_op()._decorated_kernel
kernel_idx = helion_kernel_side_table.add_kernel(kernel)
constant_args, tensor_args = self._partition_args(kernel, args, kwargs)
all_named = {**constant_args, **tensor_args}
full_args = tuple(
all_named.get(n, p.default)
for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined]
if n in all_named or p.default is not p.empty
)
with disable_proxy_modes_tracing():
output_spec = infer_output_spec(kernel, full_args)
hop_result = helion_kernel_wrapper_mutation(
kernel_idx=kernel_idx,
constant_args=constant_args,
tensor_args=tensor_args,
output_spec=output_spec,
)
tree_spec_str = output_spec.get("tree_spec_str")
if tree_spec_str is None:
return None
tree_spec = pytree.treespec_loads(tree_spec_str)
hop_iter = iter(hop_result)
reconstructed = []
for spec in output_spec["leaf_specs"]:
is_constant_scalar = spec["type"] == "scalar" and not isinstance(
spec.get("scalar_value"), torch.SymInt
)
if is_constant_scalar:
reconstructed.append(spec["scalar_value"])
else:
reconstructed.append(next(hop_iter))
return pytree.tree_unflatten(reconstructed, tree_spec)
@staticmethod
def _partition_args(
kernel: Any,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
constant_args: dict[str, Any] = {}
tensor_args: dict[str, Any] = {}
params = list(kernel.signature.parameters.keys())
for i, val in enumerate(args):
name = params[i]
if isinstance(val, torch.Tensor):
tensor_args[name] = val
else:
constant_args[name] = val
for name, val in kwargs.items():
if isinstance(val, torch.Tensor):
tensor_args[name] = val
else:
constant_args[name] = val
return constant_args, tensor_args
def register_config_picker(
self, picker_func: Callable[[tuple[Any, ...], list[str]], str | None]
@@ -309,29 +401,32 @@ class HelionKernelWrapper:
)
return autotune_kernel.autotune(inputs)
def get_configured_op(self) -> Any:
def get_configured_op(self) -> ConfiguredHelionKernel:
assert self._config_picker is not None, (
f"No config picker registered for kernel '{self.op_name}'. "
f"Use @{self.op_name}.register_config_picker to register one."
)
if self._configured_kernel is None:
self._configured_kernel = ConfiguredHelionKernel(
op_name=self.op_name,
config_picker=self._config_picker,
raw_kernel_func=self.raw_kernel_func,
helion_settings=self.helion_settings,
)
return self._configured_kernel
def _get_or_register_custom_op(self) -> Any:
if hasattr(torch.ops.vllm_helion, self.op_name):
logger.debug("Op vllm_helion::%s already registered", self.op_name)
return getattr(torch.ops.vllm_helion, self.op_name)
configured_kernel = ConfiguredHelionKernel(
op_name=self.op_name,
config_picker=self._config_picker,
raw_kernel_func=self.raw_kernel_func,
helion_settings=self.helion_settings,
)
configured_kernel = self.get_configured_op()
logger.info("Registering op: vllm_helion::%s", self.op_name)
direct_register_custom_op(
op_name=self.op_name,
op_func=configured_kernel._decorated_kernel, # Register decorated kernel
# TODO(gmagogsfm): Implement automatic mutation/aliasing detection
# for Helion kernels.
op_func=configured_kernel._decorated_kernel,
mutates_args=None,
fake_impl=self._fake_impl,
target_lib=vllm_helion_lib,