[CI] Fix broken CI (#1889)

1. vLLM commit
45badd05d0
changed the pooling check logic which broken vLLM Ascend.
2. vLLM commit
3e04107d97
requires higher version of transformers. The transformers version bug
has been fixed by
e936e401de.
We can safe to remove the version limit now.
3. vLLM commit
217937221b
added a new input `enable_eplb` for FusedMoe Ops

This PR fix the broken CI.


- vLLM version: v0.9.2
- vLLM main:
6a971ed692

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-07-20 02:11:57 +08:00
committed by GitHub
parent 2ee90461d0
commit 2b726d8f90
6 changed files with 127 additions and 53 deletions

View File

@@ -19,7 +19,5 @@ requires = [
"msgpack",
"quart",
"numba",
# Remove after https://github.com/vllm-project/vllm-ascend/issues/1470
"transformers==4.52.4",
]
build-backend = "setuptools.build_meta"

View File

@@ -25,6 +25,3 @@ numba
--pre
--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
torch-npu==2.5.1.post1.dev20250619
# Remove after https://github.com/vllm-project/vllm-ascend/issues/1470
transformers==4.52.4

View File

@@ -40,23 +40,26 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs):
def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
if SELECT_GATING_TOPK_SOTFMAX_EXPERTS:
topk_weights, topk_ids = select_gating_top_k_softmax_experts(

View File

@@ -24,7 +24,7 @@ import types
import weakref
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast, get_args
import numpy as np
import numpy.typing as npt
@@ -45,7 +45,8 @@ from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import has_step_pooler
from vllm.model_executor.models.interfaces_base import (VllmModelForPooling,
is_pooling_model)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
@@ -88,8 +89,10 @@ from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if vllm_version_is("0.9.2"):
from vllm.model_executor.models.interfaces import has_step_pooler
from vllm.v1.utils import bind_kv_cache
else:
from vllm.pooling_params import PoolingTask
from vllm.v1.worker.utils import bind_kv_cache
if TYPE_CHECKING:
@@ -395,6 +398,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
pooling_params = new_req_data.pooling_params
if sampling_params and \
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
@@ -402,6 +406,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
generator = None
if not vllm_version_is("0.9.2") and pooling_params:
assert pooling_params.task is not None, (
"You did not set `task` in the API")
model = cast(VllmModelForPooling, self.model)
to_update = (model.pooler.get_pooling_updates(
pooling_params.task))
assert to_update is not None, (
f"{pooling_params.task=} is not supported by the model")
to_update.apply(pooling_params)
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
@@ -1729,26 +1743,59 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_num_tokens = num_tokens // num_reqs
dummy_metadata = PoolingMetadata(
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
device=self.device),
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
dtype=torch.int32,
device=self.device),
pooling_params=[PoolingParams()] * num_reqs)
if vllm_version_is("0.9.2"):
dummy_metadata = PoolingMetadata(
prompt_lens=torch.tensor(
[h.shape[0] for h in hidden_states_list],
device=self.device),
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
dtype=torch.int32,
device=self.device),
pooling_params=[PoolingParams()] * num_reqs)
try:
pooler_output = self.model.pooler(
hidden_states=hidden_states_list,
pooling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"NPU out of memory occurred when warming up pooler with "
f"{num_reqs} dummy requests. Please try lowering "
"`max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine.") from e
else:
raise e
else:
model = cast(VllmModelForPooling, self.model)
dummy_task = self.get_supported_pooling_tasks()[0]
dummy_pooling_params = PoolingParams(task=dummy_task)
to_update = model.pooler.get_pooling_updates(dummy_task)
assert to_update is not None
to_update.apply(dummy_pooling_params)
dummy_metadata = PoolingMetadata(
prompt_lens=torch.tensor(
[h.shape[0] for h in hidden_states_list],
device=self.device),
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
dtype=torch.int32,
device=self.device),
pooling_params=[dummy_pooling_params] * num_reqs)
try:
pooler_output = model.pooler(hidden_states=hidden_states_list,
pooling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"NPU out of memory occurred when warming up pooler with "
f"{num_reqs} dummy requests. Please try lowering "
"`max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine.") from e
else:
raise e
try:
pooler_output = self.model.pooler(hidden_states=hidden_states_list,
pooling_metadata=dummy_metadata)
except RuntimeError as e:
if 'out of memory' in str(e):
raise RuntimeError(
"NPU out of memory occurred when warming up pooler with "
f"{num_reqs} dummy requests. Please try lowering "
"`max_num_seqs` or `gpu_memory_utilization` when "
"initializing the engine.") from e
else:
raise e
return pooler_output
def load_model(self) -> None:
@@ -1767,8 +1814,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
QKVParallelLinear, RowParallelLinear)):
module.weight.data = torch_npu.npu_format_cast(
module.weight.data, ACL_FORMAT_FRACTAL_NZ)
if has_step_pooler(self.model):
self.input_batch.logits_processing_needs_token_ids = True
if vllm_version_is("0.9.2") and has_step_pooler(self.model):
self.input_batch.logits_processing_needs_token_ids_bool = True
if self.drafter:
logger.info("Loading drafter model...")
if isinstance(self.drafter, EagleProposer):
@@ -2379,3 +2427,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if batch_size <= padded_batch_size < selected_batch_size:
selected_batch_size = padded_batch_size
return selected_batch_size
def get_supported_pooling_tasks(self):
model = self.get_model()
if not is_pooling_model(model):
return []
return [
task for task in get_args(PoolingTask)
if model.pooler.get_pooling_updates(task)
]

View File

@@ -35,6 +35,8 @@ from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm_ascend.utils import vllm_version_is
_SAMPLING_EPS = 1e-5
@@ -83,7 +85,6 @@ class InputBatch:
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logits_processing_needs_token_ids: bool = False,
is_spec_decode: bool = False,
):
self.is_spec_decode = is_spec_decode
@@ -93,8 +94,6 @@ class InputBatch:
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.logits_processing_needs_token_ids = (
logits_processing_needs_token_ids)
self._req_ids: list[Optional[str]] = []
self.req_id_to_index: dict[str, int] = {}
@@ -247,6 +246,11 @@ class InputBatch:
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
if vllm_version_is("0.9.2"):
self.logits_processing_needs_token_ids_bool = False
else:
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
dtype=bool)
self.req_output_token_ids: list[Optional[list[int]]] = []
@@ -383,9 +387,15 @@ class InputBatch:
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
else:
elif vllm_version_is("0.9.2"):
assert request.pooling_params is not None
self.pooling_params[req_id] = request.pooling_params
elif pooling_params := request.pooling_params:
self.pooling_params[req_id] = pooling_params
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids)
else:
raise NotImplementedError(request)
# Add request lora ID
if request.lora_request:
@@ -614,10 +624,15 @@ class InputBatch:
self.presence_penalties, num_reqs)
copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs)
needs_prompt_token_ids = (not self.no_penalties or
(self.num_reqs > 0
and self.logits_processing_needs_token_ids))
if vllm_version_is("0.9.2"):
needs_prompt_token_ids = (
not self.no_penalties
or (self.num_reqs > 0
and self.logits_processing_needs_token_ids_bool))
else:
needs_prompt_token_ids = (
not self.no_penalties
or self.logits_processing_needs_token_ids[:num_reqs].any())
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.

View File

@@ -355,3 +355,6 @@ class NPUWorker(WorkerBase):
torch_profiler_trace_dir))
else:
return None
def get_supported_pooling_tasks(self):
return self.model_runner.get_supported_pooling_tasks()