diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 2ba9917..af19d4a 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +# patch_utils should be the first import, because it will be used by other +# patch files. +import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_utils.py b/vllm_ascend/patch/worker/patch_common/patch_utils.py new file mode 100644 index 0000000..dec618c --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_utils.py @@ -0,0 +1,38 @@ +from typing import Callable, List, Optional, Tuple + +import torch +from torch.library import Library +from vllm import utils +from vllm.utils import vllm_lib + + +def ascend_direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: list[str], + fake_impl: Optional[Callable] = None, + target_lib: Optional[Library] = None, + dispatch_key: str = "CUDA", + tags: Tuple[torch.Tag, ...] = (), +): + # In pytorch 2.5.1, torch.library.infer_schema require the input function to + # have annotations supported by typing library. But in pytorch 2.7.0 which + # vllm using, torch.library.infer_schema require the python builtin type. In + # this case, we should revert built type to typing type for 2.5.1 backward + # compatibility. + for k, v in op_func.__annotations__.items(): + if v == list[int]: + op_func.__annotations__[k] = List[int] + if v == Optional[list[int]]: + op_func.__annotations__[k] = Optional[List[int]] + # TODO: add more type convert here if needed. + import torch.library + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + my_lib = target_lib or vllm_lib + my_lib.define(op_name + schema_str, tags=tags) + my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) + + +utils.direct_register_custom_op = ascend_direct_register_custom_op