Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -482,3 +482,44 @@ if is_torch_equal("2.9.0"):
|
||||
|
||||
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
|
||||
GraphLowering._update_scheduler = _update_scheduler_patched
|
||||
|
||||
# ===================================================
|
||||
# torch 2.11 Inductor constrain_to_fx_strides monkeypatch
|
||||
# ===================================================
|
||||
# Patch the inductor's `constrain_to_fx_strides` to handle opaque
|
||||
# (non-tensor) arguments. The original calls `.stride()` on every FX
|
||||
# arg's meta value, which crashes on FakeScriptObject (the compile-time
|
||||
# proxy for hoisted opaque types). The patched version skips args
|
||||
# whose meta value is not a torch.Tensor.
|
||||
# Upstream issue: https://github.com/pytorch/pytorch/issues/175973
|
||||
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
if is_torch_equal_or_newer("2.11.0.dev"):
|
||||
import torch._inductor.ir as _ir
|
||||
import torch._inductor.lowering as _lowering
|
||||
from torch._inductor.virtualized import V as _V
|
||||
|
||||
_orig_constrain = _lowering.constrain_to_fx_strides
|
||||
|
||||
def _patched_constrain_to_fx_strides(fx_node, *args, **kwargs):
|
||||
def apply_constraint(arg, fx_arg):
|
||||
if isinstance(arg, _ir.IRNode):
|
||||
meta_val = fx_arg.meta.get("val")
|
||||
if isinstance(meta_val, torch.Tensor):
|
||||
stride_order = _ir.get_stride_order(
|
||||
meta_val.stride(), _V.graph.sizevars.shape_env
|
||||
)
|
||||
return _ir.ExternKernel.require_stride_order(arg, stride_order)
|
||||
return arg
|
||||
if isinstance(arg, dict):
|
||||
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
|
||||
return arg
|
||||
|
||||
args = tuple(
|
||||
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
|
||||
)
|
||||
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
|
||||
return args, kwargs
|
||||
|
||||
_lowering.constrain_to_fx_strides = _patched_constrain_to_fx_strides
|
||||
|
||||
Reference in New Issue
Block a user