diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 2451a6f..32a0684 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -47,7 +47,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install -r requirements-dev.txt + pip install -r requirements-dev.txt - name: Checkout vllm-project/vllm repo uses: actions/checkout@v4 diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 7fef514..f5a7038 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -142,9 +142,10 @@ jobs: pytest -sv tests/ops fi + # only run test on spec decode when the related code changed - name: Check for changes in Speculative Decode id: filter_spec_decode - uses: dorny/paths-filter@v2 + uses: dorny/paths-filter@v3 with: filters: | speculative_tests_changed: @@ -155,10 +156,9 @@ jobs: - "vllm_ascend/patch/patch_rejection_sampler.py" - "vllm_ascend/patch/patch_spec_decode_worker.py" - "vllm_ascend/patch/patch_multi_step_worker.py" + - name: Run vllm-project/vllm-ascend Speculative Decode test - env: - HF_ENDPOINT: https://hf-mirror.com - if: steps.filter_spec_decode.outputs.speculative_tests_changed + if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true' run: | if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then pytest -sv tests/singlecard/spec_decode diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 28a5cbc..7a2d179 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -16,7 +16,7 @@ # from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import numpy as np import torch @@ -38,9 +38,8 @@ from vllm.attention.backends.utils import (CommonAttentionState, is_block_tables_empty) from vllm.utils import async_tensor_h2d, make_tensor_with_pad -if TYPE_CHECKING: - from vllm_ascend.worker.model_runner import ( - ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata) +from vllm_ascend.worker.model_runner import ( + ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata) def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None): @@ -489,7 +488,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): 128, self.input_builder.runner.model_config.dtype) def _add_seq_group( - self, inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup", + self, inter_data: ModelInputForNPUBuilder.InterDataForSeqGroup, chunked_prefill_enabled: bool): """Add a sequence group to the metadata. Specifically update/append 1. context length. diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 0fb24d2..3bf1f68 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -14,14 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import vllm -from packaging.version import Version - -# Import common patches for all versions -from vllm_ascend.patch.platform import patch_common # noqa: F401 +from vllm_ascend.utils import vllm_version_is # Import specific patches for different versions -if Version(vllm.__version__) == Version("0.8.4"): +if vllm_version_is("0.8.4"): from vllm_ascend.patch.platform import patch_0_8_4 # noqa: F401 + from vllm_ascend.patch.platform import patch_common # noqa: F401 else: + from vllm_ascend.patch.platform import patch_common # noqa: F401 from vllm_ascend.patch.platform import patch_main # noqa: F401 diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index c2f54a7..4b6b83d 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -14,14 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import vllm -from packaging.version import Version -# Import common patches for all versions -from vllm_ascend.patch.worker import patch_common # noqa: F401 +from vllm_ascend.utils import vllm_version_is # Import specific patches for different versions -if Version(vllm.__version__) == Version("0.8.4"): +if vllm_version_is("0.8.4"): from vllm_ascend.patch.worker import patch_0_8_4 # noqa: F401 + from vllm_ascend.patch.worker import patch_common # noqa: F401 else: + from vllm_ascend.patch.worker import patch_common # noqa: F401 from vllm_ascend.patch.worker import patch_main # noqa: F401 diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index bbc012a..dfd0f68 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -18,6 +18,7 @@ # import torch import torch_npu # noqa: F401 +from packaging.version import Version from vllm.logger import logger import vllm_ascend.envs as envs @@ -83,3 +84,8 @@ def adapt_patch(is_global_patch: bool = False): from vllm_ascend.patch import platform # noqa: F401 else: from vllm_ascend.patch import worker # noqa: F401 + + +def vllm_version_is(version: str): + import vllm + return Version(vllm.__version__) == Version(version) diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index c240c80..67a055f 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -25,14 +25,13 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Type, TypeVar, Union) import torch -import torch.distributed import torch.nn as nn import torch_npu from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.utils import CommonAttentionState from vllm.config import VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_kv_transfer_group, get_pp_group +from vllm.distributed import get_pp_group from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import logger @@ -62,6 +61,13 @@ from vllm.worker.model_runner_base import ( _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.8.4"): + from vllm.distributed import get_kv_transfer_group +else: + from vllm.distributed.kv_transfer import get_kv_transfer_group + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 3dc1a88..595fb46 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -25,8 +25,7 @@ import torch.distributed from torch import nn from vllm import envs from vllm.config import VllmConfig -from vllm.distributed import (ensure_kv_transfer_initialized, - ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import logger @@ -45,10 +44,15 @@ from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import try_register_lib +from vllm_ascend.utils import try_register_lib, vllm_version_is from vllm_ascend.worker.model_runner import NPUModelRunner from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner +if vllm_version_is("0.8.4"): + from vllm.distributed import ensure_kv_transfer_initialized +else: + from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized + class NPUWorker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a NPU. diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 73dde65..f839518 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -25,8 +25,7 @@ import torch.nn as nn import torch_npu from vllm import envs from vllm.config import VllmConfig -from vllm.distributed import (ensure_kv_transfer_initialized, - ensure_model_parallel_initialized, +from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import logger @@ -40,9 +39,14 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.worker_base import WorkerBase from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import try_register_lib +from vllm_ascend.utils import try_register_lib, vllm_version_is from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +if vllm_version_is("0.8.4"): + from vllm.distributed import ensure_kv_transfer_initialized +else: + from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized + class NPUWorker(WorkerBase): @@ -241,4 +245,4 @@ class NPUWorker(WorkerBase): on_trace_ready=torch_npu.profiler.tensorboard_trace_handler( torch_profiler_trace_dir)) else: - return None \ No newline at end of file + return None