diff --git a/tests/utils.py b/tests/utils.py index d2439ee..b84b39a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -42,13 +42,19 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.model_executor.model_loader.loader import get_model_loader from vllm.platforms import current_platform from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import FlexibleArgumentParser, GB_bytes, get_open_port +from vllm_ascend.utils import vllm_version_is + from .model_utils import TextTextLogprobs +if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): + from vllm.model_executor.model_loader.loader import get_model_loader # type: ignore[import] # isort: skip +else: + from vllm.model_executor.model_loader import get_model_loader + VLLM_PATH = Path(__file__).parent.parent """Path to root of the vLLM repository.""" diff --git a/vllm_ascend/ops/attention.py b/vllm_ascend/ops/attention.py index 4d38255..f21c03e 100644 --- a/vllm_ascend/ops/attention.py +++ b/vllm_ascend/ops/attention.py @@ -131,6 +131,7 @@ def vanilla_chunked_prefill( attn_output = (attn_output[q_mask].view([-1, num_query_heads, head_dim]).to(output.dtype)) + output = output.view_as(attn_output) output.copy_(attn_output) return attn_output diff --git a/vllm_ascend/worker/model_runner.py b/vllm_ascend/worker/model_runner.py index 779ac17..e58f55a 100644 --- a/vllm_ascend/worker/model_runner.py +++ b/vllm_ascend/worker/model_runner.py @@ -64,6 +64,8 @@ from vllm.worker.model_runner_base import ( _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +from vllm_ascend.utils import vllm_version_is + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1007,7 +1009,10 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: - from vllm.model_executor.model_loader.loader import ShardedStateLoader + if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): + from vllm.model_executor.model_loader.loader import ShardedStateLoader # type: ignore[import] # isort: skip # noqa + else: + from vllm.model_executor.model_loader import ShardedStateLoader ShardedStateLoader.save_model( self.model, path, @@ -1019,7 +1024,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]): self, tensorizer_config: TensorizerConfig, ) -> None: - from vllm.model_executor.model_loader.loader import TensorizerLoader + if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): + from vllm.model_executor.model_loader.loader import \ + TensorizerLoader # type: ignore # noqa + else: + from vllm.model_executor.model_loader import \ + TensorizerLoader # type: ignore # noqa TensorizerLoader.save_model( self.model, tensorizer_config=tensorizer_config,