[V1][ModelRunner] Support pooling model for v1 engine (#1359)

### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os

import torch
from vllm import LLM


os.environ["VLLM_USE_MODELSCOPE"]="True"

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery:{query}'

# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'

queries = [
    get_detailed_instruct(task, 'What is the capital of China?'),
    get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents

model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")

outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------

Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
This commit is contained in:
Li Wang
2025-06-30 16:31:12 +08:00
committed by GitHub
parent 790c810bf7
commit 5f8241c25c
10 changed files with 1312 additions and 43 deletions

View File

@@ -19,18 +19,23 @@
import contextlib
import gc
from typing import List, Optional, Tuple, TypeVar, Union
from typing import Any, List, Optional, Tuple, TypeVar, Union
import numpy as np
import pytest
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BatchEncoding, BatchFeature)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm import LLM, SamplingParams
from vllm.config import TaskOption
from vllm.config import TaskOption, _get_and_verify_dtype
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import is_list_of
from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs,
@@ -45,6 +50,7 @@ adapt_patch(True)
from vllm.distributed.parallel_state import ( # noqa E402
destroy_distributed_environment, destroy_model_parallel)
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_M = TypeVar("_M")
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
@@ -364,3 +370,131 @@ def prompt_template(request):
@pytest.fixture(scope="session")
def ilama_lora_files():
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")
class HfRunner:
def get_default_device(self):
from vllm.platforms import current_platform
return ("cpu"
if current_platform.is_cpu() else current_platform.device_type)
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if x is None or isinstance(x, (bool, )):
return x
if device is None:
device = self.device
if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}
if hasattr(x, "device") and x.device.type == device:
return x
return x.to(device)
def __init__(
self,
model_name: str,
dtype: str = "auto",
*,
model_kwargs: Optional[dict[str, Any]] = None,
trust_remote_code: bool = True,
is_sentence_transformer: bool = False,
is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False,
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
) -> None:
model_name = maybe_model_redirect(model_name)
self.model_name = model_name
self.config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
)
self.device = self.get_default_device()
self.dtype = torch_dtype = _get_and_verify_dtype(
self.model_name,
self.config,
dtype=dtype,
is_pooling_model=is_sentence_transformer or is_cross_encoder,
)
model_kwargs = model_kwargs if model_kwargs is not None else {}
model_kwargs.setdefault("torch_dtype", torch_dtype)
if is_sentence_transformer:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_name,
device=self.device,
model_kwargs=model_kwargs,
trust_remote_code=trust_remote_code,
)
elif is_cross_encoder:
# Lazy init required for AMD CI
from sentence_transformers import CrossEncoder
self.model = CrossEncoder(
model_name,
device=self.device,
automodel_args=model_kwargs,
trust_remote_code=trust_remote_code,
)
else:
model = auto_cls.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
**model_kwargs,
)
# in case some unquantized custom models are not in same dtype
if (getattr(model, "quantization_method", None) is None
and any(p.dtype != self.dtype
for p in model.parameters())):
model = model.to(dtype=self.dtype)
if (getattr(model, "quantization_method", None) != "bitsandbytes"
and len({p.device
for p in model.parameters()}) < 2):
model = model.to(device=self.device)
self.model = model
if not skip_tokenizer_init:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
self.processor = AutoProcessor.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
if skip_tokenizer_init:
self.tokenizer = self.processor.tokenizer
def encode(self, prompts: list[str], *args,
**kwargs) -> list[list[torch.Tensor]]:
return self.model.encode(prompts, *args, **kwargs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup_dist_env_and_memory()
@pytest.fixture(scope="session")
def hf_runner():
return HfRunner