[Bugfix] Add get_supported_tasks interface to fix broken CI (#2023)
### What this PR does / why we need it? Added `get_supported_tasks` interface to adapt to vllm [changes](46d81d6951 (diff-80ee7e2a62f9dcfbb8a312dc4e3948557e97ef187290daebbcae1e28596bda29)) ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.9.2 - vLLM main:5ac3168ee3--------- Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -92,6 +92,10 @@ if vllm_version_is("0.9.2"):
|
|||||||
from vllm.model_executor.models.interfaces import has_step_pooler
|
from vllm.model_executor.models.interfaces import has_step_pooler
|
||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
else:
|
else:
|
||||||
|
from vllm.model_executor.models.interfaces import supports_transcription
|
||||||
|
from vllm.model_executor.models.interfaces_base import \
|
||||||
|
is_text_generation_model
|
||||||
|
from vllm.tasks import GenerationTask, SupportedTask
|
||||||
from vllm.v1.worker.utils import bind_kv_cache
|
from vllm.v1.worker.utils import bind_kv_cache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -706,6 +710,31 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def get_supported_generation_tasks(self) -> "list[GenerationTask]":
|
||||||
|
model = self.get_model()
|
||||||
|
supported_tasks = list[GenerationTask]()
|
||||||
|
|
||||||
|
if is_text_generation_model(model):
|
||||||
|
supported_tasks.append("generate")
|
||||||
|
|
||||||
|
if supports_transcription(model):
|
||||||
|
if model.supports_transcription_only:
|
||||||
|
return ["transcription"]
|
||||||
|
|
||||||
|
supported_tasks.append("transcription")
|
||||||
|
|
||||||
|
return supported_tasks
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
|
||||||
|
tasks = list[SupportedTask]()
|
||||||
|
|
||||||
|
if self.model_config.runner_type == "generate":
|
||||||
|
tasks.extend(self.get_supported_generation_tasks())
|
||||||
|
if self.model_config.runner_type == "pooling":
|
||||||
|
tasks.extend(self.get_supported_pooling_tasks())
|
||||||
|
|
||||||
|
return tuple(tasks)
|
||||||
|
|
||||||
def _make_attention_mask(self, seq_lens, query_lens, position,
|
def _make_attention_mask(self, seq_lens, query_lens, position,
|
||||||
attn_state) -> torch.Tensor:
|
attn_state) -> torch.Tensor:
|
||||||
# Chunk Prefill situation.
|
# Chunk Prefill situation.
|
||||||
|
|||||||
@@ -41,9 +41,13 @@ from vllm.v1.worker.worker_base import WorkerBase
|
|||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.utils import sleep_mode_enabled, try_register_lib
|
from vllm_ascend.utils import (sleep_mode_enabled, try_register_lib,
|
||||||
|
vllm_version_is)
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
|
if not vllm_version_is("0.9.2"):
|
||||||
|
from vllm.tasks import SupportedTask
|
||||||
|
|
||||||
|
|
||||||
class NPUWorker(WorkerBase):
|
class NPUWorker(WorkerBase):
|
||||||
|
|
||||||
@@ -326,3 +330,6 @@ class NPUWorker(WorkerBase):
|
|||||||
|
|
||||||
def get_supported_pooling_tasks(self):
|
def get_supported_pooling_tasks(self):
|
||||||
return self.model_runner.get_supported_pooling_tasks()
|
return self.model_runner.get_supported_pooling_tasks()
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
|
||||||
|
return self.model_runner.get_supported_tasks()
|
||||||
|
|||||||
Reference in New Issue
Block a user