[Model] Support pooling models (#3122)

### What this PR does / why we need it?

Support pooling models (like `bge-reranker-v2-m3`) in vllm-ascend, this
pr covered the three model types of embed (cls_token, mean_token,
lasttoken).

After this
[commit](17373dcd93),
vllm has provided support for adapting pooling models on the v1 engine.
This PR includes corresponding adaptations on the vllm-ascend side.

Fixes #1960

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: lianyibo <lianyibo1@kunlunit.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
lianyibo
2025-12-10 11:37:57 +08:00
committed by GitHub
parent 1a7a34c5ec
commit e32014ac1d
17 changed files with 577 additions and 338 deletions

View File

@@ -77,6 +77,7 @@ jobs:
# pytest -sv tests/e2e/singlecard/test_aclgraph.py
# pytest -sv tests/e2e/singlecard/test_quantization.py
pytest -sv tests/e2e/singlecard/test_vlm.py::test_multimodal_vl
pytest -sv tests/e2e/singlecard/pooling/test_classification.py::test_classify_correctness
- name: Run e2e test
env:
@@ -91,9 +92,7 @@ jobs:
pytest -sv tests/e2e/singlecard/test_completion_with_prompt_embeds.py
pytest -sv tests/e2e/singlecard/test_aclgraph.py
pytest -sv tests/e2e/singlecard/test_aclgraph_mem.py
pytest -sv tests/e2e/singlecard/test_bge_model.py
pytest -sv tests/e2e/singlecard/test_camem.py
pytest -sv tests/e2e/singlecard/test_embedding.py
# pytest -sv tests/e2e/singlecard/test_embedding_aclgraph.py
pytest -sv tests/e2e/singlecard/test_guided_decoding.py
# torch 2.8 doesn't work with lora, fix me
@@ -104,6 +103,7 @@ jobs:
pytest -sv tests/e2e/singlecard/test_vlm.py
pytest -sv tests/e2e/singlecard/multi-modal/test_internvl.py
pytest -sv tests/e2e/singlecard/test_xlite.py
pytest -sv tests/e2e/singlecard/pooling/
# ------------------------------------ v1 spec decode test ------------------------------------ #
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

View File

@@ -48,7 +48,8 @@ Get the latest info here: https://github.com/vllm-project/vllm-ascend/issues/160
|-------------------------------|-----------|----------------------------------------------------------------------|------|--------------------|------|-----------------|------------------------|------|----------------------|------------------|-----------------|-------------------|-----------------|---------------|-------------------------------|--------------------|--------------------|---------------|---------------------|-----|
| Qwen3-Embedding | ✅ | |||||||||||||||||||
| Molmo | ✅ | [1942](https://github.com/vllm-project/vllm-ascend/issues/1942) |||||||||||||||||||
| XLM-RoBERTa-based | | [1960](https://github.com/vllm-project/vllm-ascend/issues/1960) |||||||||||||||||||
| XLM-RoBERTa-based | | |||||||||||||||||||
| Bert | ✅ | |||||||||||||||||||
## Multimodal Language Models

View File

@@ -26,7 +26,7 @@ import shlex
import subprocess
import sys
import time
from typing import Any, List, Optional, Tuple, TypeVar, Union
from typing import Any, Optional, Tuple, TypeVar, Union
import httpx
import numpy as np
@@ -42,7 +42,8 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BatchEncoding, BatchFeature)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm import LLM, SamplingParams
from vllm.config.model import _get_and_verify_dtype
from vllm.config.model import (ConvertOption, RunnerOption,
_get_and_verify_dtype)
from vllm.inputs import TextPrompt
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
@@ -67,7 +68,7 @@ from vllm.distributed.parallel_state import ( # noqa E402
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_M = TypeVar("_M")
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
_PromptMultiModalInput = Union[list[_M], list[list[_M]]]
PromptImageInput = _PromptMultiModalInput[Image.Image]
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
@@ -320,12 +321,11 @@ class VllmRunner:
def __init__(
self,
model_name: str,
runner: str = "auto",
runner: RunnerOption = "auto",
convert: ConvertOption = "auto",
tokenizer_name: Optional[str] = None,
tokenizer_mode: str = "auto",
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len: int = 1024,
max_model_len: Optional[int] = 1024,
dtype: str = "auto",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
@@ -339,6 +339,7 @@ class VllmRunner:
self.model = LLM(
model=model_name,
runner=runner,
convert=convert,
tokenizer=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=True,
@@ -356,73 +357,79 @@ class VllmRunner:
def get_inputs(
self,
prompts: List[str],
prompts: Union[list[str], list[torch.Tensor], list[int]],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[TextPrompt]:
if images is not None:
assert len(prompts) == len(images)
) -> list[TextPrompt]:
if videos is not None:
assert len(prompts) == len(videos)
if any(x is not None and len(x) != len(prompts)
for x in [images, videos, audios]):
raise ValueError(
"All non-None multimodal inputs must have the same length as "
"prompts")
if audios is not None:
assert len(prompts) == len(audios)
inputs = []
for i, prompt in enumerate(prompts):
multi_modal_data = {}
if images is not None and (image := images[i]) is not None:
multi_modal_data["image"] = image
if videos is not None and (video := videos[i]) is not None:
multi_modal_data["video"] = video
if audios is not None and (audio := audios[i]) is not None:
multi_modal_data["audio"] = audio
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
if image is not None:
inputs[i]["multi_modal_data"] = {"image": image}
text_prompt_kwargs: dict[str, Any] = {
"multi_modal_data": multi_modal_data or None
}
if isinstance(prompt, str):
text_prompt_kwargs["prompt"] = prompt
elif isinstance(prompt, list):
text_prompt_kwargs["prompt_token_ids"] = prompt
else:
text_prompt_kwargs["prompt_embeds"] = prompt
if videos is not None:
for i, video in enumerate(videos):
if video is not None:
inputs[i]["multi_modal_data"] = {"video": video}
if audios is not None:
for i, audio in enumerate(audios):
if audio is not None:
inputs[i]["multi_modal_data"] = {"audio": audio}
inputs.append(TextPrompt(**text_prompt_kwargs))
return inputs
def generate(
self,
prompts: List[str],
prompts: Union[list[str], list[torch.Tensor]],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
**kwargs: Any,
) -> list[tuple[list[list[int]], list[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)
outputs: List[Tuple[List[List[int]], List[str]]] = []
outputs: list[tuple[list[list[int]], list[str]]] = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids: List[List[int]] = []
req_sample_output_strs: List[str] = []
req_sample_output_ids: list[list[int]] = []
req_sample_output_strs: list[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
req_sample_output_strs.append((prompt_str or "") + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
@staticmethod
def _final_steps_generate_w_logprobs(
req_outputs: List[RequestOutput],
) -> List[TokensTextLogprobsPromptLogprobs]:
outputs: List[TokensTextLogprobsPromptLogprobs] = []
req_outputs: list[RequestOutput],
) -> list[TokensTextLogprobsPromptLogprobs]:
outputs: list[TokensTextLogprobsPromptLogprobs] = []
for req_output in req_outputs:
assert len(req_output.outputs) > 0
for sample in req_output.outputs:
@@ -435,20 +442,22 @@ class VllmRunner:
def generate_w_logprobs(
self,
prompts: List[str],
prompts: list[str],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
**kwargs: Any,
) -> Union[list[TokensTextLogprobs],
list[TokensTextLogprobsPromptLogprobs]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)
toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
@@ -459,34 +468,37 @@ class VllmRunner:
def generate_greedy(
self,
prompts: List[str],
prompts: Union[list[str], list[torch.Tensor]],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[int], str]]:
**kwargs: Any,
) -> list[tuple[list[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
audios=audios,
**kwargs)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]
def generate_greedy_logprobs(
self,
prompts: List[str],
prompts: list[str],
max_tokens: int,
num_logprobs: int,
num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
stop_token_ids: Optional[list[int]] = None,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Union[list[TokensTextLogprobs],
list[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
max_tokens=max_tokens,
@@ -499,23 +511,46 @@ class VllmRunner:
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)
videos=videos,
**kwargs)
def encode(
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[List[float]]:
def classify(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.model.classify(prompts)
return [req_output.outputs.probs for req_output in req_outputs]
def embed(self,
prompts: list[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
*args,
**kwargs) -> list[list[float]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.embed(inputs)
req_outputs = self.model.embed(inputs, *args, **kwargs)
return [req_output.outputs.embedding for req_output in req_outputs]
def encode(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.model.encode(prompts)
return [req_output.outputs.data for req_output in req_outputs]
def reward(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.model.reward(prompts)
return [req_output.outputs.data for req_output in req_outputs]
def score(
self,
text_1: Union[str, list[str]],
text_2: Union[str, list[str]],
*args,
**kwargs,
) -> list[float]:
req_outputs = self.model.score(text_1, text_2, *args, **kwargs)
return [req_output.outputs.score for req_output in req_outputs]
def __enter__(self):
return self
@@ -635,10 +670,79 @@ class HfRunner:
if skip_tokenizer_init:
self.tokenizer = self.processor.tokenizer
def get_inputs(
self,
prompts: list[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> list[Union[BatchFeature, BatchEncoding]]:
if images is not None:
assert len(prompts) == len(images)
if videos is not None:
assert len(prompts) == len(videos)
if audios is not None:
assert len(prompts) == len(audios)
all_inputs: list[Union[BatchFeature, BatchEncoding]] = []
for i, prompt in enumerate(prompts):
processor_kwargs: dict[str, Any] = {
"text": prompt,
"return_tensors": "pt",
}
if images is not None and (image := images[i]) is not None:
processor_kwargs["images"] = image
if videos is not None and (video := videos[i]) is not None:
processor_kwargs["videos"] = video
if audios is not None and (audio_inputs := audios[i]) is not None:
# HACK - not all processors take sampling_rate; we should
# clean this up in the future.
if len(audio_inputs) == 2:
audio, sr = audio_inputs
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr
else:
processor_kwargs["audio"] = audio_inputs
inputs = self.processor(**processor_kwargs)
if isinstance(inputs, BatchFeature):
inputs = inputs.to(dtype=self.dtype)
all_inputs.append(inputs)
return all_inputs
def classify(self, prompts: list[str]) -> list[str]:
# output is final logits
all_inputs = self.get_inputs(prompts)
outputs = []
problem_type = getattr(self.config, "problem_type", "")
for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs))
if problem_type == "regression":
logits = output.logits[0].tolist()
elif problem_type == "multi_label_classification":
logits = output.logits.sigmoid()[0].tolist()
else:
logits = output.logits.softmax(dim=-1)[0].tolist()
outputs.append(logits)
return outputs
def encode(self, prompts: list[str], *args,
**kwargs) -> list[list[torch.Tensor]]:
return self.model.encode(prompts, *args, **kwargs)
def predict(self, prompts: list[list[str]], *args,
**kwargs) -> torch.Tensor:
return self.model.predict(prompts,
*args,
convert_to_tensor=True,
**kwargs)
def __enter__(self):
return self
@@ -652,7 +756,7 @@ def ilama_lora_files():
return snapshot_download(repo_id="vllm-ascend/ilama-text2sql-spider")
def qwen_prompt(questions: List[str]) -> List[str]:
def qwen_prompt(questions: list[str]) -> list[str]:
placeholder = "<|image_pad|>"
return [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"

View File

View File

@@ -0,0 +1,34 @@
import torch
from modelscope import snapshot_download # type: ignore[import-untyped]
from transformers import AutoModelForSequenceClassification
from tests.e2e.conftest import HfRunner, VllmRunner
def test_classify_correctness() -> None:
model_name = snapshot_download("Howeee/Qwen2.5-1.5B-apeach")
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is what",
]
with VllmRunner(
model_name,
runner="pooling",
max_model_len=None,
cudagraph_capture_sizes=[4],
) as vllm_runner:
vllm_outputs = vllm_runner.classify(prompts)
with HfRunner(model_name,
dtype="float32",
auto_cls=AutoModelForSequenceClassification) as hf_runner:
hf_outputs = hf_runner.classify(prompts)
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = torch.tensor(hf_output)
vllm_output = torch.tensor(vllm_output)
assert torch.allclose(hf_output, vllm_output, 1e-2)

View File

@@ -16,22 +16,32 @@
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
import pytest
from modelscope import snapshot_download # type: ignore[import-untyped]
from tests.e2e.conftest import HfRunner, VllmRunner
from tests.e2e.utils import check_embeddings_close
MODELS = [
"Qwen/Qwen3-Embedding-0.6B", # lasttoken
"BAAI/bge-small-en-v1.5", # cls_token
"intfloat/multilingual-e5-small" # mean_tokens
]
def test_embed_models_correctness():
@pytest.mark.parametrize("model", MODELS)
def test_embed_models_correctness(model: str):
queries = ['What is the capital of China?', 'Explain gravity']
model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B")
model_name = snapshot_download(model)
with VllmRunner(
model_name,
runner="pooling",
enforce_eager=False,
max_model_len=None,
cudagraph_capture_sizes=[4],
) as vllm_runner:
vllm_outputs = vllm_runner.encode(queries)
vllm_outputs = vllm_runner.embed(queries)
with HfRunner(
model_name,

View File

@@ -0,0 +1,187 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from modelscope import snapshot_download # type: ignore[import-untyped]
from tests.e2e.conftest import HfRunner, VllmRunner
CROSS_ENCODER_MODELS = [
"dengcao/ms-marco-MiniLM-L6-v2", # Bert
"BAAI/bge-reranker-v2-m3", # Roberta
]
EMBEDDING_MODELS = [
"sentence-transformers/all-MiniLM-L12-v2",
]
TEXTS_1 = [
"What is the capital of France?",
"What is the capital of Germany?",
]
TEXTS_2 = [
"The capital of France is Paris.",
"The capital of Germany is Berlin.",
]
DTYPE = "half"
@pytest.fixture(scope="module", params=CROSS_ENCODER_MODELS)
def model_name(request):
yield snapshot_download(request.param)
def test_cross_encoder_1_to_1(model_name):
text_pair = [TEXTS_1[0], TEXTS_2[0]]
with HfRunner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model:
hf_outputs = hf_model.predict([text_pair]).tolist()
with VllmRunner(model_name,
runner="pooling",
dtype=DTYPE,
cudagraph_capture_sizes=[4],
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(text_pair[0], text_pair[1])
assert len(vllm_outputs) == 1
assert len(hf_outputs) == 1
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
def test_cross_encoder_1_to_N(model_name):
text_pairs = [
[TEXTS_1[0], TEXTS_2[0]],
[TEXTS_1[0], TEXTS_2[1]],
]
with HfRunner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model:
hf_outputs = hf_model.predict(text_pairs).tolist()
with VllmRunner(model_name,
runner="pooling",
dtype=DTYPE,
cudagraph_capture_sizes=[4],
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2)
assert len(vllm_outputs) == 2
assert len(hf_outputs) == 2
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
def test_cross_encoder_N_to_N(model_name):
text_pairs = [
[TEXTS_1[0], TEXTS_2[0]],
[TEXTS_1[1], TEXTS_2[1]],
]
with HfRunner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model:
hf_outputs = hf_model.predict(text_pairs).tolist()
with VllmRunner(model_name,
runner="pooling",
dtype=DTYPE,
cudagraph_capture_sizes=[4],
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2)
assert len(vllm_outputs) == 2
assert len(hf_outputs) == 2
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
@pytest.fixture(scope="module", params=EMBEDDING_MODELS)
def emb_model_name(request):
yield snapshot_download(request.param)
def test_embedding_1_to_1(emb_model_name):
text_pair = [TEXTS_1[0], TEXTS_2[0]]
with HfRunner(emb_model_name, dtype=DTYPE,
is_sentence_transformer=True) as hf_model:
hf_embeddings = hf_model.encode(text_pair)
hf_outputs = [
F.cosine_similarity(*map(torch.tensor, hf_embeddings), dim=0)
]
with VllmRunner(emb_model_name,
runner="pooling",
dtype=DTYPE,
cudagraph_capture_sizes=[4],
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(text_pair[0], text_pair[1])
assert len(vllm_outputs) == 1
assert len(hf_outputs) == 1
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
def test_embedding_1_to_N(emb_model_name):
text_pairs = [
[TEXTS_1[0], TEXTS_2[0]],
[TEXTS_1[0], TEXTS_2[1]],
]
with HfRunner(emb_model_name, dtype=DTYPE,
is_sentence_transformer=True) as hf_model:
hf_embeddings = [
hf_model.encode(text_pair) for text_pair in text_pairs
]
hf_outputs = [
F.cosine_similarity(*map(torch.tensor, pair), dim=0)
for pair in hf_embeddings
]
with VllmRunner(emb_model_name,
runner="pooling",
dtype=DTYPE,
cudagraph_capture_sizes=[4],
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2)
assert len(vllm_outputs) == 2
assert len(hf_outputs) == 2
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
def test_embedding_N_to_N(emb_model_name):
text_pairs = [
[TEXTS_1[0], TEXTS_2[0]],
[TEXTS_1[1], TEXTS_2[1]],
]
with HfRunner(emb_model_name, dtype=DTYPE,
is_sentence_transformer=True) as hf_model:
hf_embeddings = [
hf_model.encode(text_pair) for text_pair in text_pairs
]
hf_outputs = [
F.cosine_similarity(*map(torch.tensor, pair), dim=0)
for pair in hf_embeddings
]
with VllmRunner(emb_model_name,
runner="pooling",
dtype=DTYPE,
cudagraph_capture_sizes=[4],
max_model_len=None) as vllm_model:
vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2)
assert len(vllm_outputs) == 2
assert len(hf_outputs) == 2
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)

View File

@@ -1,49 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
from modelscope import snapshot_download # type: ignore[import-untyped]
from tests.e2e.conftest import HfRunner, VllmRunner
from tests.e2e.utils import check_embeddings_close
def test_bge_model_correctness():
queries = ['What is the capital of China?', 'Explain gravity']
model_name = snapshot_download("BAAI/bge-m3")
with VllmRunner(
model_name,
runner="pooling",
enforce_eager=True,
) as vllm_runner:
vllm_outputs = vllm_runner.encode(queries)
with HfRunner(
model_name,
dtype="float32",
is_sentence_transformer=True,
) as hf_runner:
hf_outputs = hf_runner.encode(queries)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@@ -1,55 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
import os
import pytest
from tests.e2e.conftest import VllmRunner
from tests.e2e.utils import check_embeddings_close
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODELS = ["BAAI/bge-m3"]
@pytest.mark.parametrize("model_name", MODELS)
def test_aclgrpah_embed_models_correctness(model_name):
queries = ['What is the capital of China?', 'Explain gravity']
with VllmRunner(
model_name,
runner="pooling",
enforce_eager=False,
) as vllm_aclgraph_runner:
vllm_aclgraph_outputs = vllm_aclgraph_runner.encode(queries)
with VllmRunner(
model_name,
runner="pooling",
enforce_eager=True,
) as vllm_runner:
vllm_outputs = vllm_runner.encode(queries)
check_embeddings_close(
embeddings_0_lst=vllm_outputs,
embeddings_1_lst=vllm_aclgraph_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@@ -35,7 +35,6 @@ class AttentionMaskBuilder:
self.attn_mask_cache = None
self._seq_len_cached = 0
self.device = device
self.pooling_mask = None
self.mla_mask = None
self.chunked_prefill_attn_mask = None
self.pcp_mla_mask = None
@@ -50,14 +49,6 @@ class AttentionMaskBuilder:
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
).to(self.device, non_blocking=True)
def get_pooling_mask(self):
if self.pooling_mask is None:
# the compressed attention mask for npu_fusion_attention sparse mode 4
self.pooling_mask = torch.triu(torch.ones(
2048, 2048), diagonal=1).to(torch.bool).to(self.device,
non_blocking=True)
return self.pooling_mask
def get_splitfuse_attn_mask(self) -> torch.Tensor:
if self.chunked_prefill_attn_mask is None:
self.chunked_prefill_attn_mask = torch.triu(

View File

@@ -221,6 +221,10 @@ class AscendMetadata:
# dcp
decode_meta: Optional[AscendMetadataForDecode] = None
# Whether is the pooling model with causal attention,
# used to guide the attention computation for pooling models.
is_causal_pooling: Optional[bool] = None
class AscendAttentionMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
@@ -319,6 +323,10 @@ class AscendAttentionMetadataBuilder:
query_start_loc = query_start_loc_cpu.pin_memory().to(
self.device, non_blocking=True)
is_causal_pooling = None
if self.model_config.runner_type == "pooling":
is_causal_pooling = common_attn_metadata.causal if hasattr(
common_attn_metadata, 'causal') else True
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
@@ -336,7 +344,8 @@ class AscendAttentionMetadataBuilder:
attn_mask=attn_mask,
attn_state=attn_state,
num_prefills=num_prefills,
num_decodes=num_decodes)
num_decodes=num_decodes,
is_causal_pooling=is_causal_pooling)
return attn_metadata
def build_for_graph_capture(
@@ -597,30 +606,39 @@ class AscendAttentionBackendImpl(AttentionImpl):
out=output)
return output
def _forward_encode(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
) -> torch.Tensor:
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
output = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
pre_tockens=attn_metadata.max_query_len,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
)[0]
return output
def _forward_encoder_attention(self, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
attn_metadata: AscendMetadata,
_: torch.Tensor) -> torch.Tensor:
assert attn_metadata is not None
assert attn_metadata.is_causal_pooling is not None
if attn_metadata.is_causal_pooling:
# use sparse_mode 3 in causal scenario
return torch_npu.npu_fusion_attention(
query=query,
key=key,
value=value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=3,
atten_mask=attn_metadata.attn_mask,
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
)[0]
else:
# use default sparse_mode 0 in normal scenario, which means no mask works on it
return torch_npu.npu_fusion_attention(
query=query,
key=key,
value=value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
actual_seq_qlen=attn_metadata.actual_seq_lengths_q,
actual_seq_kvlen=attn_metadata.actual_seq_lengths_q,
)[0]
def reshape_and_cache(
self,
@@ -697,18 +715,22 @@ class AscendAttentionBackendImpl(AttentionImpl):
" for AscendAttentionBackendImpl")
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
if self.attn_type != AttentionType.DECODER and self.attn_type != AttentionType.ENCODER_ONLY:
raise NotImplementedError("Encoder/decoder cross-attention "
"are not implemented for "
attn_type = self.attn_type
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/Decoder cross-attention "
"is not implemented for "
"PallasAttentionBackendImpl")
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
key, value = self.reshape_and_cache(key, value, kv_cache,
attn_metadata)
if self.attn_type == AttentionType.ENCODER_ONLY:
attn_output = self._forward_encode(query, key, value,
attn_metadata, output)
# pooling model branch
if isinstance(attn_metadata.is_causal_pooling, bool):
attn_output = self._forward_encoder_attention(
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
output = self.forward_impl(query, key, value, kv_cache, attn_metadata,

View File

@@ -106,16 +106,7 @@
#
# ** File: worker/patch_roberta.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.roberta.RobertaEmbedding.forward`
# Why:
# shift operation in `_encode_token_type_ids` and `_decode_token_type_ids` cannot run in ascend aclgraph mode
# How
# Replace shift operation with multiplication and division.
# Related PR (if no, explain why):
# No, this need CANN add an aclnn shift operation
# Future Plan:
# Revert this when CANN support shift aclnn operation
# 2. `vllm.model_executor.models.roberta.RobertaForSequenceClassification.forward `
# 1. `vllm.model_executor.models.bert `
# Why:
# shift operation in `_encode_token_type_ids` and `_decode_token_type_ids` cannot run in ascend aclgraph mode
# How

View File

@@ -22,9 +22,9 @@ if HAS_TRITON:
# isort: off
import vllm_ascend.patch.platform.patch_sched_yield # noqa
import vllm_ascend.patch.worker.patch_bert # noqa
import vllm_ascend.patch.worker.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_deepseek # noqa
import vllm_ascend.patch.worker.patch_roberta # noqa
import vllm_ascend.patch.worker.patch_weight_loader # noqa
import vllm_ascend.patch.worker.patch_multimodal_merge # noqa
import vllm_ascend.patch.worker.patch_minicpm # noqa

View File

@@ -0,0 +1,45 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from vllm.model_executor.models import bert
# aclgraph does not support shift operator for now
# TODO: revert me when aclgraph supports shift operator
TOKEN_TYPE_SHIFT = 30
TOKEN_TYPE_MULTIPLIER = 1 << 30
TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1
def _encode_token_type_ids(input_ids: torch.Tensor,
token_type_ids: torch.Tensor) -> None:
# input_ids can be padded to the right
input_ids[:token_type_ids.shape[0]].bitwise_or_(token_type_ids *
TOKEN_TYPE_MULTIPLIER)
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
token_type_ids = input_ids // TOKEN_TYPE_MULTIPLIER
input_ids.bitwise_and_(TOKEN_MASK)
return token_type_ids
bert._encode_token_type_ids = _encode_token_type_ids
bert._decode_token_type_ids = _decode_token_type_ids

View File

@@ -1,91 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Union
import torch
from vllm.model_executor.models.roberta import (
RobertaEmbedding, RobertaForSequenceClassification,
replace_roberta_positions)
from vllm.sequence import IntermediateTensors
# aclgraph does not support shift operator for now
# TODO: revert me when aclgraph supports shift operator
TOKEN_TYPE_SHIFT = 30
TOKEN_TYPE_MULTIPLIER = 1 << 30
TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1
def _encode_token_type_ids(input_ids: torch.Tensor,
token_type_ids: torch.Tensor) -> None:
# input_ids can be padded to the right
input_ids[:token_type_ids.shape[0]].bitwise_or_(token_type_ids *
TOKEN_TYPE_MULTIPLIER)
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
token_type_ids = input_ids // TOKEN_TYPE_MULTIPLIER
input_ids.bitwise_and_(TOKEN_MASK)
return token_type_ids
def roberta_for_sequence_classification_forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
replace_roberta_positions(input_ids=input_ids,
position_ids=positions,
padding_idx=self.padding_idx)
if token_type_ids is not None:
assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
assert input_ids is not None
_encode_token_type_ids(input_ids, token_type_ids)
return self.roberta(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
def roberta_embedding_forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
inputs_embeds: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings
RobertaEmbedding.forward = roberta_embedding_forward
RobertaForSequenceClassification.forward = roberta_for_sequence_classification_forward

View File

@@ -377,6 +377,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self.block_size,
use_mla=self.model_config.use_mla,
use_sparse=self.use_sparse)
self.attn_mask_builder = AttentionMaskBuilder(self.device)
self._set_up_drafter()
@@ -1029,8 +1030,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
# Pooling situation.
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
return self.attn_mask_builder.get_pooling_mask()
if self.model_config.runner_type == "pooling":
return self.attn_mask_builder.get_attn_mask(2048, torch.bool)
if self.vllm_config.model_config.use_mla:
if self.pcp_size > 1:
@@ -1933,8 +1934,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
common_prefix_len = 0
extra_attn_metadata_args = {}
builder = attn_group.get_metadata_builder()
if isinstance(builder, GDNAttentionMetadataBuilder
) or self.model_config.runner_type == "pooling":
if isinstance(builder, GDNAttentionMetadataBuilder):
if use_spec_decode:
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.
@@ -1946,6 +1946,11 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args)
elif self.model_config.runner_type == "pooling":
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args)
else:
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
@@ -1968,18 +1973,52 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
input_ids, inputs_embeds, intermediate_tensors,
max_num_scheduled_tokens)
def _init_model_kwargs(self):
model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs
num_pooling_reqs = len(self.input_batch.pooling_params)
if num_pooling_reqs == 0:
return model_kwargs
pooling_params = self.input_batch.get_pooling_params()
assert num_pooling_reqs == num_reqs
token_type_id_requests = dict[int, Any]()
for i, param in enumerate(pooling_params):
if param.extra_kwargs is not None and \
(token_types := param.extra_kwargs.get(
"compressed_token_type_ids")) is not None:
token_type_id_requests[i] = token_types
if len(token_type_id_requests) == 0:
return model_kwargs
seq_lens = self.seq_lens[:num_reqs]
token_type_ids = []
for i in range(num_reqs):
pos = token_type_id_requests.get(i, seq_lens[i])
ids = (torch.arange(seq_lens[i]) >= pos).int()
token_type_ids.append(ids)
model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
device=self.device)
return model_kwargs
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
maybe_padded_num_tokens,
input_ids, positions,
intermediate_tensors,
inputs_embeds):
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
hidden_states = self.model(input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**self._init_model_kwargs())
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
@@ -2022,7 +2061,14 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens):
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
if self.model_config.runner_type == "pooling":
if isinstance(
self.kv_cache_config.kv_cache_groups[0].kv_cache_spec,
EncoderOnlyAttentionSpec):
attn_state = AscendAttentionState.PrefillNoCache
else:
attn_state = AscendAttentionState.PrefillCacheHit
elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
@@ -2251,7 +2297,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
" a batch must be pooling request"
hidden_states = hidden_states[:num_scheduled_tokens]
pooling_metadata = self.input_batch.pooling_metadata
pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
device=hidden_states.device)
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
@@ -4049,6 +4095,15 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
desc="Capturing ACL graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
aclgraph_runtime_mode.name))
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
# When the kv cache spec is empty, PiecewiseBackend is not initialized, and
# compilation_case=1 will cause the dynamic shape position to be incorrectly derived.
if not self.get_kv_cache_spec():
self._dummy_run(2,
aclgraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode)
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
@@ -4057,7 +4112,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
self._dummy_run(num_tokens,
aclgraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,

View File

@@ -793,17 +793,12 @@ class InputBatch:
logitsprocs=self.logitsprocs,
)
@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]
def get_pooling_params(self) -> list[PoolingParams]:
assert len(self.req_ids) == len(self.pooling_params)
return [self.pooling_params[req_id] for req_id in self.req_ids]
def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()
return PoolingMetadata(
prompt_lens=torch.from_numpy(