diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 42007d8d..4ff1ee25 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -77,6 +77,7 @@ jobs: # pytest -sv tests/e2e/singlecard/test_aclgraph.py # pytest -sv tests/e2e/singlecard/test_quantization.py pytest -sv tests/e2e/singlecard/test_vlm.py::test_multimodal_vl + pytest -sv tests/e2e/singlecard/pooling/test_classification.py::test_classify_correctness - name: Run e2e test env: @@ -91,9 +92,7 @@ jobs: pytest -sv tests/e2e/singlecard/test_completion_with_prompt_embeds.py pytest -sv tests/e2e/singlecard/test_aclgraph.py pytest -sv tests/e2e/singlecard/test_aclgraph_mem.py - pytest -sv tests/e2e/singlecard/test_bge_model.py pytest -sv tests/e2e/singlecard/test_camem.py - pytest -sv tests/e2e/singlecard/test_embedding.py # pytest -sv tests/e2e/singlecard/test_embedding_aclgraph.py pytest -sv tests/e2e/singlecard/test_guided_decoding.py # torch 2.8 doesn't work with lora, fix me @@ -104,6 +103,7 @@ jobs: pytest -sv tests/e2e/singlecard/test_vlm.py pytest -sv tests/e2e/singlecard/multi-modal/test_internvl.py pytest -sv tests/e2e/singlecard/test_xlite.py + pytest -sv tests/e2e/singlecard/pooling/ # ------------------------------------ v1 spec decode test ------------------------------------ # pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py diff --git a/docs/source/user_guide/support_matrix/supported_models.md b/docs/source/user_guide/support_matrix/supported_models.md index 060be366..41479335 100644 --- a/docs/source/user_guide/support_matrix/supported_models.md +++ b/docs/source/user_guide/support_matrix/supported_models.md @@ -48,7 +48,8 @@ Get the latest info here: https://github.com/vllm-project/vllm-ascend/issues/160 |-------------------------------|-----------|----------------------------------------------------------------------|------|--------------------|------|-----------------|------------------------|------|----------------------|------------------|-----------------|-------------------|-----------------|---------------|-------------------------------|--------------------|--------------------|---------------|---------------------|-----| | Qwen3-Embedding | ✅ | ||||||||||||||||||| | Molmo | ✅ | [1942](https://github.com/vllm-project/vllm-ascend/issues/1942) ||||||||||||||||||| -| XLM-RoBERTa-based | ❌ | [1960](https://github.com/vllm-project/vllm-ascend/issues/1960) ||||||||||||||||||| +| XLM-RoBERTa-based | ✅ | ||||||||||||||||||| +| Bert | ✅ | ||||||||||||||||||| ## Multimodal Language Models diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 9d0e709a..a8207e3f 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -26,7 +26,7 @@ import shlex import subprocess import sys import time -from typing import Any, List, Optional, Tuple, TypeVar, Union +from typing import Any, Optional, Tuple, TypeVar, Union import httpx import numpy as np @@ -42,7 +42,8 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, BatchEncoding, BatchFeature) from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm import LLM, SamplingParams -from vllm.config.model import _get_and_verify_dtype +from vllm.config.model import (ConvertOption, RunnerOption, + _get_and_verify_dtype) from vllm.inputs import TextPrompt from vllm.outputs import RequestOutput from vllm.platforms import current_platform @@ -67,7 +68,7 @@ from vllm.distributed.parallel_state import ( # noqa E402 _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict) _M = TypeVar("_M") -_PromptMultiModalInput = Union[List[_M], List[List[_M]]] +_PromptMultiModalInput = Union[list[_M], list[list[_M]]] PromptImageInput = _PromptMultiModalInput[Image.Image] PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]] @@ -320,12 +321,11 @@ class VllmRunner: def __init__( self, model_name: str, - runner: str = "auto", + runner: RunnerOption = "auto", + convert: ConvertOption = "auto", tokenizer_name: Optional[str] = None, tokenizer_mode: str = "auto", - # Use smaller max model length, otherwise bigger model cannot run due - # to kv cache size limit. - max_model_len: int = 1024, + max_model_len: Optional[int] = 1024, dtype: str = "auto", disable_log_stats: bool = True, tensor_parallel_size: int = 1, @@ -339,6 +339,7 @@ class VllmRunner: self.model = LLM( model=model_name, runner=runner, + convert=convert, tokenizer=tokenizer_name, tokenizer_mode=tokenizer_mode, trust_remote_code=True, @@ -356,73 +357,79 @@ class VllmRunner: def get_inputs( self, - prompts: List[str], + prompts: Union[list[str], list[torch.Tensor], list[int]], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, - ) -> List[TextPrompt]: - if images is not None: - assert len(prompts) == len(images) + ) -> list[TextPrompt]: - if videos is not None: - assert len(prompts) == len(videos) + if any(x is not None and len(x) != len(prompts) + for x in [images, videos, audios]): + raise ValueError( + "All non-None multimodal inputs must have the same length as " + "prompts") - if audios is not None: - assert len(prompts) == len(audios) + inputs = [] + for i, prompt in enumerate(prompts): + multi_modal_data = {} + if images is not None and (image := images[i]) is not None: + multi_modal_data["image"] = image + if videos is not None and (video := videos[i]) is not None: + multi_modal_data["video"] = video + if audios is not None and (audio := audios[i]) is not None: + multi_modal_data["audio"] = audio - inputs = [TextPrompt(prompt=prompt) for prompt in prompts] - if images is not None: - for i, image in enumerate(images): - if image is not None: - inputs[i]["multi_modal_data"] = {"image": image} + text_prompt_kwargs: dict[str, Any] = { + "multi_modal_data": multi_modal_data or None + } + if isinstance(prompt, str): + text_prompt_kwargs["prompt"] = prompt + elif isinstance(prompt, list): + text_prompt_kwargs["prompt_token_ids"] = prompt + else: + text_prompt_kwargs["prompt_embeds"] = prompt - if videos is not None: - for i, video in enumerate(videos): - if video is not None: - inputs[i]["multi_modal_data"] = {"video": video} - - if audios is not None: - for i, audio in enumerate(audios): - if audio is not None: - inputs[i]["multi_modal_data"] = {"audio": audio} + inputs.append(TextPrompt(**text_prompt_kwargs)) return inputs def generate( self, - prompts: List[str], + prompts: Union[list[str], list[torch.Tensor]], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, - ) -> List[Tuple[List[List[int]], List[str]]]: + **kwargs: Any, + ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) req_outputs = self.model.generate(inputs, - sampling_params=sampling_params) + sampling_params=sampling_params, + **kwargs) - outputs: List[Tuple[List[List[int]], List[str]]] = [] + outputs: list[tuple[list[list[int]], list[str]]] = [] for req_output in req_outputs: prompt_str = req_output.prompt prompt_ids = req_output.prompt_token_ids - req_sample_output_ids: List[List[int]] = [] - req_sample_output_strs: List[str] = [] + req_sample_output_ids: list[list[int]] = [] + req_sample_output_strs: list[str] = [] for sample in req_output.outputs: output_str = sample.text output_ids = list(sample.token_ids) req_sample_output_ids.append(prompt_ids + output_ids) - req_sample_output_strs.append(prompt_str + output_str) + req_sample_output_strs.append((prompt_str or "") + output_str) outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs @staticmethod def _final_steps_generate_w_logprobs( - req_outputs: List[RequestOutput], - ) -> List[TokensTextLogprobsPromptLogprobs]: - outputs: List[TokensTextLogprobsPromptLogprobs] = [] + req_outputs: list[RequestOutput], + ) -> list[TokensTextLogprobsPromptLogprobs]: + outputs: list[TokensTextLogprobsPromptLogprobs] = [] for req_output in req_outputs: assert len(req_output.outputs) > 0 for sample in req_output.outputs: @@ -435,20 +442,22 @@ class VllmRunner: def generate_w_logprobs( self, - prompts: List[str], + prompts: list[str], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, - ) -> Union[List[TokensTextLogprobs], - List[TokensTextLogprobsPromptLogprobs]]: + **kwargs: Any, + ) -> Union[list[TokensTextLogprobs], + list[TokensTextLogprobsPromptLogprobs]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) req_outputs = self.model.generate(inputs, - sampling_params=sampling_params) + sampling_params=sampling_params, + **kwargs) toks_str_logsprobs_prompt_logprobs = ( self._final_steps_generate_w_logprobs(req_outputs)) @@ -459,34 +468,37 @@ class VllmRunner: def generate_greedy( self, - prompts: List[str], + prompts: Union[list[str], list[torch.Tensor]], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, - ) -> List[Tuple[List[int], str]]: + **kwargs: Any, + ) -> list[tuple[list[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, images=images, videos=videos, - audios=audios) + audios=audios, + **kwargs) return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] def generate_greedy_logprobs( self, - prompts: List[str], + prompts: list[str], max_tokens: int, - num_logprobs: int, + num_logprobs: Optional[int], num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, - stop_token_ids: Optional[List[int]] = None, - stop: Optional[List[str]] = None, - ) -> Union[List[TokensTextLogprobs], - List[TokensTextLogprobsPromptLogprobs]]: + stop_token_ids: Optional[list[int]] = None, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> Union[list[TokensTextLogprobs], + list[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, @@ -499,23 +511,46 @@ class VllmRunner: greedy_logprobs_params, images=images, audios=audios, - videos=videos) + videos=videos, + **kwargs) - def encode( - self, - prompts: List[str], - images: Optional[PromptImageInput] = None, - videos: Optional[PromptVideoInput] = None, - audios: Optional[PromptAudioInput] = None, - ) -> List[List[float]]: + def classify(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.model.classify(prompts) + return [req_output.outputs.probs for req_output in req_outputs] + + def embed(self, + prompts: list[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + *args, + **kwargs) -> list[list[float]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - req_outputs = self.model.embed(inputs) + req_outputs = self.model.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] + def encode(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.model.encode(prompts) + return [req_output.outputs.data for req_output in req_outputs] + + def reward(self, prompts: list[str]) -> list[list[float]]: + req_outputs = self.model.reward(prompts) + return [req_output.outputs.data for req_output in req_outputs] + + def score( + self, + text_1: Union[str, list[str]], + text_2: Union[str, list[str]], + *args, + **kwargs, + ) -> list[float]: + req_outputs = self.model.score(text_1, text_2, *args, **kwargs) + return [req_output.outputs.score for req_output in req_outputs] + def __enter__(self): return self @@ -635,10 +670,79 @@ class HfRunner: if skip_tokenizer_init: self.tokenizer = self.processor.tokenizer + def get_inputs( + self, + prompts: list[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> list[Union[BatchFeature, BatchEncoding]]: + if images is not None: + assert len(prompts) == len(images) + + if videos is not None: + assert len(prompts) == len(videos) + + if audios is not None: + assert len(prompts) == len(audios) + + all_inputs: list[Union[BatchFeature, BatchEncoding]] = [] + for i, prompt in enumerate(prompts): + processor_kwargs: dict[str, Any] = { + "text": prompt, + "return_tensors": "pt", + } + if images is not None and (image := images[i]) is not None: + processor_kwargs["images"] = image + if videos is not None and (video := videos[i]) is not None: + processor_kwargs["videos"] = video + if audios is not None and (audio_inputs := audios[i]) is not None: + # HACK - not all processors take sampling_rate; we should + # clean this up in the future. + if len(audio_inputs) == 2: + audio, sr = audio_inputs + processor_kwargs["audio"] = audio + processor_kwargs["sampling_rate"] = sr + else: + processor_kwargs["audio"] = audio_inputs + + inputs = self.processor(**processor_kwargs) + if isinstance(inputs, BatchFeature): + inputs = inputs.to(dtype=self.dtype) + + all_inputs.append(inputs) + + return all_inputs + + def classify(self, prompts: list[str]) -> list[str]: + # output is final logits + all_inputs = self.get_inputs(prompts) + outputs = [] + problem_type = getattr(self.config, "problem_type", "") + + for inputs in all_inputs: + output = self.model(**self.wrap_device(inputs)) + if problem_type == "regression": + logits = output.logits[0].tolist() + elif problem_type == "multi_label_classification": + logits = output.logits.sigmoid()[0].tolist() + else: + logits = output.logits.softmax(dim=-1)[0].tolist() + outputs.append(logits) + + return outputs + def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]: return self.model.encode(prompts, *args, **kwargs) + def predict(self, prompts: list[list[str]], *args, + **kwargs) -> torch.Tensor: + return self.model.predict(prompts, + *args, + convert_to_tensor=True, + **kwargs) + def __enter__(self): return self @@ -652,7 +756,7 @@ def ilama_lora_files(): return snapshot_download(repo_id="vllm-ascend/ilama-text2sql-spider") -def qwen_prompt(questions: List[str]) -> List[str]: +def qwen_prompt(questions: list[str]) -> list[str]: placeholder = "<|image_pad|>" return [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" diff --git a/tests/e2e/singlecard/pooling/__init__.py b/tests/e2e/singlecard/pooling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e/singlecard/pooling/test_classification.py b/tests/e2e/singlecard/pooling/test_classification.py new file mode 100644 index 00000000..e59983c1 --- /dev/null +++ b/tests/e2e/singlecard/pooling/test_classification.py @@ -0,0 +1,34 @@ +import torch +from modelscope import snapshot_download # type: ignore[import-untyped] +from transformers import AutoModelForSequenceClassification + +from tests.e2e.conftest import HfRunner, VllmRunner + + +def test_classify_correctness() -> None: + + model_name = snapshot_download("Howeee/Qwen2.5-1.5B-apeach") + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is what", + ] + with VllmRunner( + model_name, + runner="pooling", + max_model_len=None, + cudagraph_capture_sizes=[4], + ) as vllm_runner: + vllm_outputs = vllm_runner.classify(prompts) + + with HfRunner(model_name, + dtype="float32", + auto_cls=AutoModelForSequenceClassification) as hf_runner: + hf_outputs = hf_runner.classify(prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + assert torch.allclose(hf_output, vllm_output, 1e-2) diff --git a/tests/e2e/singlecard/test_embedding.py b/tests/e2e/singlecard/pooling/test_embedding.py similarity index 78% rename from tests/e2e/singlecard/test_embedding.py rename to tests/e2e/singlecard/pooling/test_embedding.py index 3ff8d341..7666dbcd 100644 --- a/tests/e2e/singlecard/test_embedding.py +++ b/tests/e2e/singlecard/pooling/test_embedding.py @@ -16,22 +16,32 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/basic_correctness/test_basic_correctness.py # +import pytest from modelscope import snapshot_download # type: ignore[import-untyped] from tests.e2e.conftest import HfRunner, VllmRunner from tests.e2e.utils import check_embeddings_close +MODELS = [ + "Qwen/Qwen3-Embedding-0.6B", # lasttoken + "BAAI/bge-small-en-v1.5", # cls_token + "intfloat/multilingual-e5-small" # mean_tokens +] -def test_embed_models_correctness(): + +@pytest.mark.parametrize("model", MODELS) +def test_embed_models_correctness(model: str): queries = ['What is the capital of China?', 'Explain gravity'] - model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B") + model_name = snapshot_download(model) with VllmRunner( model_name, runner="pooling", enforce_eager=False, + max_model_len=None, + cudagraph_capture_sizes=[4], ) as vllm_runner: - vllm_outputs = vllm_runner.encode(queries) + vllm_outputs = vllm_runner.embed(queries) with HfRunner( model_name, diff --git a/tests/e2e/singlecard/pooling/test_scoring.py b/tests/e2e/singlecard/pooling/test_scoring.py new file mode 100644 index 00000000..c196a0bf --- /dev/null +++ b/tests/e2e/singlecard/pooling/test_scoring.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +import torch.nn.functional as F +from modelscope import snapshot_download # type: ignore[import-untyped] + +from tests.e2e.conftest import HfRunner, VllmRunner + +CROSS_ENCODER_MODELS = [ + "dengcao/ms-marco-MiniLM-L6-v2", # Bert + "BAAI/bge-reranker-v2-m3", # Roberta +] + +EMBEDDING_MODELS = [ + "sentence-transformers/all-MiniLM-L12-v2", +] + +TEXTS_1 = [ + "What is the capital of France?", + "What is the capital of Germany?", +] + +TEXTS_2 = [ + "The capital of France is Paris.", + "The capital of Germany is Berlin.", +] + +DTYPE = "half" + + +@pytest.fixture(scope="module", params=CROSS_ENCODER_MODELS) +def model_name(request): + yield snapshot_download(request.param) + + +def test_cross_encoder_1_to_1(model_name): + text_pair = [TEXTS_1[0], TEXTS_2[0]] + + with HfRunner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict([text_pair]).tolist() + + with VllmRunner(model_name, + runner="pooling", + dtype=DTYPE, + cudagraph_capture_sizes=[4], + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) + + assert len(vllm_outputs) == 1 + assert len(hf_outputs) == 1 + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + + +def test_cross_encoder_1_to_N(model_name): + text_pairs = [ + [TEXTS_1[0], TEXTS_2[0]], + [TEXTS_1[0], TEXTS_2[1]], + ] + + with HfRunner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict(text_pairs).tolist() + + with VllmRunner(model_name, + runner="pooling", + dtype=DTYPE, + cudagraph_capture_sizes=[4], + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) + + assert len(vllm_outputs) == 2 + assert len(hf_outputs) == 2 + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) + + +def test_cross_encoder_N_to_N(model_name): + text_pairs = [ + [TEXTS_1[0], TEXTS_2[0]], + [TEXTS_1[1], TEXTS_2[1]], + ] + + with HfRunner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict(text_pairs).tolist() + + with VllmRunner(model_name, + runner="pooling", + dtype=DTYPE, + cudagraph_capture_sizes=[4], + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) + + assert len(vllm_outputs) == 2 + assert len(hf_outputs) == 2 + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) + + +@pytest.fixture(scope="module", params=EMBEDDING_MODELS) +def emb_model_name(request): + yield snapshot_download(request.param) + + +def test_embedding_1_to_1(emb_model_name): + text_pair = [TEXTS_1[0], TEXTS_2[0]] + + with HfRunner(emb_model_name, dtype=DTYPE, + is_sentence_transformer=True) as hf_model: + hf_embeddings = hf_model.encode(text_pair) + hf_outputs = [ + F.cosine_similarity(*map(torch.tensor, hf_embeddings), dim=0) + ] + + with VllmRunner(emb_model_name, + runner="pooling", + dtype=DTYPE, + cudagraph_capture_sizes=[4], + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) + + assert len(vllm_outputs) == 1 + assert len(hf_outputs) == 1 + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + + +def test_embedding_1_to_N(emb_model_name): + text_pairs = [ + [TEXTS_1[0], TEXTS_2[0]], + [TEXTS_1[0], TEXTS_2[1]], + ] + + with HfRunner(emb_model_name, dtype=DTYPE, + is_sentence_transformer=True) as hf_model: + hf_embeddings = [ + hf_model.encode(text_pair) for text_pair in text_pairs + ] + hf_outputs = [ + F.cosine_similarity(*map(torch.tensor, pair), dim=0) + for pair in hf_embeddings + ] + + with VllmRunner(emb_model_name, + runner="pooling", + dtype=DTYPE, + cudagraph_capture_sizes=[4], + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) + + assert len(vllm_outputs) == 2 + assert len(hf_outputs) == 2 + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) + + +def test_embedding_N_to_N(emb_model_name): + text_pairs = [ + [TEXTS_1[0], TEXTS_2[0]], + [TEXTS_1[1], TEXTS_2[1]], + ] + + with HfRunner(emb_model_name, dtype=DTYPE, + is_sentence_transformer=True) as hf_model: + hf_embeddings = [ + hf_model.encode(text_pair) for text_pair in text_pairs + ] + hf_outputs = [ + F.cosine_similarity(*map(torch.tensor, pair), dim=0) + for pair in hf_embeddings + ] + + with VllmRunner(emb_model_name, + runner="pooling", + dtype=DTYPE, + cudagraph_capture_sizes=[4], + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) + + assert len(vllm_outputs) == 2 + assert len(hf_outputs) == 2 + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) diff --git a/tests/e2e/singlecard/test_bge_model.py b/tests/e2e/singlecard/test_bge_model.py deleted file mode 100644 index 48d4bf08..00000000 --- a/tests/e2e/singlecard/test_bge_model.py +++ /dev/null @@ -1,49 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py -# -from modelscope import snapshot_download # type: ignore[import-untyped] - -from tests.e2e.conftest import HfRunner, VllmRunner -from tests.e2e.utils import check_embeddings_close - - -def test_bge_model_correctness(): - queries = ['What is the capital of China?', 'Explain gravity'] - - model_name = snapshot_download("BAAI/bge-m3") - with VllmRunner( - model_name, - runner="pooling", - enforce_eager=True, - ) as vllm_runner: - vllm_outputs = vllm_runner.encode(queries) - - with HfRunner( - model_name, - dtype="float32", - is_sentence_transformer=True, - ) as hf_runner: - hf_outputs = hf_runner.encode(queries) - - check_embeddings_close( - embeddings_0_lst=hf_outputs, - embeddings_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - tol=1e-2, - ) diff --git a/tests/e2e/singlecard/test_embedding_aclgraph.py b/tests/e2e/singlecard/test_embedding_aclgraph.py deleted file mode 100644 index 4c164900..00000000 --- a/tests/e2e/singlecard/test_embedding_aclgraph.py +++ /dev/null @@ -1,55 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py -# -import os - -import pytest - -from tests.e2e.conftest import VllmRunner -from tests.e2e.utils import check_embeddings_close - -os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - -MODELS = ["BAAI/bge-m3"] - - -@pytest.mark.parametrize("model_name", MODELS) -def test_aclgrpah_embed_models_correctness(model_name): - queries = ['What is the capital of China?', 'Explain gravity'] - - with VllmRunner( - model_name, - runner="pooling", - enforce_eager=False, - ) as vllm_aclgraph_runner: - vllm_aclgraph_outputs = vllm_aclgraph_runner.encode(queries) - - with VllmRunner( - model_name, - runner="pooling", - enforce_eager=True, - ) as vllm_runner: - vllm_outputs = vllm_runner.encode(queries) - - check_embeddings_close( - embeddings_0_lst=vllm_outputs, - embeddings_1_lst=vllm_aclgraph_outputs, - name_0="hf", - name_1="vllm", - tol=1e-2, - ) diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index c1322351..a291a480 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -35,7 +35,6 @@ class AttentionMaskBuilder: self.attn_mask_cache = None self._seq_len_cached = 0 self.device = device - self.pooling_mask = None self.mla_mask = None self.chunked_prefill_attn_mask = None self.pcp_mla_mask = None @@ -50,14 +49,6 @@ class AttentionMaskBuilder: return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( ).to(self.device, non_blocking=True) - def get_pooling_mask(self): - if self.pooling_mask is None: - # the compressed attention mask for npu_fusion_attention sparse mode 4 - self.pooling_mask = torch.triu(torch.ones( - 2048, 2048), diagonal=1).to(torch.bool).to(self.device, - non_blocking=True) - return self.pooling_mask - def get_splitfuse_attn_mask(self) -> torch.Tensor: if self.chunked_prefill_attn_mask is None: self.chunked_prefill_attn_mask = torch.triu( diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 7ba77449..8ef50a43 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -221,6 +221,10 @@ class AscendMetadata: # dcp decode_meta: Optional[AscendMetadataForDecode] = None + # Whether is the pooling model with causal attention, + # used to guide the attention computation for pooling models. + is_causal_pooling: Optional[bool] = None + class AscendAttentionMetadataBuilder: # Does this backend/builder support ACL Graphs for attention (default: no). @@ -319,6 +323,10 @@ class AscendAttentionMetadataBuilder: query_start_loc = query_start_loc_cpu.pin_memory().to( self.device, non_blocking=True) + is_causal_pooling = None + if self.model_config.runner_type == "pooling": + is_causal_pooling = common_attn_metadata.causal if hasattr( + common_attn_metadata, 'causal') else True attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, @@ -336,7 +344,8 @@ class AscendAttentionMetadataBuilder: attn_mask=attn_mask, attn_state=attn_state, num_prefills=num_prefills, - num_decodes=num_decodes) + num_decodes=num_decodes, + is_causal_pooling=is_causal_pooling) return attn_metadata def build_for_graph_capture( @@ -597,30 +606,39 @@ class AscendAttentionBackendImpl(AttentionImpl): out=output) return output - def _forward_encode( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: AscendMetadata, - output: torch.Tensor, - ) -> torch.Tensor: - cum_seq_len = attn_metadata.query_start_loc[1:].tolist() - output = torch_npu.npu_fusion_attention( - query, - key, - value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - sparse_mode=4, - atten_mask=attn_metadata.attn_mask, - pre_tockens=attn_metadata.max_query_len, - next_tockens=attn_metadata.max_query_len, - actual_seq_qlen=cum_seq_len, - actual_seq_kvlen=cum_seq_len, - )[0] - return output + def _forward_encoder_attention(self, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + attn_metadata: AscendMetadata, + _: torch.Tensor) -> torch.Tensor: + assert attn_metadata is not None + assert attn_metadata.is_causal_pooling is not None + + if attn_metadata.is_causal_pooling: + # use sparse_mode 3 in causal scenario + return torch_npu.npu_fusion_attention( + query=query, + key=key, + value=value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=3, + atten_mask=attn_metadata.attn_mask, + actual_seq_qlen=attn_metadata.actual_seq_lengths_q, + actual_seq_kvlen=attn_metadata.actual_seq_lengths_q, + )[0] + else: + # use default sparse_mode 0 in normal scenario, which means no mask works on it + return torch_npu.npu_fusion_attention( + query=query, + key=key, + value=value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + actual_seq_qlen=attn_metadata.actual_seq_lengths_q, + actual_seq_kvlen=attn_metadata.actual_seq_lengths_q, + )[0] def reshape_and_cache( self, @@ -697,18 +715,22 @@ class AscendAttentionBackendImpl(AttentionImpl): " for AscendAttentionBackendImpl") assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - if self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_ONLY: - raise NotImplementedError("Encoder/decoder cross-attention " - "are not implemented for " + attn_type = self.attn_type + if attn_type not in [ + AttentionType.DECODER, AttentionType.ENCODER_ONLY + ]: + raise NotImplementedError("Encoder/Decoder cross-attention " + "is not implemented for " "PallasAttentionBackendImpl") num_tokens = query.shape[0] if attn_metadata is None: return output.fill_(0) key, value = self.reshape_and_cache(key, value, kv_cache, attn_metadata) - if self.attn_type == AttentionType.ENCODER_ONLY: - attn_output = self._forward_encode(query, key, value, - attn_metadata, output) + # pooling model branch + if isinstance(attn_metadata.is_causal_pooling, bool): + attn_output = self._forward_encoder_attention( + query, key, value, attn_metadata, output) output[:num_tokens] = attn_output[:num_tokens] return output output = self.forward_impl(query, key, value, kv_cache, attn_metadata, diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 1b346de6..68aac1e7 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -106,16 +106,7 @@ # # ** File: worker/patch_roberta.py ** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# 1. `vllm.model_executor.models.roberta.RobertaEmbedding.forward` -# Why: -# shift operation in `_encode_token_type_ids` and `_decode_token_type_ids` cannot run in ascend aclgraph mode -# How: -# Replace shift operation with multiplication and division. -# Related PR (if no, explain why): -# No, this need CANN add an aclnn shift operation -# Future Plan: -# Revert this when CANN support shift aclnn operation -# 2. `vllm.model_executor.models.roberta.RobertaForSequenceClassification.forward ` +# 1. `vllm.model_executor.models.bert ` # Why: # shift operation in `_encode_token_type_ids` and `_decode_token_type_ids` cannot run in ascend aclgraph mode # How: diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index a7f9d93c..2419d197 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -22,9 +22,9 @@ if HAS_TRITON: # isort: off import vllm_ascend.patch.platform.patch_sched_yield # noqa +import vllm_ascend.patch.worker.patch_bert # noqa import vllm_ascend.patch.worker.patch_distributed # noqa import vllm_ascend.patch.worker.patch_deepseek # noqa -import vllm_ascend.patch.worker.patch_roberta # noqa import vllm_ascend.patch.worker.patch_weight_loader # noqa import vllm_ascend.patch.worker.patch_multimodal_merge # noqa import vllm_ascend.patch.worker.patch_minicpm # noqa diff --git a/vllm_ascend/patch/worker/patch_bert.py b/vllm_ascend/patch/worker/patch_bert.py new file mode 100644 index 00000000..a48b499e --- /dev/null +++ b/vllm_ascend/patch/worker/patch_bert.py @@ -0,0 +1,45 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from vllm.model_executor.models import bert + +# aclgraph does not support shift operator for now +# TODO: revert me when aclgraph supports shift operator +TOKEN_TYPE_SHIFT = 30 +TOKEN_TYPE_MULTIPLIER = 1 << 30 +TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1 + + +def _encode_token_type_ids(input_ids: torch.Tensor, + token_type_ids: torch.Tensor) -> None: + # input_ids can be padded to the right + input_ids[:token_type_ids.shape[0]].bitwise_or_(token_type_ids * + TOKEN_TYPE_MULTIPLIER) + + +def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: + + token_type_ids = input_ids // TOKEN_TYPE_MULTIPLIER + + input_ids.bitwise_and_(TOKEN_MASK) + + return token_type_ids + + +bert._encode_token_type_ids = _encode_token_type_ids +bert._decode_token_type_ids = _decode_token_type_ids diff --git a/vllm_ascend/patch/worker/patch_roberta.py b/vllm_ascend/patch/worker/patch_roberta.py deleted file mode 100644 index a2e74615..00000000 --- a/vllm_ascend/patch/worker/patch_roberta.py +++ /dev/null @@ -1,91 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Optional, Union - -import torch -from vllm.model_executor.models.roberta import ( - RobertaEmbedding, RobertaForSequenceClassification, - replace_roberta_positions) -from vllm.sequence import IntermediateTensors - -# aclgraph does not support shift operator for now -# TODO: revert me when aclgraph supports shift operator -TOKEN_TYPE_SHIFT = 30 -TOKEN_TYPE_MULTIPLIER = 1 << 30 -TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1 - - -def _encode_token_type_ids(input_ids: torch.Tensor, - token_type_ids: torch.Tensor) -> None: - # input_ids can be padded to the right - input_ids[:token_type_ids.shape[0]].bitwise_or_(token_type_ids * - TOKEN_TYPE_MULTIPLIER) - - -def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: - - token_type_ids = input_ids // TOKEN_TYPE_MULTIPLIER - - input_ids.bitwise_and_(TOKEN_MASK) - - return token_type_ids - - -def roberta_for_sequence_classification_forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, -) -> torch.Tensor: - replace_roberta_positions(input_ids=input_ids, - position_ids=positions, - padding_idx=self.padding_idx) - if token_type_ids is not None: - assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) - assert input_ids is not None - _encode_token_type_ids(input_ids, token_type_ids) - return self.roberta(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) - - -def roberta_embedding_forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - inputs_embeds: Union[torch.Tensor, None] = None, -) -> torch.Tensor: - - token_type_ids = _decode_token_type_ids(input_ids) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - position_embeddings = self.position_embeddings(position_ids) - - token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings + position_embeddings - embeddings = self.LayerNorm(embeddings) - return embeddings - - -RobertaEmbedding.forward = roberta_embedding_forward -RobertaForSequenceClassification.forward = roberta_for_sequence_classification_forward diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8d57417c..68c36094 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -377,6 +377,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self.block_size, use_mla=self.model_config.use_mla, use_sparse=self.use_sparse) + self.attn_mask_builder = AttentionMaskBuilder(self.device) self._set_up_drafter() @@ -1029,8 +1030,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): if self.attn_mask_builder is None: raise ValueError("Attn mask builder is None") # Pooling situation. - if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": - return self.attn_mask_builder.get_pooling_mask() + if self.model_config.runner_type == "pooling": + return self.attn_mask_builder.get_attn_mask(2048, torch.bool) if self.vllm_config.model_config.use_mla: if self.pcp_size > 1: @@ -1933,8 +1934,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): common_prefix_len = 0 extra_attn_metadata_args = {} builder = attn_group.get_metadata_builder() - if isinstance(builder, GDNAttentionMetadataBuilder - ) or self.model_config.runner_type == "pooling": + if isinstance(builder, GDNAttentionMetadataBuilder): if use_spec_decode: extra_attn_metadata_args = dict( num_accepted_tokens=self.num_accepted_tokens. @@ -1946,6 +1946,11 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, **extra_attn_metadata_args) + elif self.model_config.runner_type == "pooling": + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + **extra_attn_metadata_args) else: attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, @@ -1968,18 +1973,52 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): input_ids, inputs_embeds, intermediate_tensors, max_num_scheduled_tokens) + def _init_model_kwargs(self): + model_kwargs = dict[str, Any]() + num_reqs = self.input_batch.num_reqs + + num_pooling_reqs = len(self.input_batch.pooling_params) + + if num_pooling_reqs == 0: + return model_kwargs + + pooling_params = self.input_batch.get_pooling_params() + + assert num_pooling_reqs == num_reqs + + token_type_id_requests = dict[int, Any]() + for i, param in enumerate(pooling_params): + if param.extra_kwargs is not None and \ + (token_types := param.extra_kwargs.get( + "compressed_token_type_ids")) is not None: + token_type_id_requests[i] = token_types + + if len(token_type_id_requests) == 0: + return model_kwargs + + seq_lens = self.seq_lens[:num_reqs] + token_type_ids = [] + + for i in range(num_reqs): + pos = token_type_id_requests.get(i, seq_lens[i]) + ids = (torch.arange(seq_lens[i]) >= pos).int() + token_type_ids.append(ids) + + model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( + device=self.device) + return model_kwargs + def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill, maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, inputs_embeds): assert self.model is not None - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + hidden_states = self.model(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **self._init_model_kwargs()) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \ @@ -2022,7 +2061,14 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): - if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): + if self.model_config.runner_type == "pooling": + if isinstance( + self.kv_cache_config.kv_cache_groups[0].kv_cache_spec, + EncoderOnlyAttentionSpec): + attn_state = AscendAttentionState.PrefillNoCache + else: + attn_state = AscendAttentionState.PrefillCacheHit + elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): @@ -2251,7 +2297,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): " a batch must be pooling request" hidden_states = hidden_states[:num_scheduled_tokens] - pooling_metadata = self.input_batch.pooling_metadata + pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), device=hidden_states.device) seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs] @@ -4049,6 +4095,15 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): desc="Capturing ACL graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", aclgraph_runtime_mode.name)) + + force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL) + # When the kv cache spec is empty, PiecewiseBackend is not initialized, and + # compilation_case=1 will cause the dynamic shape position to be incorrectly derived. + if not self.get_kv_cache_spec(): + self._dummy_run(2, + aclgraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode) # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: for _ in range(self.compilation_config.cudagraph_num_of_warmups): @@ -4057,7 +4112,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. - force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL) self._dummy_run(num_tokens, aclgraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index ad4f525f..a6f8e3bd 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -793,17 +793,12 @@ class InputBatch: logitsprocs=self.logitsprocs, ) - @property - def pooling_metadata(self) -> PoolingMetadata: - if len(self.pooling_params) == 0: - pooling_params = [] - else: - # Note, for now this assumes that all request in the batch - # are either sampling or pooling requests - assert len(self.req_ids) == len(self.pooling_params) - pooling_params = [ - self.pooling_params[req_id] for req_id in self.req_ids - ] + def get_pooling_params(self) -> list[PoolingParams]: + assert len(self.req_ids) == len(self.pooling_params) + return [self.pooling_params[req_id] for req_id in self.req_ids] + + def get_pooling_metadata(self) -> PoolingMetadata: + pooling_params = self.get_pooling_params() return PoolingMetadata( prompt_lens=torch.from_numpy(