Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -31,8 +31,8 @@ by key matches the config returned by the autotuner.
|
||||
|
||||
Key Classes
|
||||
-----------
|
||||
- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured ops
|
||||
- ConfiguredHelionKernel: Platform-specific kernel registered as PyTorch custom op
|
||||
- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured kernels
|
||||
- ConfiguredHelionKernel: Platform-specific kernel with pre-tuned configs
|
||||
- PresetConfigSearch: Custom autotuner that returns pre-tuned configs
|
||||
"""
|
||||
|
||||
@@ -53,10 +53,27 @@ if not has_helion():
|
||||
)
|
||||
|
||||
import helion
|
||||
from helion._compat import requires_torch_version
|
||||
from helion.autotuner.base_search import BaseAutotuner
|
||||
from helion.runtime.config import Config
|
||||
from helion.runtime.settings import default_autotuner_fn
|
||||
|
||||
# TODO(gmagogsfm): Remove CustomOp fallback path (_get_or_register_custom_op,
|
||||
# vllm_helion_lib, direct_register_custom_op) once vLLM requires PyTorch >= 2.11.
|
||||
_HOP_AVAILABLE = requires_torch_version("2.11")
|
||||
|
||||
if _HOP_AVAILABLE:
|
||||
import torch.utils._pytree as pytree
|
||||
from helion._compiler._dynamo.higher_order_ops import (
|
||||
helion_kernel_side_table,
|
||||
helion_kernel_wrapper_mutation,
|
||||
)
|
||||
from helion._compiler._dynamo.variables import infer_output_spec
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
get_proxy_mode,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
vllm_helion_lib = Library("vllm_helion", "FRAGMENT") # noqa
|
||||
@@ -233,7 +250,7 @@ class ConfiguredHelionKernel:
|
||||
|
||||
|
||||
class HelionKernelWrapper:
|
||||
"""Wrapper for Helion kernels that creates config-specific PyTorch custom ops."""
|
||||
"""Wrapper for Helion kernels with pre-tuned config selection and HOP support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -252,11 +269,86 @@ class HelionKernelWrapper:
|
||||
self._config_picker: (
|
||||
Callable[[tuple[Any, ...], list[str]], str | None] | None
|
||||
) = None
|
||||
self._configured_kernel: ConfiguredHelionKernel | None = None
|
||||
self._input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
configured_op = self.get_configured_op()
|
||||
return configured_op(*args, **kwargs)
|
||||
# CustomOp fallback: register as torch custom op for torch.compile
|
||||
# compatibility on older PyTorch lacking HOP/EffectType support
|
||||
if not _HOP_AVAILABLE:
|
||||
custom_op = self._get_or_register_custom_op()
|
||||
return custom_op(*args, **kwargs)
|
||||
# HOP tracing: record HigherOrderOp in the FX graph
|
||||
if get_proxy_mode() is not None:
|
||||
return self._call_via_hop(args, kwargs)
|
||||
# Eager: run the configured kernel directly
|
||||
return self.get_configured_op()(*args, **kwargs)
|
||||
|
||||
def _call_via_hop(
|
||||
self,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
kernel = self.get_configured_op()._decorated_kernel
|
||||
kernel_idx = helion_kernel_side_table.add_kernel(kernel)
|
||||
|
||||
constant_args, tensor_args = self._partition_args(kernel, args, kwargs)
|
||||
|
||||
all_named = {**constant_args, **tensor_args}
|
||||
full_args = tuple(
|
||||
all_named.get(n, p.default)
|
||||
for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined]
|
||||
if n in all_named or p.default is not p.empty
|
||||
)
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
output_spec = infer_output_spec(kernel, full_args)
|
||||
|
||||
hop_result = helion_kernel_wrapper_mutation(
|
||||
kernel_idx=kernel_idx,
|
||||
constant_args=constant_args,
|
||||
tensor_args=tensor_args,
|
||||
output_spec=output_spec,
|
||||
)
|
||||
|
||||
tree_spec_str = output_spec.get("tree_spec_str")
|
||||
if tree_spec_str is None:
|
||||
return None
|
||||
tree_spec = pytree.treespec_loads(tree_spec_str)
|
||||
|
||||
hop_iter = iter(hop_result)
|
||||
reconstructed = []
|
||||
for spec in output_spec["leaf_specs"]:
|
||||
is_constant_scalar = spec["type"] == "scalar" and not isinstance(
|
||||
spec.get("scalar_value"), torch.SymInt
|
||||
)
|
||||
if is_constant_scalar:
|
||||
reconstructed.append(spec["scalar_value"])
|
||||
else:
|
||||
reconstructed.append(next(hop_iter))
|
||||
return pytree.tree_unflatten(reconstructed, tree_spec)
|
||||
|
||||
@staticmethod
|
||||
def _partition_args(
|
||||
kernel: Any,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
constant_args: dict[str, Any] = {}
|
||||
tensor_args: dict[str, Any] = {}
|
||||
params = list(kernel.signature.parameters.keys())
|
||||
for i, val in enumerate(args):
|
||||
name = params[i]
|
||||
if isinstance(val, torch.Tensor):
|
||||
tensor_args[name] = val
|
||||
else:
|
||||
constant_args[name] = val
|
||||
for name, val in kwargs.items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
tensor_args[name] = val
|
||||
else:
|
||||
constant_args[name] = val
|
||||
return constant_args, tensor_args
|
||||
|
||||
def register_config_picker(
|
||||
self, picker_func: Callable[[tuple[Any, ...], list[str]], str | None]
|
||||
@@ -309,29 +401,32 @@ class HelionKernelWrapper:
|
||||
)
|
||||
return autotune_kernel.autotune(inputs)
|
||||
|
||||
def get_configured_op(self) -> Any:
|
||||
def get_configured_op(self) -> ConfiguredHelionKernel:
|
||||
assert self._config_picker is not None, (
|
||||
f"No config picker registered for kernel '{self.op_name}'. "
|
||||
f"Use @{self.op_name}.register_config_picker to register one."
|
||||
)
|
||||
|
||||
if self._configured_kernel is None:
|
||||
self._configured_kernel = ConfiguredHelionKernel(
|
||||
op_name=self.op_name,
|
||||
config_picker=self._config_picker,
|
||||
raw_kernel_func=self.raw_kernel_func,
|
||||
helion_settings=self.helion_settings,
|
||||
)
|
||||
|
||||
return self._configured_kernel
|
||||
|
||||
def _get_or_register_custom_op(self) -> Any:
|
||||
if hasattr(torch.ops.vllm_helion, self.op_name):
|
||||
logger.debug("Op vllm_helion::%s already registered", self.op_name)
|
||||
return getattr(torch.ops.vllm_helion, self.op_name)
|
||||
|
||||
configured_kernel = ConfiguredHelionKernel(
|
||||
op_name=self.op_name,
|
||||
config_picker=self._config_picker,
|
||||
raw_kernel_func=self.raw_kernel_func,
|
||||
helion_settings=self.helion_settings,
|
||||
)
|
||||
configured_kernel = self.get_configured_op()
|
||||
|
||||
logger.info("Registering op: vllm_helion::%s", self.op_name)
|
||||
direct_register_custom_op(
|
||||
op_name=self.op_name,
|
||||
op_func=configured_kernel._decorated_kernel, # Register decorated kernel
|
||||
# TODO(gmagogsfm): Implement automatic mutation/aliasing detection
|
||||
# for Helion kernels.
|
||||
op_func=configured_kernel._decorated_kernel,
|
||||
mutates_args=None,
|
||||
fake_impl=self._fake_impl,
|
||||
target_lib=vllm_helion_lib,
|
||||
|
||||
Reference in New Issue
Block a user