[CI] Fix broken CI (#1915)
### What this PR does / why we need it?
Fix [#21227](https://github.com/vllm-project/vllm/pull/21227) to make ci
happy
- vLLM version: v0.9.2
- vLLM main:
6b46c4b653
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -24,7 +24,7 @@ import types
|
|||||||
import weakref
|
import weakref
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast, get_args
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -93,7 +93,6 @@ if vllm_version_is("0.9.2"):
|
|||||||
from vllm.model_executor.models.interfaces import has_step_pooler
|
from vllm.model_executor.models.interfaces import has_step_pooler
|
||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
else:
|
else:
|
||||||
from vllm.pooling_params import PoolingTask
|
|
||||||
from vllm.v1.worker.utils import bind_kv_cache
|
from vllm.v1.worker.utils import bind_kv_cache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -408,13 +407,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
generator = None
|
generator = None
|
||||||
|
|
||||||
if not vllm_version_is("0.9.2") and pooling_params:
|
if not vllm_version_is("0.9.2") and pooling_params:
|
||||||
assert pooling_params.task is not None, (
|
assert (task := pooling_params.task) is not None, (
|
||||||
"You did not set `task` in the API")
|
"You did not set `task` in the API")
|
||||||
model = cast(VllmModelForPooling, self.model)
|
model = cast(VllmModelForPooling, self.model)
|
||||||
to_update = (model.pooler.get_pooling_updates(
|
to_update = model.pooler.get_pooling_updates(task)
|
||||||
pooling_params.task))
|
|
||||||
assert to_update is not None, (
|
|
||||||
f"{pooling_params.task=} is not supported by the model")
|
|
||||||
to_update.apply(pooling_params)
|
to_update.apply(pooling_params)
|
||||||
|
|
||||||
self.requests[req_id] = CachedRequestState(
|
self.requests[req_id] = CachedRequestState(
|
||||||
@@ -1772,7 +1768,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dummy_pooling_params = PoolingParams(task=dummy_task)
|
dummy_pooling_params = PoolingParams(task=dummy_task)
|
||||||
|
|
||||||
to_update = model.pooler.get_pooling_updates(dummy_task)
|
to_update = model.pooler.get_pooling_updates(dummy_task)
|
||||||
assert to_update is not None
|
|
||||||
to_update.apply(dummy_pooling_params)
|
to_update.apply(dummy_pooling_params)
|
||||||
|
|
||||||
dummy_metadata = PoolingMetadata(
|
dummy_metadata = PoolingMetadata(
|
||||||
@@ -2434,7 +2429,4 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if not is_pooling_model(model):
|
if not is_pooling_model(model):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [
|
return list(model.pooler.get_supported_tasks())
|
||||||
task for task in get_args(PoolingTask)
|
|
||||||
if model.pooler.get_pooling_updates(task)
|
|
||||||
]
|
|
||||||
|
|||||||
Reference in New Issue
Block a user