diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c934d22..4d24987 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -92,6 +92,10 @@ if vllm_version_is("0.9.2"): from vllm.model_executor.models.interfaces import has_step_pooler from vllm.v1.utils import bind_kv_cache else: + from vllm.model_executor.models.interfaces import supports_transcription + from vllm.model_executor.models.interfaces_base import \ + is_text_generation_model + from vllm.tasks import GenerationTask, SupportedTask from vllm.v1.worker.utils import bind_kv_cache if TYPE_CHECKING: @@ -706,6 +710,31 @@ class NPUModelRunner(LoRAModelRunnerMixin): def get_model(self) -> nn.Module: return self.model + def get_supported_generation_tasks(self) -> "list[GenerationTask]": + model = self.get_model() + supported_tasks = list[GenerationTask]() + + if is_text_generation_model(model): + supported_tasks.append("generate") + + if supports_transcription(model): + if model.supports_transcription_only: + return ["transcription"] + + supported_tasks.append("transcription") + + return supported_tasks + + def get_supported_tasks(self) -> "tuple[SupportedTask, ...]": + tasks = list[SupportedTask]() + + if self.model_config.runner_type == "generate": + tasks.extend(self.get_supported_generation_tasks()) + if self.model_config.runner_type == "pooling": + tasks.extend(self.get_supported_pooling_tasks()) + + return tuple(tasks) + def _make_attention_mask(self, seq_lens, query_lens, position, attn_state) -> torch.Tensor: # Chunk Prefill situation. diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 73f2d0b..a5e1a1c 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -41,9 +41,13 @@ from vllm.v1.worker.worker_base import WorkerBase from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import sleep_mode_enabled, try_register_lib +from vllm_ascend.utils import (sleep_mode_enabled, try_register_lib, + vllm_version_is) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +if not vllm_version_is("0.9.2"): + from vllm.tasks import SupportedTask + class NPUWorker(WorkerBase): @@ -326,3 +330,6 @@ class NPUWorker(WorkerBase): def get_supported_pooling_tasks(self): return self.model_runner.get_supported_pooling_tasks() + + def get_supported_tasks(self) -> "tuple[SupportedTask, ...]": + return self.model_runner.get_supported_tasks()