[V1][ModelRunner] Support pooling model for v1 engine (#1359)

### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os

import torch
from vllm import LLM


os.environ["VLLM_USE_MODELSCOPE"]="True"

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery:{query}'

# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'

queries = [
    get_detailed_instruct(task, 'What is the capital of China?'),
    get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents

model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")

outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------

Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
This commit is contained in:
Li Wang
2025-06-30 16:31:12 +08:00
committed by GitHub
parent 790c810bf7
commit 5f8241c25c
10 changed files with 1312 additions and 43 deletions

View File

@@ -19,18 +19,23 @@
import contextlib
import gc
from typing import List, Optional, Tuple, TypeVar, Union
from typing import Any, List, Optional, Tuple, TypeVar, Union
import numpy as np
import pytest
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BatchEncoding, BatchFeature)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm import LLM, SamplingParams
from vllm.config import TaskOption
from vllm.config import TaskOption, _get_and_verify_dtype
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import is_list_of
from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs,
@@ -45,6 +50,7 @@ adapt_patch(True)
from vllm.distributed.parallel_state import ( # noqa E402
destroy_distributed_environment, destroy_model_parallel)
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_M = TypeVar("_M")
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
@@ -364,3 +370,131 @@ def prompt_template(request):
@pytest.fixture(scope="session")
def ilama_lora_files():
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")
class HfRunner:
def get_default_device(self):
from vllm.platforms import current_platform
return ("cpu"
if current_platform.is_cpu() else current_platform.device_type)
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if x is None or isinstance(x, (bool, )):
return x
if device is None:
device = self.device
if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}
if hasattr(x, "device") and x.device.type == device:
return x
return x.to(device)
def __init__(
self,
model_name: str,
dtype: str = "auto",
*,
model_kwargs: Optional[dict[str, Any]] = None,
trust_remote_code: bool = True,
is_sentence_transformer: bool = False,
is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False,
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
) -> None:
model_name = maybe_model_redirect(model_name)
self.model_name = model_name
self.config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
)
self.device = self.get_default_device()
self.dtype = torch_dtype = _get_and_verify_dtype(
self.model_name,
self.config,
dtype=dtype,
is_pooling_model=is_sentence_transformer or is_cross_encoder,
)
model_kwargs = model_kwargs if model_kwargs is not None else {}
model_kwargs.setdefault("torch_dtype", torch_dtype)
if is_sentence_transformer:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_name,
device=self.device,
model_kwargs=model_kwargs,
trust_remote_code=trust_remote_code,
)
elif is_cross_encoder:
# Lazy init required for AMD CI
from sentence_transformers import CrossEncoder
self.model = CrossEncoder(
model_name,
device=self.device,
automodel_args=model_kwargs,
trust_remote_code=trust_remote_code,
)
else:
model = auto_cls.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
**model_kwargs,
)
# in case some unquantized custom models are not in same dtype
if (getattr(model, "quantization_method", None) is None
and any(p.dtype != self.dtype
for p in model.parameters())):
model = model.to(dtype=self.dtype)
if (getattr(model, "quantization_method", None) != "bitsandbytes"
and len({p.device
for p in model.parameters()}) < 2):
model = model.to(device=self.device)
self.model = model
if not skip_tokenizer_init:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
self.processor = AutoProcessor.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
if skip_tokenizer_init:
self.tokenizer = self.processor.tokenizer
def encode(self, prompts: list[str], *args,
**kwargs) -> list[list[torch.Tensor]]:
return self.model.encode(prompts, *args, **kwargs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup_dist_env_and_memory()
@pytest.fixture(scope="session")
def hf_runner():
return HfRunner

View File

@@ -0,0 +1,72 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
from collections.abc import Sequence
from typing import Optional
import pytest
from modelscope import snapshot_download # type: ignore[import-untyped]
from tests.conftest import HfRunner
from tests.utils import check_embeddings_close, matryoshka_fy
from vllm_ascend.utils import vllm_version_is
def run_embedding_correctness_test(
hf_model: "HfRunner",
inputs: list[str],
vllm_outputs: Sequence[list[float]],
dimensions: Optional[int] = None,
):
hf_outputs = hf_model.encode(inputs)
if dimensions:
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
# dummy to avoid pytest collect nothing and exit code 5
def test_dummy():
assert True
@pytest.mark.skipif(vllm_version_is("0.9.1"),
reason="vLLM 0.9.1 does not support embed task for v1")
def test_embed_models_correctness(hf_runner, vllm_runner):
queries = ['What is the capital of China?', 'Explain gravity']
model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B")
with vllm_runner(
model_name,
task="embed",
enforce_eager=True,
) as vllm_model:
vllm_outputs = vllm_model.encode(queries)
with hf_runner(
model_name,
dtype="float32",
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, queries, vllm_outputs)

View File

@@ -0,0 +1,162 @@
import unittest
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]):
return CachedRequestState(
req_id=req_id,
prompt_token_ids=prompt,
mm_inputs=[],
mm_positions=[],
sampling_params=SamplingParams(),
pooling_params=None,
generator=None,
block_ids=([], ),
num_computed_tokens=0,
output_token_ids=output,
)
class TestInputBatch(unittest.TestCase):
def setUp(self):
self.max_num_reqs = 10
self.max_model_len = 32
self.max_num_batched_tokens = 132
self.vocab_size = 1000
self.device = torch.device("cpu")
self.block_sizes = [128]
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_batched_tokens,
device=self.device,
pin_memory=False,
vocab_size=self.vocab_size,
block_sizes=self.block_sizes,
)
self.cached_request_state = mock_cached_request_state()
def test_shapes_and_defaults(self):
# torch tensor shape assertions
self.assertEqual(self.input_batch.token_ids_cpu_tensor.shape,
(self.max_num_reqs, self.max_model_len))
self.assertEqual(self.input_batch.temperature.shape,
(self.max_num_reqs, ))
self.assertEqual(self.input_batch.top_k.shape, (self.max_num_reqs, ))
self.assertEqual(self.input_batch.min_p_cpu_tensor.shape,
(self.max_num_reqs, ))
# numpy shape assertions
self.assertEqual(self.input_batch.token_ids_cpu.shape,
(self.max_num_reqs, self.max_model_len))
self.assertEqual(self.input_batch.num_tokens.shape,
(self.max_num_reqs, ))
self.assertEqual(self.input_batch.num_tokens.shape,
(self.max_num_reqs, ))
# type assertions
self.assertIsInstance(self.input_batch.greedy_reqs, set)
self.assertIsInstance(self.input_batch.req_id_to_index, dict)
self.assertIsInstance(self.input_batch.sampling_metadata,
SamplingMetadata)
self.assertIsInstance(self.input_batch.block_table,
MultiGroupBlockTable)
self.assertIsNone(self.input_batch.allowed_token_ids_mask)
self.assertIsNone(self.input_batch.allowed_token_ids_mask_cpu_tensor)
def test_add_request(self):
# case1: add a new req
self.input_batch.add_request(self.cached_request_state)
self.assertIn(self.cached_request_state.req_id,
self.input_batch.req_id_to_index)
req_index = self.input_batch.req_id_to_index[
self.cached_request_state.req_id]
self.assertEqual(self.input_batch.num_prompt_tokens[req_index],
len(self.cached_request_state.prompt_token_ids))
self.assertEqual(self.input_batch.num_tokens[req_index],
self.cached_request_state.num_tokens)
# case2: add an existing req, maybe need update
self.cached_request_state.output_token_ids.extend([7, 8, 9])
self.cached_request_state.num_computed_tokens += 3
cached_index = self.input_batch.req_id_to_index[
self.cached_request_state.req_id]
self.input_batch.add_request(self.cached_request_state, cached_index)
# check if this index in the input_batch is updated
# This np arrat "token_ids_cpu" should be filled with prompt_token_ids + output_token_ids
self.assertTrue(
np.all(self.input_batch.token_ids_cpu[
cached_index, :self.cached_request_state.num_tokens]),
msg=f"Token IDs at index {cached_index} did not update correctly.")
# case3: add req that greater than max_num_reqs
with self.assertRaises(AssertionError):
self.input_batch.add_request(self.cached_request_state,
req_index=self.max_num_reqs)
# case4: add req that out of max_model_len
long_prompt = list(range(self.max_model_len + 1))
long_request = mock_cached_request_state(req_id="2",
prompt=long_prompt,
output=[10])
with self.assertRaises(ValueError) as cm:
self.input_batch.add_request(long_request)
self.assertIn("could not broadcast", str(cm.exception))
def test_remove_request(self):
self.input_batch.add_request(self.cached_request_state)
req_index = self.input_batch.remove_request(
self.cached_request_state.req_id)
self.assertIsNotNone(req_index)
self.assertNotIn(self.cached_request_state.req_id,
self.input_batch.req_id_to_index)
self.assertIsNone(self.input_batch._req_ids[req_index])
def test_condense(self):
# Let's say we have some requests like below
# Index Req ID
# 0 1
# 1 2
# 2 3
# 3 4
for i in range(4):
request = mock_cached_request_state(req_id=str(i + 1))
self.input_batch.add_request(request)
removed_req_indices = []
id_to_remove = ["2", "4"] # IDs to remove
for req_id in id_to_remove:
removed_index = self.input_batch.remove_request(req_id)
if removed_index is not None:
removed_req_indices.append(removed_index)
self.assertEqual(len(removed_req_indices), len(id_to_remove))
self.input_batch.condense(sorted(removed_req_indices, reverse=True))
# Check if the remaining requests are condensed correctly
indices = [
self.input_batch.req_id_to_index[req_id] for req_id in ["1", "3"]
]
self.assertTrue(all(idx < self.input_batch.num_reqs
for idx in indices))
for i in range(self.input_batch.num_reqs):
self.assertIsNotNone(self.input_batch._req_ids[i])
for i in range(self.input_batch.num_reqs,
len(self.input_batch._req_ids)):
self.assertIsNone(self.input_batch._req_ids[i])
for req_id in ["1", "3"]:
idx = self.input_batch.req_id_to_index[req_id]
tokens = self.input_batch.token_ids_cpu[idx]
self.assertTrue(
tokens.any(),
f"Tokens at index {idx} for req {req_id} should not be all zero"
)

View File

@@ -23,10 +23,13 @@ import signal
import subprocess
import sys
import time
from collections.abc import Sequence
from typing import Callable, Optional
import openai
import requests
import torch
import torch.nn.functional as F
from typing_extensions import ParamSpec
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
@@ -197,3 +200,37 @@ def fork_new_process_for_each_test(
f" args {args} and kwargs {kwargs}")
return wrapper
def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
tensor = torch.tensor(tensor)
tensor = tensor[..., :dimensions]
tensor = F.normalize(tensor, p=2, dim=1)
return tensor
def check_embeddings_close(
*,
embeddings_0_lst: Sequence[list[float]],
embeddings_1_lst: Sequence[list[float]],
name_0: str,
name_1: str,
tol: float = 1e-3,
) -> None:
assert len(embeddings_0_lst) == len(embeddings_1_lst)
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
zip(embeddings_0_lst, embeddings_1_lst)):
assert len(embeddings_0) == len(embeddings_1), (
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
sim = F.cosine_similarity(torch.tensor(embeddings_0),
torch.tensor(embeddings_1),
dim=0)
fail_msg = (f"Test{prompt_idx}:"
f"\nCosine similarity: \t{sim:.4f}"
f"\n{name_0}:\t{embeddings_0[:16]!r}"
f"\n{name_1}:\t{embeddings_1[:16]!r}")
assert sim >= 1 - tol, fail_msg