From 68fb63428b8b972dd60ad9189538909b0eb1fcc8 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 14 May 2025 19:49:09 +0800 Subject: [PATCH] [CI] Patch torch.library.infer_schema for fused moe ops to fix CI (#854) make sure pytorch infer_schema check is patched before some case which using fused moe ops: 1. model register 2. quantization loading 3. fused moe ut Signed-off-by: wangxiyuan --- tests/ops/test_fused_moe.py | 3 +++ vllm_ascend/__init__.py | 4 ++++ vllm_ascend/quantization/quant_config.py | 4 ++++ 3 files changed, 11 insertions(+) diff --git a/tests/ops/test_fused_moe.py b/tests/ops/test_fused_moe.py index 7b21307..78c0d88 100644 --- a/tests/ops/test_fused_moe.py +++ b/tests/ops/test_fused_moe.py @@ -19,6 +19,9 @@ Run `pytest tests/ops/test_fused_moe.py`. """ +# fused moe ops test will hit the infer_schema error, we need add the patch +# here to make the test pass. +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa import pytest import torch diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index 7588e70..c8f3331 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -23,5 +23,9 @@ def register(): def register_model(): + # fix pytorch schema check error, remove this line after pytorch + # is upgraded to 2.7.0 + import vllm_ascend.patch.worker.patch_common.patch_utils # noqa: F401 + from .models import register_model register_model() diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 22b61f2..499e236 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -15,6 +15,10 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # +# By using quantization case, this file is called before worker patch achieve, +# we need to import patch_utils here first to make sure the patch is applied. +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + from types import MappingProxyType from typing import Any, Callable, Dict, List, Mapping, Optional