Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_common/patch_utils.py
wangxiyuan 857f489cbf [CI] Patch torch.library.infer_schema for torch 2.5 backward compatibility (#837)
Patch torch.library.infer_schema for torch 2.5 backward compatibility

- Introduced a new module `patch_utils` under
`vllm_ascend/patch/worker/patch_common/`.
- Added a function `ascend_direct_register_custom_op` to handle custom
operator registration with backward compatibility for PyTorch < 2.7
(such as torch 2.5.1).
- Implemented type conversion logic for annotations to ensure
compatibility across different PyTorch versions.
- Registered the function `ascend_direct_register_custom_op` to
`utils.direct_register_custom_op`.

- Updated `__init__.py` to include `patch_utils` as the first import.
- Ensured `patch_utils` is available for use in other patch files and
skipped isort checks for `patch_utils` import.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-05-14 09:20:55 +08:00

39 lines
1.4 KiB
Python

from typing import Callable, List, Optional, Tuple
import torch
from torch.library import Library
from vllm import utils
from vllm.utils import vllm_lib
def ascend_direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: Tuple[torch.Tag, ...] = (),
):
# In pytorch 2.5.1, torch.library.infer_schema require the input function to
# have annotations supported by typing library. But in pytorch 2.7.0 which
# vllm using, torch.library.infer_schema require the python builtin type. In
# this case, we should revert built type to typing type for 2.5.1 backward
# compatibility.
for k, v in op_func.__annotations__.items():
if v == list[int]:
op_func.__annotations__[k] = List[int]
if v == Optional[list[int]]:
op_func.__annotations__[k] = Optional[List[int]]
# TODO: add more type convert here if needed.
import torch.library
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
utils.direct_register_custom_op = ascend_direct_register_custom_op