[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it? Change as little existing code as possible to add v1 pooling task's support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to vllm-ascend, Considering the frequent changes in upstream interfaces, in order to decouple, so i move it here ### How was this patch tested? CI passed with new added/existing test, and I have a simple test was first conducted locally which is adapted from https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like bellow: ```python import os import torch from vllm import LLM os.environ["VLLM_USE_MODELSCOPE"]="True" def get_detailed_instruct(task_description: str, query: str) -> str: return f'Instruct: {task_description}\nQuery:{query}' # Each query must come with a one-sentence instruction that describes the task task = 'Given a web search query, retrieve relevant passages that answer the query' queries = [ get_detailed_instruct(task, 'What is the capital of China?'), get_detailed_instruct(task, 'Explain gravity') ] # No need to add instruction for retrieval documents documents = [ "The capital of China is Beijing.", "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." ] input_texts = queries + documents model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed") outputs = model.embed(input_texts) embeddings = torch.tensor([o.outputs.embedding for o in outputs]) scores = (embeddings[:2] @ embeddings[2:].T) print(scores.tolist()) # [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]] ``` --------- Signed-off-by: wangli <wangli858794774@gmail.com> Signed-off-by: wangli <858794774@qq.com> Co-authored-by: wangli <858794774@qq.com>
This commit is contained in:
@@ -19,18 +19,23 @@
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
from typing import List, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||
BatchEncoding, BatchFeature)
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import TaskOption
|
||||
from vllm.config import TaskOption, _get_and_verify_dtype
|
||||
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs,
|
||||
@@ -45,6 +50,7 @@ adapt_patch(True)
|
||||
from vllm.distributed.parallel_state import ( # noqa E402
|
||||
destroy_distributed_environment, destroy_model_parallel)
|
||||
|
||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
||||
_M = TypeVar("_M")
|
||||
|
||||
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
|
||||
@@ -364,3 +370,131 @@ def prompt_template(request):
|
||||
@pytest.fixture(scope="session")
|
||||
def ilama_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")
|
||||
|
||||
|
||||
class HfRunner:
|
||||
|
||||
def get_default_device(self):
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return ("cpu"
|
||||
if current_platform.is_cpu() else current_platform.device_type)
|
||||
|
||||
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
|
||||
if x is None or isinstance(x, (bool, )):
|
||||
return x
|
||||
|
||||
if device is None:
|
||||
device = self.device
|
||||
|
||||
if isinstance(x, dict):
|
||||
return {k: self.wrap_device(v, device) for k, v in x.items()}
|
||||
|
||||
if hasattr(x, "device") and x.device.type == device:
|
||||
return x
|
||||
|
||||
return x.to(device)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
dtype: str = "auto",
|
||||
*,
|
||||
model_kwargs: Optional[dict[str, Any]] = None,
|
||||
trust_remote_code: bool = True,
|
||||
is_sentence_transformer: bool = False,
|
||||
is_cross_encoder: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||||
) -> None:
|
||||
model_name = maybe_model_redirect(model_name)
|
||||
self.model_name = model_name
|
||||
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
self.device = self.get_default_device()
|
||||
self.dtype = torch_dtype = _get_and_verify_dtype(
|
||||
self.model_name,
|
||||
self.config,
|
||||
dtype=dtype,
|
||||
is_pooling_model=is_sentence_transformer or is_cross_encoder,
|
||||
)
|
||||
|
||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||
model_kwargs.setdefault("torch_dtype", torch_dtype)
|
||||
|
||||
if is_sentence_transformer:
|
||||
# Lazy init required for AMD CI
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
self.model = SentenceTransformer(
|
||||
model_name,
|
||||
device=self.device,
|
||||
model_kwargs=model_kwargs,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif is_cross_encoder:
|
||||
# Lazy init required for AMD CI
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
self.model = CrossEncoder(
|
||||
model_name,
|
||||
device=self.device,
|
||||
automodel_args=model_kwargs,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
model = auto_cls.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# in case some unquantized custom models are not in same dtype
|
||||
if (getattr(model, "quantization_method", None) is None
|
||||
and any(p.dtype != self.dtype
|
||||
for p in model.parameters())):
|
||||
model = model.to(dtype=self.dtype)
|
||||
|
||||
if (getattr(model, "quantization_method", None) != "bitsandbytes"
|
||||
and len({p.device
|
||||
for p in model.parameters()}) < 2):
|
||||
model = model.to(device=self.device)
|
||||
|
||||
self.model = model
|
||||
|
||||
if not skip_tokenizer_init:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor # noqa: F401
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if skip_tokenizer_init:
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
|
||||
def encode(self, prompts: list[str], *args,
|
||||
**kwargs) -> list[list[torch.Tensor]]:
|
||||
return self.model.encode(prompts, *args, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
del self.model
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_runner():
|
||||
return HfRunner
|
||||
|
||||
72
tests/e2e/singlecard/test_embedding.py
Normal file
72
tests/e2e/singlecard/test_embedding.py
Normal file
@@ -0,0 +1,72 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||
#
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from modelscope import snapshot_download # type: ignore[import-untyped]
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
from tests.utils import check_embeddings_close, matryoshka_fy
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
def run_embedding_correctness_test(
|
||||
hf_model: "HfRunner",
|
||||
inputs: list[str],
|
||||
vllm_outputs: Sequence[list[float]],
|
||||
dimensions: Optional[int] = None,
|
||||
):
|
||||
hf_outputs = hf_model.encode(inputs)
|
||||
if dimensions:
|
||||
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
# dummy to avoid pytest collect nothing and exit code 5
|
||||
def test_dummy():
|
||||
assert True
|
||||
|
||||
|
||||
@pytest.mark.skipif(vllm_version_is("0.9.1"),
|
||||
reason="vLLM 0.9.1 does not support embed task for v1")
|
||||
def test_embed_models_correctness(hf_runner, vllm_runner):
|
||||
queries = ['What is the capital of China?', 'Explain gravity']
|
||||
|
||||
model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B")
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
task="embed",
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(queries)
|
||||
|
||||
with hf_runner(
|
||||
model_name,
|
||||
dtype="float32",
|
||||
is_sentence_transformer=True,
|
||||
) as hf_model:
|
||||
run_embedding_correctness_test(hf_model, queries, vllm_outputs)
|
||||
162
tests/ut/worker/test_input_batch.py
Normal file
162
tests/ut/worker/test_input_batch.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
|
||||
def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]):
|
||||
return CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=prompt,
|
||||
mm_inputs=[],
|
||||
mm_positions=[],
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
generator=None,
|
||||
block_ids=([], ),
|
||||
num_computed_tokens=0,
|
||||
output_token_ids=output,
|
||||
)
|
||||
|
||||
|
||||
class TestInputBatch(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.max_num_reqs = 10
|
||||
self.max_model_len = 32
|
||||
self.max_num_batched_tokens = 132
|
||||
self.vocab_size = 1000
|
||||
self.device = torch.device("cpu")
|
||||
self.block_sizes = [128]
|
||||
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
device=self.device,
|
||||
pin_memory=False,
|
||||
vocab_size=self.vocab_size,
|
||||
block_sizes=self.block_sizes,
|
||||
)
|
||||
self.cached_request_state = mock_cached_request_state()
|
||||
|
||||
def test_shapes_and_defaults(self):
|
||||
# torch tensor shape assertions
|
||||
self.assertEqual(self.input_batch.token_ids_cpu_tensor.shape,
|
||||
(self.max_num_reqs, self.max_model_len))
|
||||
self.assertEqual(self.input_batch.temperature.shape,
|
||||
(self.max_num_reqs, ))
|
||||
self.assertEqual(self.input_batch.top_k.shape, (self.max_num_reqs, ))
|
||||
self.assertEqual(self.input_batch.min_p_cpu_tensor.shape,
|
||||
(self.max_num_reqs, ))
|
||||
|
||||
# numpy shape assertions
|
||||
self.assertEqual(self.input_batch.token_ids_cpu.shape,
|
||||
(self.max_num_reqs, self.max_model_len))
|
||||
self.assertEqual(self.input_batch.num_tokens.shape,
|
||||
(self.max_num_reqs, ))
|
||||
self.assertEqual(self.input_batch.num_tokens.shape,
|
||||
(self.max_num_reqs, ))
|
||||
|
||||
# type assertions
|
||||
self.assertIsInstance(self.input_batch.greedy_reqs, set)
|
||||
self.assertIsInstance(self.input_batch.req_id_to_index, dict)
|
||||
self.assertIsInstance(self.input_batch.sampling_metadata,
|
||||
SamplingMetadata)
|
||||
self.assertIsInstance(self.input_batch.block_table,
|
||||
MultiGroupBlockTable)
|
||||
self.assertIsNone(self.input_batch.allowed_token_ids_mask)
|
||||
self.assertIsNone(self.input_batch.allowed_token_ids_mask_cpu_tensor)
|
||||
|
||||
def test_add_request(self):
|
||||
# case1: add a new req
|
||||
self.input_batch.add_request(self.cached_request_state)
|
||||
self.assertIn(self.cached_request_state.req_id,
|
||||
self.input_batch.req_id_to_index)
|
||||
req_index = self.input_batch.req_id_to_index[
|
||||
self.cached_request_state.req_id]
|
||||
self.assertEqual(self.input_batch.num_prompt_tokens[req_index],
|
||||
len(self.cached_request_state.prompt_token_ids))
|
||||
self.assertEqual(self.input_batch.num_tokens[req_index],
|
||||
self.cached_request_state.num_tokens)
|
||||
|
||||
# case2: add an existing req, maybe need update
|
||||
self.cached_request_state.output_token_ids.extend([7, 8, 9])
|
||||
self.cached_request_state.num_computed_tokens += 3
|
||||
cached_index = self.input_batch.req_id_to_index[
|
||||
self.cached_request_state.req_id]
|
||||
self.input_batch.add_request(self.cached_request_state, cached_index)
|
||||
# check if this index in the input_batch is updated
|
||||
# This np arrat "token_ids_cpu" should be filled with prompt_token_ids + output_token_ids
|
||||
self.assertTrue(
|
||||
np.all(self.input_batch.token_ids_cpu[
|
||||
cached_index, :self.cached_request_state.num_tokens]),
|
||||
msg=f"Token IDs at index {cached_index} did not update correctly.")
|
||||
|
||||
# case3: add req that greater than max_num_reqs
|
||||
with self.assertRaises(AssertionError):
|
||||
self.input_batch.add_request(self.cached_request_state,
|
||||
req_index=self.max_num_reqs)
|
||||
|
||||
# case4: add req that out of max_model_len
|
||||
long_prompt = list(range(self.max_model_len + 1))
|
||||
long_request = mock_cached_request_state(req_id="2",
|
||||
prompt=long_prompt,
|
||||
output=[10])
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
self.input_batch.add_request(long_request)
|
||||
self.assertIn("could not broadcast", str(cm.exception))
|
||||
|
||||
def test_remove_request(self):
|
||||
self.input_batch.add_request(self.cached_request_state)
|
||||
req_index = self.input_batch.remove_request(
|
||||
self.cached_request_state.req_id)
|
||||
self.assertIsNotNone(req_index)
|
||||
self.assertNotIn(self.cached_request_state.req_id,
|
||||
self.input_batch.req_id_to_index)
|
||||
self.assertIsNone(self.input_batch._req_ids[req_index])
|
||||
|
||||
def test_condense(self):
|
||||
# Let's say we have some requests like below
|
||||
# Index Req ID
|
||||
# 0 1
|
||||
# 1 2
|
||||
# 2 3
|
||||
# 3 4
|
||||
for i in range(4):
|
||||
request = mock_cached_request_state(req_id=str(i + 1))
|
||||
self.input_batch.add_request(request)
|
||||
removed_req_indices = []
|
||||
id_to_remove = ["2", "4"] # IDs to remove
|
||||
for req_id in id_to_remove:
|
||||
removed_index = self.input_batch.remove_request(req_id)
|
||||
if removed_index is not None:
|
||||
removed_req_indices.append(removed_index)
|
||||
self.assertEqual(len(removed_req_indices), len(id_to_remove))
|
||||
self.input_batch.condense(sorted(removed_req_indices, reverse=True))
|
||||
|
||||
# Check if the remaining requests are condensed correctly
|
||||
indices = [
|
||||
self.input_batch.req_id_to_index[req_id] for req_id in ["1", "3"]
|
||||
]
|
||||
self.assertTrue(all(idx < self.input_batch.num_reqs
|
||||
for idx in indices))
|
||||
|
||||
for i in range(self.input_batch.num_reqs):
|
||||
self.assertIsNotNone(self.input_batch._req_ids[i])
|
||||
for i in range(self.input_batch.num_reqs,
|
||||
len(self.input_batch._req_ids)):
|
||||
self.assertIsNone(self.input_batch._req_ids[i])
|
||||
|
||||
for req_id in ["1", "3"]:
|
||||
idx = self.input_batch.req_id_to_index[req_id]
|
||||
tokens = self.input_batch.token_ids_cpu[idx]
|
||||
self.assertTrue(
|
||||
tokens.any(),
|
||||
f"Tokens at index {idx} for req {req_id} should not be all zero"
|
||||
)
|
||||
@@ -23,10 +23,13 @@ import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional
|
||||
|
||||
import openai
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing_extensions import ParamSpec
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
@@ -197,3 +200,37 @@ def fork_new_process_for_each_test(
|
||||
f" args {args} and kwargs {kwargs}")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
|
||||
tensor = torch.tensor(tensor)
|
||||
tensor = tensor[..., :dimensions]
|
||||
tensor = F.normalize(tensor, p=2, dim=1)
|
||||
return tensor
|
||||
|
||||
|
||||
def check_embeddings_close(
|
||||
*,
|
||||
embeddings_0_lst: Sequence[list[float]],
|
||||
embeddings_1_lst: Sequence[list[float]],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
tol: float = 1e-3,
|
||||
) -> None:
|
||||
assert len(embeddings_0_lst) == len(embeddings_1_lst)
|
||||
|
||||
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
|
||||
zip(embeddings_0_lst, embeddings_1_lst)):
|
||||
assert len(embeddings_0) == len(embeddings_1), (
|
||||
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
|
||||
|
||||
sim = F.cosine_similarity(torch.tensor(embeddings_0),
|
||||
torch.tensor(embeddings_1),
|
||||
dim=0)
|
||||
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\nCosine similarity: \t{sim:.4f}"
|
||||
f"\n{name_0}:\t{embeddings_0[:16]!r}"
|
||||
f"\n{name_1}:\t{embeddings_1[:16]!r}")
|
||||
|
||||
assert sim >= 1 - tol, fail_msg
|
||||
|
||||
Reference in New Issue
Block a user