[Misc] Fix import error and address nits to make CI happy (#563)
1. Add `vllm_version_is` function to check vllm version.
2. `ensure_kv_transfer_initialized` and `get_kv_transfer_group ` have
been moved to other place in vllm main branch via
3408e47159
, this patch fix the import error.
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
2
.github/workflows/mypy.yaml
vendored
2
.github/workflows/mypy.yaml
vendored
@@ -47,7 +47,7 @@ jobs:
|
|||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
|
|
||||||
- name: Checkout vllm-project/vllm repo
|
- name: Checkout vllm-project/vllm repo
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
8
.github/workflows/vllm_ascend_test.yaml
vendored
8
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -142,9 +142,10 @@ jobs:
|
|||||||
pytest -sv tests/ops
|
pytest -sv tests/ops
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# only run test on spec decode when the related code changed
|
||||||
- name: Check for changes in Speculative Decode
|
- name: Check for changes in Speculative Decode
|
||||||
id: filter_spec_decode
|
id: filter_spec_decode
|
||||||
uses: dorny/paths-filter@v2
|
uses: dorny/paths-filter@v3
|
||||||
with:
|
with:
|
||||||
filters: |
|
filters: |
|
||||||
speculative_tests_changed:
|
speculative_tests_changed:
|
||||||
@@ -155,10 +156,9 @@ jobs:
|
|||||||
- "vllm_ascend/patch/patch_rejection_sampler.py"
|
- "vllm_ascend/patch/patch_rejection_sampler.py"
|
||||||
- "vllm_ascend/patch/patch_spec_decode_worker.py"
|
- "vllm_ascend/patch/patch_spec_decode_worker.py"
|
||||||
- "vllm_ascend/patch/patch_multi_step_worker.py"
|
- "vllm_ascend/patch/patch_multi_step_worker.py"
|
||||||
|
|
||||||
- name: Run vllm-project/vllm-ascend Speculative Decode test
|
- name: Run vllm-project/vllm-ascend Speculative Decode test
|
||||||
env:
|
if: steps.filter_spec_decode.outputs.speculative_tests_changed == 'true'
|
||||||
HF_ENDPOINT: https://hf-mirror.com
|
|
||||||
if: steps.filter_spec_decode.outputs.speculative_tests_changed
|
|
||||||
run: |
|
run: |
|
||||||
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
||||||
pytest -sv tests/singlecard/spec_decode
|
pytest -sv tests/singlecard/spec_decode
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -38,9 +38,8 @@ from vllm.attention.backends.utils import (CommonAttentionState,
|
|||||||
is_block_tables_empty)
|
is_block_tables_empty)
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from vllm_ascend.worker.model_runner import (
|
||||||
from vllm_ascend.worker.model_runner import (
|
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
||||||
ModelInputForNPUBuilder, ModelInputForNPUWithSamplingMetadata)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
|
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
|
||||||
@@ -489,7 +488,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
|||||||
128, self.input_builder.runner.model_config.dtype)
|
128, self.input_builder.runner.model_config.dtype)
|
||||||
|
|
||||||
def _add_seq_group(
|
def _add_seq_group(
|
||||||
self, inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup",
|
self, inter_data: ModelInputForNPUBuilder.InterDataForSeqGroup,
|
||||||
chunked_prefill_enabled: bool):
|
chunked_prefill_enabled: bool):
|
||||||
"""Add a sequence group to the metadata. Specifically update/append
|
"""Add a sequence group to the metadata. Specifically update/append
|
||||||
1. context length.
|
1. context length.
|
||||||
|
|||||||
@@ -14,14 +14,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import vllm
|
from vllm_ascend.utils import vllm_version_is
|
||||||
from packaging.version import Version
|
|
||||||
|
|
||||||
# Import common patches for all versions
|
|
||||||
from vllm_ascend.patch.platform import patch_common # noqa: F401
|
|
||||||
|
|
||||||
# Import specific patches for different versions
|
# Import specific patches for different versions
|
||||||
if Version(vllm.__version__) == Version("0.8.4"):
|
if vllm_version_is("0.8.4"):
|
||||||
from vllm_ascend.patch.platform import patch_0_8_4 # noqa: F401
|
from vllm_ascend.patch.platform import patch_0_8_4 # noqa: F401
|
||||||
|
from vllm_ascend.patch.platform import patch_common # noqa: F401
|
||||||
else:
|
else:
|
||||||
|
from vllm_ascend.patch.platform import patch_common # noqa: F401
|
||||||
from vllm_ascend.patch.platform import patch_main # noqa: F401
|
from vllm_ascend.patch.platform import patch_main # noqa: F401
|
||||||
|
|||||||
@@ -14,14 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import vllm
|
|
||||||
from packaging.version import Version
|
|
||||||
|
|
||||||
# Import common patches for all versions
|
from vllm_ascend.utils import vllm_version_is
|
||||||
from vllm_ascend.patch.worker import patch_common # noqa: F401
|
|
||||||
|
|
||||||
# Import specific patches for different versions
|
# Import specific patches for different versions
|
||||||
if Version(vllm.__version__) == Version("0.8.4"):
|
if vllm_version_is("0.8.4"):
|
||||||
from vllm_ascend.patch.worker import patch_0_8_4 # noqa: F401
|
from vllm_ascend.patch.worker import patch_0_8_4 # noqa: F401
|
||||||
|
from vllm_ascend.patch.worker import patch_common # noqa: F401
|
||||||
else:
|
else:
|
||||||
|
from vllm_ascend.patch.worker import patch_common # noqa: F401
|
||||||
from vllm_ascend.patch.worker import patch_main # noqa: F401
|
from vllm_ascend.patch.worker import patch_main # noqa: F401
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
#
|
#
|
||||||
import torch
|
import torch
|
||||||
import torch_npu # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
|
from packaging.version import Version
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
import vllm_ascend.envs as envs
|
import vllm_ascend.envs as envs
|
||||||
@@ -83,3 +84,8 @@ def adapt_patch(is_global_patch: bool = False):
|
|||||||
from vllm_ascend.patch import platform # noqa: F401
|
from vllm_ascend.patch import platform # noqa: F401
|
||||||
else:
|
else:
|
||||||
from vllm_ascend.patch import worker # noqa: F401
|
from vllm_ascend.patch import worker # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_version_is(version: str):
|
||||||
|
import vllm
|
||||||
|
return Version(vllm.__version__) == Version(version)
|
||||||
|
|||||||
@@ -25,14 +25,13 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
|
|||||||
Type, TypeVar, Union)
|
Type, TypeVar, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.distributed import get_kv_transfer_group, get_pp_group
|
from vllm.distributed import get_pp_group
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@@ -62,6 +61,13 @@ from vllm.worker.model_runner_base import (
|
|||||||
_init_attn_metadata_from_tensor_dict,
|
_init_attn_metadata_from_tensor_dict,
|
||||||
_init_sampling_metadata_from_tensor_dict)
|
_init_sampling_metadata_from_tensor_dict)
|
||||||
|
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
|
if vllm_version_is("0.8.4"):
|
||||||
|
from vllm.distributed import get_kv_transfer_group
|
||||||
|
else:
|
||||||
|
from vllm.distributed.kv_transfer import get_kv_transfer_group
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,7 @@ import torch.distributed
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (ensure_kv_transfer_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
ensure_model_parallel_initialized,
|
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@@ -45,10 +44,15 @@ from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
|||||||
WorkerInput)
|
WorkerInput)
|
||||||
|
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.utils import try_register_lib
|
from vllm_ascend.utils import try_register_lib, vllm_version_is
|
||||||
from vllm_ascend.worker.model_runner import NPUModelRunner
|
from vllm_ascend.worker.model_runner import NPUModelRunner
|
||||||
from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner
|
from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner
|
||||||
|
|
||||||
|
if vllm_version_is("0.8.4"):
|
||||||
|
from vllm.distributed import ensure_kv_transfer_initialized
|
||||||
|
else:
|
||||||
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||||
|
|
||||||
|
|
||||||
class NPUWorker(LocalOrDistributedWorkerBase):
|
class NPUWorker(LocalOrDistributedWorkerBase):
|
||||||
"""A worker class that executes (a partition of) the model on a NPU.
|
"""A worker class that executes (a partition of) the model on a NPU.
|
||||||
|
|||||||
@@ -25,8 +25,7 @@ import torch.nn as nn
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (ensure_kv_transfer_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
ensure_model_parallel_initialized,
|
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@@ -40,9 +39,14 @@ from vllm.v1.utils import bind_kv_cache
|
|||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.utils import try_register_lib
|
from vllm_ascend.utils import try_register_lib, vllm_version_is
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
|
if vllm_version_is("0.8.4"):
|
||||||
|
from vllm.distributed import ensure_kv_transfer_initialized
|
||||||
|
else:
|
||||||
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||||
|
|
||||||
|
|
||||||
class NPUWorker(WorkerBase):
|
class NPUWorker(WorkerBase):
|
||||||
|
|
||||||
@@ -241,4 +245,4 @@ class NPUWorker(WorkerBase):
|
|||||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
||||||
torch_profiler_trace_dir))
|
torch_profiler_trace_dir))
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user