[Misc] Fix import error and address nits to make CI happy (#563)
1. Add `vllm_version_is` function to check vllm version.
2. `ensure_kv_transfer_initialized` and `get_kv_transfer_group ` have
been moved to other place in vllm main branch via
3408e47159
, this patch fix the import error.
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -25,14 +25,13 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
|
||||
Type, TypeVar, Union)
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_kv_transfer_group, get_pp_group
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import logger
|
||||
@@ -62,6 +61,13 @@ from vllm.worker.model_runner_base import (
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.8.4"):
|
||||
from vllm.distributed import get_kv_transfer_group
|
||||
else:
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
@@ -25,8 +25,7 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_kv_transfer_initialized,
|
||||
ensure_model_parallel_initialized,
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.logger import logger
|
||||
@@ -45,10 +44,15 @@ from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import try_register_lib
|
||||
from vllm_ascend.utils import try_register_lib, vllm_version_is
|
||||
from vllm_ascend.worker.model_runner import NPUModelRunner
|
||||
from vllm_ascend.worker.pooling_model_runner import NPUPoolingModelRunner
|
||||
|
||||
if vllm_version_is("0.8.4"):
|
||||
from vllm.distributed import ensure_kv_transfer_initialized
|
||||
else:
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
|
||||
|
||||
class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a NPU.
|
||||
|
||||
@@ -25,8 +25,7 @@ import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_kv_transfer_initialized,
|
||||
ensure_model_parallel_initialized,
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.logger import logger
|
||||
@@ -40,9 +39,14 @@ from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import try_register_lib
|
||||
from vllm_ascend.utils import try_register_lib, vllm_version_is
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
if vllm_version_is("0.8.4"):
|
||||
from vllm.distributed import ensure_kv_transfer_initialized
|
||||
else:
|
||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||
|
||||
|
||||
class NPUWorker(WorkerBase):
|
||||
|
||||
@@ -241,4 +245,4 @@ class NPUWorker(WorkerBase):
|
||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir))
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user