Sync from v0.13
This commit is contained in:
53
examples/pooling/classify/openai_classification_client.py
Normal file
53
examples/pooling/classify/openai_classification_client.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Example Python client for classification API using vLLM API server
|
||||
NOTE:
|
||||
start a supported classification model server with `vllm serve`, e.g.
|
||||
vllm serve jason9693/Qwen2.5-1.5B-apeach
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def post_http_request(payload: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=payload)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parse = argparse.ArgumentParser()
|
||||
parse.add_argument("--host", type=str, default="localhost")
|
||||
parse.add_argument("--port", type=int, default=8000)
|
||||
parse.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
|
||||
return parse.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
host = args.host
|
||||
port = args.port
|
||||
model_name = args.model
|
||||
|
||||
api_url = f"http://{host}:{port}/classify"
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"input": prompts,
|
||||
}
|
||||
|
||||
classify_response = post_http_request(payload=payload, api_url=api_url)
|
||||
pprint.pprint(classify_response.json())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
60
examples/pooling/embed/embed_jina_embeddings_v3.py
Normal file
60
examples/pooling/embed/embed_jina_embeddings_v3.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="jinaai/jina-embeddings-v3",
|
||||
runner="pooling",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Follow the white rabbit.", # English
|
||||
"Sigue al conejo blanco.", # Spanish
|
||||
"Suis le lapin blanc.", # French
|
||||
"跟着白兔走。", # Chinese
|
||||
"اتبع الأرنب الأبيض.", # Arabic
|
||||
"Folge dem weißen Kaninchen.", # German
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass runner="pooling" for embedding models
|
||||
llm = LLM(**vars(args))
|
||||
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
# Only text matching task is supported for now. See #16120
|
||||
outputs = llm.embed(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:")
|
||||
print("Only text matching task is supported for now. See #16120")
|
||||
print("-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
embeds_trimmed = (
|
||||
(str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
|
||||
)
|
||||
print(
|
||||
f"Prompt: {prompt!r} \n"
|
||||
f"Embeddings for text matching: {embeds_trimmed} "
|
||||
f"(size={len(embeds)})"
|
||||
)
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
54
examples/pooling/embed/embed_matryoshka_fy.py
Normal file
54
examples/pooling/embed/embed_matryoshka_fy.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs, PoolingParams
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="jinaai/jina-embeddings-v3",
|
||||
runner="pooling",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Follow the white rabbit.", # English
|
||||
"Sigue al conejo blanco.", # Spanish
|
||||
"Suis le lapin blanc.", # French
|
||||
"跟着白兔走。", # Chinese
|
||||
"اتبع الأرنب الأبيض.", # Arabic
|
||||
"Folge dem weißen Kaninchen.", # German
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass runner="pooling" for embedding models
|
||||
llm = LLM(**vars(args))
|
||||
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
outputs = llm.embed(prompts, pooling_params=PoolingParams(dimensions=32))
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:")
|
||||
print("-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
embeds_trimmed = (
|
||||
(str(embeds[:16])[:-1] + ", ...]") if len(embeds) > 16 else embeds
|
||||
)
|
||||
print(f"Prompt: {prompt!r} \nEmbeddings: {embeds_trimmed} (size={len(embeds)})")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
64
examples/pooling/embed/embedding_requests_base64_client.py
Normal file
64
examples/pooling/embed/embedding_requests_base64_client.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Example Python client for embedding API using vLLM API server
|
||||
NOTE:
|
||||
start a supported embeddings model server with `vllm serve`, e.g.
|
||||
vllm serve intfloat/e5-small
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from vllm.utils.serial_utils import (
|
||||
EMBED_DTYPE_TO_TORCH_DTYPE,
|
||||
ENDIANNESS,
|
||||
binary2tensor,
|
||||
)
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="intfloat/e5-small")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/v1/embeddings"
|
||||
model_name = args.model
|
||||
|
||||
# The OpenAI client does not support the embed_dtype and endianness parameters.
|
||||
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
|
||||
for endianness in ENDIANNESS:
|
||||
prompt = {
|
||||
"model": model_name,
|
||||
"input": "vLLM is great!",
|
||||
"encoding_format": "base64",
|
||||
"embed_dtype": embed_dtype,
|
||||
"endianness": endianness,
|
||||
}
|
||||
response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
|
||||
embedding = []
|
||||
for data in response.json()["data"]:
|
||||
binary = base64.b64decode(data["embedding"])
|
||||
tensor = binary2tensor(binary, (-1,), embed_dtype, endianness)
|
||||
embedding.append(tensor.to(torch.float32))
|
||||
embedding = torch.cat(embedding)
|
||||
print(embed_dtype, endianness, embedding.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
99
examples/pooling/embed/embedding_requests_bytes_client.py
Normal file
99
examples/pooling/embed/embedding_requests_bytes_client.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Example Python client for embedding API using vLLM API server
|
||||
NOTE:
|
||||
start a supported embeddings model server with `vllm serve`, e.g.
|
||||
vllm serve intfloat/e5-small
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from vllm.utils.serial_utils import (
|
||||
EMBED_DTYPE_TO_TORCH_DTYPE,
|
||||
ENDIANNESS,
|
||||
MetadataItem,
|
||||
build_metadata_items,
|
||||
decode_pooling_output,
|
||||
)
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="intfloat/e5-small")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/v1/embeddings"
|
||||
model_name = args.model
|
||||
embedding_size = 0
|
||||
|
||||
input_texts = [
|
||||
"The best thing about vLLM is that it supports many different models",
|
||||
] * 2
|
||||
|
||||
# The OpenAI client does not support the bytes encoding_format.
|
||||
# The OpenAI client does not support the embed_dtype and endianness parameters.
|
||||
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
|
||||
for endianness in ENDIANNESS:
|
||||
prompt = {
|
||||
"model": model_name,
|
||||
"input": input_texts,
|
||||
"encoding_format": "bytes",
|
||||
"embed_dtype": embed_dtype,
|
||||
"endianness": endianness,
|
||||
}
|
||||
response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
metadata = json.loads(response.headers["metadata"])
|
||||
body = response.content
|
||||
items = [MetadataItem(**x) for x in metadata["data"]]
|
||||
|
||||
embedding = decode_pooling_output(items=items, body=body)
|
||||
embedding = [x.to(torch.float32) for x in embedding]
|
||||
embedding = torch.stack(embedding)
|
||||
embedding_size = embedding.shape[-1]
|
||||
print(embed_dtype, endianness, embedding.shape)
|
||||
|
||||
# The vllm server always sorts the returned embeddings in the order of input. So
|
||||
# returning metadata is not necessary. You can set encoding_format to bytes_only
|
||||
# to let the server not return metadata.
|
||||
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
|
||||
for endianness in ENDIANNESS:
|
||||
prompt = {
|
||||
"model": model_name,
|
||||
"input": input_texts,
|
||||
"encoding_format": "bytes_only",
|
||||
"embed_dtype": embed_dtype,
|
||||
"endianness": endianness,
|
||||
}
|
||||
response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
body = response.content
|
||||
|
||||
items = build_metadata_items(
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
shape=(embedding_size,),
|
||||
n_request=len(input_texts),
|
||||
)
|
||||
embedding = decode_pooling_output(items=items, body=body)
|
||||
embedding = [x.to(torch.float32) for x in embedding]
|
||||
embedding = torch.stack(embedding)
|
||||
print(embed_dtype, endianness, embedding.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,293 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
"""Example Python client for multimodal embedding API using vLLM API server.
|
||||
|
||||
Refer to each `run_*` function for the command to run the server for that model.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import io
|
||||
from typing import Literal
|
||||
|
||||
from openai import OpenAI
|
||||
from openai._types import NOT_GIVEN, NotGiven
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from openai.types.create_embedding_response import CreateEmbeddingResponse
|
||||
from PIL import Image
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
image_url = "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
|
||||
|
||||
def create_chat_embeddings(
|
||||
client: OpenAI,
|
||||
*,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model: str,
|
||||
encoding_format: Literal["base64", "float"] | NotGiven = NOT_GIVEN,
|
||||
) -> CreateEmbeddingResponse:
|
||||
"""
|
||||
Convenience function for accessing vLLM's Chat Embeddings API,
|
||||
which is an extension of OpenAI's existing Embeddings API.
|
||||
"""
|
||||
return client.post(
|
||||
"/embeddings",
|
||||
cast_to=CreateEmbeddingResponse,
|
||||
body={"messages": messages, "model": model, "encoding_format": encoding_format},
|
||||
)
|
||||
|
||||
|
||||
def run_clip(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
|
||||
vllm serve openai/clip-vit-base-patch32 \
|
||||
--runner pooling
|
||||
"""
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Image embedding output:", response.data[0].embedding)
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "a photo of a cat"},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
|
||||
|
||||
def run_dse_qwen2_vl(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
|
||||
vllm serve MrLight/dse-qwen2-2b-mrl-v1 \
|
||||
--runner pooling \
|
||||
--trust-remote-code \
|
||||
--max-model-len 8192 \
|
||||
--chat-template examples/template_dse_qwen2_vl.jinja
|
||||
"""
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url,
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Image embedding output:", response.data[0].embedding)
|
||||
|
||||
# MrLight/dse-qwen2-2b-mrl-v1 requires a placeholder image
|
||||
# of the minimum input size
|
||||
buffer = io.BytesIO()
|
||||
image_placeholder = Image.new("RGB", (56, 56))
|
||||
image_placeholder.save(buffer, "png")
|
||||
buffer.seek(0)
|
||||
image_placeholder = base64.b64encode(buffer.read()).decode("utf-8")
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_placeholder}",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Query: What is the weather like today?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
|
||||
|
||||
def run_siglip(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
|
||||
vllm serve google/siglip-base-patch16-224 \
|
||||
--runner pooling \
|
||||
--chat-template template_basic.jinja
|
||||
"""
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Image embedding output:", response.data[0].embedding)
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "a photo of a cat"},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
|
||||
|
||||
def run_vlm2vec(client: OpenAI, model: str):
|
||||
"""
|
||||
Start the server using:
|
||||
|
||||
vllm serve TIGER-Lab/VLM2Vec-Full \
|
||||
--runner pooling \
|
||||
--trust-remote-code \
|
||||
--max-model-len 4096 \
|
||||
--chat-template examples/template_vlm2vec_phi3v.jinja
|
||||
"""
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{"type": "text", "text": "Represent the given image."},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Image embedding output:", response.data[0].embedding)
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Represent the given image with the following question: What is in the image.",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Image+Text embedding output:", response.data[0].embedding)
|
||||
|
||||
response = create_chat_embeddings(
|
||||
client,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "A cat and a dog"},
|
||||
],
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
print("Text embedding output:", response.data[0].embedding)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"clip": run_clip,
|
||||
"dse_qwen2_vl": run_dse_qwen2_vl,
|
||||
"siglip": run_siglip,
|
||||
"vlm2vec": run_vlm2vec,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
"Script to call a specified VLM through the API. Make sure to serve "
|
||||
"the model with `--runner pooling` before running this."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
choices=model_example_map.keys(),
|
||||
required=True,
|
||||
help="The name of the embedding model.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model_id = models.data[0].id
|
||||
|
||||
model_example_map[args.model](client, model_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
40
examples/pooling/embed/openai_embedding_client.py
Normal file
40
examples/pooling/embed/openai_embedding_client.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Example Python client for embedding API using vLLM API server
|
||||
NOTE:
|
||||
start a supported embeddings model server with `vllm serve`, e.g.
|
||||
vllm serve intfloat/e5-small
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
|
||||
def main():
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
responses = client.embeddings.create(
|
||||
# ruff: noqa: E501
|
||||
input=[
|
||||
"Hello my name is",
|
||||
"The best thing about vLLM is that it supports many different models",
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
|
||||
for data in responses.data:
|
||||
print(data.embedding) # List of float of len 4096
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
186
examples/pooling/embed/openai_embedding_long_text/README.md
Normal file
186
examples/pooling/embed/openai_embedding_long_text/README.md
Normal file
@@ -0,0 +1,186 @@
|
||||
# Long Text Embedding with Chunked Processing
|
||||
|
||||
This directory contains examples for using vLLM's **chunked processing** feature to handle long text embedding that exceeds the model's maximum context length.
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Start the Server
|
||||
|
||||
Use the provided script to start a vLLM server with chunked processing enabled:
|
||||
|
||||
```bash
|
||||
# Basic usage (supports very long texts up to ~3M tokens)
|
||||
./service.sh
|
||||
|
||||
# Custom configuration with different models
|
||||
MODEL_NAME="jinaai/jina-embeddings-v3" \
|
||||
MAX_EMBED_LEN=1048576 \
|
||||
./service.sh
|
||||
|
||||
# For extremely long documents
|
||||
MODEL_NAME="intfloat/multilingual-e5-large" \
|
||||
MAX_EMBED_LEN=3072000 \
|
||||
./service.sh
|
||||
```
|
||||
|
||||
### Test Long Text Embedding
|
||||
|
||||
Run the comprehensive test client:
|
||||
|
||||
```bash
|
||||
python client.py
|
||||
```
|
||||
|
||||
## 📁 Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `service.sh` | Server startup script with chunked processing enabled |
|
||||
| `client.py` | Comprehensive test client for long text embedding |
|
||||
|
||||
## ⚙️ Configuration
|
||||
|
||||
### Server Configuration
|
||||
|
||||
The key parameters for chunked processing are in the `--pooler-config`:
|
||||
|
||||
```json
|
||||
{
|
||||
"pooling_type": "auto",
|
||||
"normalize": true,
|
||||
"enable_chunked_processing": true,
|
||||
"max_embed_len": 3072000
|
||||
}
|
||||
```
|
||||
|
||||
!!! note
|
||||
`pooling_type` sets the model's own pooling strategy for processing within each chunk. The cross-chunk aggregation automatically uses MEAN strategy when input exceeds the model's native maximum length.
|
||||
|
||||
#### Chunked Processing Behavior
|
||||
|
||||
Chunked processing uses **MEAN aggregation** for cross-chunk combination when input exceeds the model's native maximum length:
|
||||
|
||||
| Component | Behavior | Description |
|
||||
|-----------|----------|-------------|
|
||||
| **Within chunks** | Model's native pooling | Uses the model's configured pooling strategy |
|
||||
| **Cross-chunk aggregation** | Always MEAN | Weighted averaging based on chunk token counts |
|
||||
| **Performance** | Optimal | All chunks processed for complete semantic coverage |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `MODEL_NAME` | `intfloat/multilingual-e5-large` | Embedding model to use (supports multiple models) |
|
||||
| `PORT` | `31090` | Server port |
|
||||
| `GPU_COUNT` | `1` | Number of GPUs to use |
|
||||
| `MAX_EMBED_LEN` | `3072000` | Maximum embedding input length (supports very long documents) |
|
||||
| `POOLING_TYPE` | `auto` | Model's native pooling type: `auto`, `MEAN`, `CLS`, `LAST` (only affects within-chunk pooling, not cross-chunk aggregation) |
|
||||
| `API_KEY` | `EMPTY` | API key for authentication |
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
1. **Enhanced Input Validation**: `max_embed_len` allows accepting inputs longer than `max_model_len` without environment variables
|
||||
2. **Smart Chunking**: Text is split based on `max_position_embeddings` to maintain semantic integrity
|
||||
3. **Unified Processing**: All chunks processed separately through the model using its configured pooling strategy
|
||||
4. **MEAN Aggregation**: When input exceeds model's native length, results combined using token count-based weighted averaging across all chunks
|
||||
5. **Consistent Output**: Final embeddings maintain the same dimensionality as standard processing
|
||||
|
||||
### Input Length Handling
|
||||
|
||||
- **Within max_embed_len**: Input is accepted and processed (up to 3M+ tokens)
|
||||
- **Exceeds max_position_embeddings**: Chunked processing is automatically triggered
|
||||
- **Exceeds max_embed_len**: Input is rejected with clear error message
|
||||
- **No environment variables required**: Works without `VLLM_ALLOW_LONG_MAX_MODEL_LEN`
|
||||
|
||||
### Extreme Long Text Support
|
||||
|
||||
With `MAX_EMBED_LEN=3072000`, you can process:
|
||||
|
||||
- **Academic papers**: Full research papers with references
|
||||
- **Legal documents**: Complete contracts and legal texts
|
||||
- **Books**: Entire chapters or small books
|
||||
- **Code repositories**: Large codebases and documentation
|
||||
|
||||
## 📊 Performance Characteristics
|
||||
|
||||
### Chunked Processing Performance
|
||||
|
||||
| Aspect | Behavior | Performance |
|
||||
|--------|----------|-------------|
|
||||
| **Chunk Processing** | All chunks processed with native pooling | Consistent with input length |
|
||||
| **Cross-chunk Aggregation** | MEAN weighted averaging | Minimal overhead |
|
||||
| **Memory Usage** | Proportional to number of chunks | Moderate, scalable |
|
||||
| **Semantic Quality** | Complete text coverage | Optimal for long documents |
|
||||
|
||||
## 🧪 Test Cases
|
||||
|
||||
The test client demonstrates:
|
||||
|
||||
- ✅ **Short text**: Normal processing (baseline)
|
||||
- ✅ **Medium text**: Single chunk processing
|
||||
- ✅ **Long text**: Multi-chunk processing with aggregation
|
||||
- ✅ **Very long text**: Many chunks processing
|
||||
- ✅ **Extreme long text**: Document-level processing (100K+ tokens)
|
||||
- ✅ **Batch processing**: Mixed-length inputs in one request
|
||||
- ✅ **Consistency**: Reproducible results across runs
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Chunked processing not enabled**:
|
||||
|
||||
```log
|
||||
ValueError: This model's maximum position embeddings length is 4096 tokens...
|
||||
```
|
||||
|
||||
**Solution**: Ensure `enable_chunked_processing: true` in pooler config
|
||||
|
||||
2. **Input exceeds max_embed_len**:
|
||||
|
||||
```log
|
||||
ValueError: This model's maximum embedding input length is 3072000 tokens...
|
||||
```
|
||||
|
||||
**Solution**: Increase `max_embed_len` in pooler config or reduce input length
|
||||
|
||||
3. **Memory errors**:
|
||||
|
||||
```log
|
||||
RuntimeError: CUDA out of memory
|
||||
```
|
||||
|
||||
**Solution**: Reduce chunk size by adjusting model's `max_position_embeddings` or use fewer GPUs
|
||||
|
||||
4. **Slow processing**:
|
||||
**Expected**: Long text takes more time due to multiple inference calls
|
||||
|
||||
### Debug Information
|
||||
|
||||
Server logs show chunked processing activity:
|
||||
|
||||
```log
|
||||
INFO: Input length 150000 exceeds max_position_embeddings 4096, will use chunked processing
|
||||
INFO: Split input of 150000 tokens into 37 chunks (max_chunk_size: 4096)
|
||||
```
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
To extend chunked processing support to other embedding models:
|
||||
|
||||
1. Check model compatibility with the pooling architecture
|
||||
2. Test with various text lengths
|
||||
3. Validate embedding quality compared to single-chunk processing
|
||||
4. Submit PR with test cases and documentation updates
|
||||
|
||||
## 🆕 Enhanced Features
|
||||
|
||||
### max_embed_len Parameter
|
||||
|
||||
The new `max_embed_len` parameter provides:
|
||||
|
||||
- **Simplified Configuration**: No need for `VLLM_ALLOW_LONG_MAX_MODEL_LEN` environment variable
|
||||
- **Flexible Input Validation**: Accept inputs longer than `max_model_len` up to `max_embed_len`
|
||||
- **Extreme Length Support**: Process documents with millions of tokens
|
||||
- **Clear Error Messages**: Better feedback when inputs exceed limits
|
||||
- **Backward Compatibility**: Existing configurations continue to work
|
||||
366
examples/pooling/embed/openai_embedding_long_text/client.py
Normal file
366
examples/pooling/embed/openai_embedding_long_text/client.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Example script demonstrating long text embedding with chunked processing in vLLM.
|
||||
|
||||
This example shows how to use vLLM's chunked processing feature to handle text
|
||||
inputs that exceed the model's maximum token length. The feature automatically
|
||||
splits long text into chunks and handles different pooling types optimally.
|
||||
|
||||
Prerequisites:
|
||||
1. Start vLLM server with chunked processing enabled:
|
||||
|
||||
# MEAN pooling (processes all chunks, recommended for complete coverage)
|
||||
vllm serve intfloat/multilingual-e5-large \
|
||||
--pooler-config \
|
||||
'{"pooling_type": "MEAN", "normalize": true, ' \
|
||||
'"enable_chunked_processing": true, "max_embed_len": 3072000}' \
|
||||
--served-model-name multilingual-e5-large \
|
||||
--trust-remote-code \
|
||||
--port 31090 \
|
||||
--api-key your-api-key
|
||||
|
||||
# OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks)
|
||||
vllm serve BAAI/bge-large-en-v1.5 \
|
||||
--pooler-config \
|
||||
'{"pooling_type": "CLS", "normalize": true, ' \
|
||||
'"enable_chunked_processing": true, "max_embed_len": 1048576}' \
|
||||
--served-model-name bge-large-en-v1.5 \
|
||||
--trust-remote-code \
|
||||
--port 31090 \
|
||||
--api-key your-api-key
|
||||
|
||||
2. Install required dependencies:
|
||||
pip install openai requests
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
|
||||
# Configuration
|
||||
API_KEY = "your-api-key" # Replace with your actual API key
|
||||
BASE_URL = "http://localhost:31090/v1"
|
||||
MODEL_NAME = "multilingual-e5-large"
|
||||
|
||||
|
||||
def generate_long_text(base_text: str, repeat_count: int) -> str:
|
||||
"""Generate long text by repeating base text."""
|
||||
return base_text * repeat_count
|
||||
|
||||
|
||||
def test_embedding_with_different_lengths():
|
||||
"""Test embedding generation with different text lengths."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
# Test cases with different text lengths
|
||||
test_cases = [
|
||||
{
|
||||
"name": "Short Text",
|
||||
"text": "Hello, this is a short text for embedding.",
|
||||
"expected_chunks": 1,
|
||||
},
|
||||
{
|
||||
"name": "Medium Text",
|
||||
"text": generate_long_text(
|
||||
"This is a medium-length text that should fit within the "
|
||||
"model's context window. " * 20,
|
||||
2,
|
||||
),
|
||||
"expected_chunks": 1,
|
||||
},
|
||||
{
|
||||
"name": "Long Text (2 chunks)",
|
||||
"text": generate_long_text(
|
||||
"This is a very long text that will exceed the model's "
|
||||
"maximum context length and trigger chunked processing. " * 50,
|
||||
5,
|
||||
),
|
||||
"expected_chunks": 2,
|
||||
},
|
||||
{
|
||||
"name": "Very Long Text (3+ chunks)",
|
||||
"text": generate_long_text(
|
||||
"This text is extremely long and will definitely "
|
||||
"require multiple chunks for processing. " * 100,
|
||||
10,
|
||||
),
|
||||
"expected_chunks": 3,
|
||||
},
|
||||
]
|
||||
|
||||
print("🧪 Testing vLLM Long Text Embedding with Chunked Processing")
|
||||
print("=" * 70)
|
||||
|
||||
for i, test_case in enumerate(test_cases, 1):
|
||||
print(f"\n📝 Test {i}: {test_case['name']}")
|
||||
print(f"Text length: {len(test_case['text'])} characters")
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=test_case["text"], model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
# Extract embedding data
|
||||
embedding = response.data[0].embedding
|
||||
embedding_dim = len(embedding)
|
||||
|
||||
print("✅ Success!")
|
||||
print(f" - Embedding dimension: {embedding_dim}")
|
||||
print(f" - Processing time: {processing_time:.2f}s")
|
||||
print(f" - Expected chunks: ~{test_case['expected_chunks']}")
|
||||
print(f" - First 5 values: {embedding[:5]}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed: {str(e)}")
|
||||
|
||||
|
||||
def test_batch_embedding():
|
||||
"""Test batch embedding with mixed-length inputs."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
print("\n🔄 Testing Batch Embedding with Mixed Lengths")
|
||||
print("=" * 50)
|
||||
|
||||
# Mix of short and long texts
|
||||
batch_inputs = [
|
||||
"Short text 1",
|
||||
generate_long_text("Medium length text that fits in one chunk. " * 20, 1),
|
||||
"Another short text",
|
||||
generate_long_text("Long text requiring chunked processing. " * 100, 5),
|
||||
]
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=batch_inputs, model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print("✅ Batch processing successful!")
|
||||
print(f" - Number of inputs: {len(batch_inputs)}")
|
||||
print(f" - Number of embeddings: {len(response.data)}")
|
||||
print(f" - Total processing time: {processing_time:.2f}s")
|
||||
print(
|
||||
f" - Average time per input: {processing_time / len(batch_inputs):.2f}s"
|
||||
)
|
||||
|
||||
for i, data in enumerate(response.data):
|
||||
input_length = len(batch_inputs[i])
|
||||
embedding_dim = len(data.embedding)
|
||||
print(
|
||||
f" - Input {i + 1}: {input_length} chars → {embedding_dim}D embedding"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Batch processing failed: {str(e)}")
|
||||
|
||||
|
||||
def test_multiple_long_texts_batch():
|
||||
"""Test batch processing with multiple long texts to verify chunk ID uniqueness."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
print("\n🔧 Testing Multiple Long Texts in Batch (Chunk ID Fix Verification)")
|
||||
print("=" * 70)
|
||||
|
||||
# Create multiple distinct long texts that will all require chunking
|
||||
# Note: All pooling types now use MEAN aggregation across chunks:
|
||||
# - Native pooling (MEAN/CLS/LAST) is used within each chunk
|
||||
# - MEAN aggregation combines results across all chunks
|
||||
# - Full semantic coverage for all pooling types
|
||||
long_texts = [
|
||||
generate_long_text(
|
||||
"First long document about artificial intelligence and machine learning. "
|
||||
* 80,
|
||||
6,
|
||||
),
|
||||
generate_long_text(
|
||||
"Second long document about natural language processing and transformers. "
|
||||
* 80,
|
||||
6,
|
||||
),
|
||||
generate_long_text(
|
||||
"Third long document about computer vision and neural networks. " * 80, 6
|
||||
),
|
||||
]
|
||||
|
||||
# Add some short texts to mix things up
|
||||
batch_inputs = [
|
||||
"Short text before long texts",
|
||||
long_texts[0],
|
||||
"Short text between long texts",
|
||||
long_texts[1],
|
||||
long_texts[2],
|
||||
"Short text after long texts",
|
||||
]
|
||||
|
||||
print("📊 Batch composition:")
|
||||
for i, text in enumerate(batch_inputs):
|
||||
length = len(text)
|
||||
text_type = "Long (will be chunked)" if length > 5000 else "Short"
|
||||
print(f" - Input {i + 1}: {length} chars ({text_type})")
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=batch_inputs, model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print("\n✅ Multiple long texts batch processing successful!")
|
||||
print(f" - Number of inputs: {len(batch_inputs)}")
|
||||
print(f" - Number of embeddings returned: {len(response.data)}")
|
||||
print(f" - Total processing time: {processing_time:.2f}s")
|
||||
|
||||
# Verify each embedding is different (no incorrect aggregation)
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
|
||||
if len(embeddings) >= 3:
|
||||
import numpy as np
|
||||
|
||||
# Compare embeddings of the long texts (indices 1, 3, 4)
|
||||
long_embeddings = [
|
||||
np.array(embeddings[1]), # First long text
|
||||
np.array(embeddings[3]), # Second long text
|
||||
np.array(embeddings[4]), # Third long text
|
||||
]
|
||||
|
||||
print("\n🔍 Verifying embedding uniqueness:")
|
||||
for i in range(len(long_embeddings)):
|
||||
for j in range(i + 1, len(long_embeddings)):
|
||||
cosine_sim = np.dot(long_embeddings[i], long_embeddings[j]) / (
|
||||
np.linalg.norm(long_embeddings[i])
|
||||
* np.linalg.norm(long_embeddings[j])
|
||||
)
|
||||
print(
|
||||
f" - Similarity between long text {i + 1} and {j + 1}: "
|
||||
f"{cosine_sim:.4f}"
|
||||
)
|
||||
|
||||
if (
|
||||
cosine_sim < 0.9
|
||||
): # Different content should have lower similarity
|
||||
print(" ✅ Good: Embeddings are appropriately different")
|
||||
else:
|
||||
print(
|
||||
" ⚠️ High similarity - may indicate chunk "
|
||||
"aggregation issue"
|
||||
)
|
||||
|
||||
print("\n📋 Per-input results:")
|
||||
for i, data in enumerate(response.data):
|
||||
input_length = len(batch_inputs[i])
|
||||
embedding_dim = len(data.embedding)
|
||||
embedding_norm = np.linalg.norm(data.embedding)
|
||||
print(
|
||||
f" - Input {i + 1}: {input_length} chars → {embedding_dim}D "
|
||||
f"embedding (norm: {embedding_norm:.4f})"
|
||||
)
|
||||
|
||||
print(
|
||||
"\n✅ This test verifies the fix for chunk ID collisions in "
|
||||
"batch processing"
|
||||
)
|
||||
print(" - Before fix: Multiple long texts would have conflicting chunk IDs")
|
||||
print(" - After fix: Each prompt's chunks have unique IDs with prompt index")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Multiple long texts batch test failed: {str(e)}")
|
||||
print(" This might indicate the chunk ID collision bug is present!")
|
||||
|
||||
|
||||
def test_embedding_consistency():
|
||||
"""Test that chunked processing produces consistent results."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
print("\n🔍 Testing Embedding Consistency")
|
||||
print("=" * 40)
|
||||
|
||||
# Use the same long text multiple times
|
||||
long_text = generate_long_text(
|
||||
"Consistency test text for chunked processing validation. " * 50, 3
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
|
||||
try:
|
||||
for i in range(3):
|
||||
response = client.embeddings.create(
|
||||
input=long_text, model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
embeddings.append(response.data[0].embedding)
|
||||
print(f" - Generated embedding {i + 1}")
|
||||
|
||||
# Check consistency (embeddings should be identical)
|
||||
if len(embeddings) >= 2:
|
||||
# Calculate similarity between first two embeddings
|
||||
|
||||
emb1 = np.array(embeddings[0])
|
||||
emb2 = np.array(embeddings[1])
|
||||
|
||||
# Cosine similarity
|
||||
cosine_sim = np.dot(emb1, emb2) / (
|
||||
np.linalg.norm(emb1) * np.linalg.norm(emb2)
|
||||
)
|
||||
|
||||
print("✅ Consistency test completed!")
|
||||
print(f" - Cosine similarity between runs: {cosine_sim:.6f}")
|
||||
print(" - Expected: ~1.0 (identical embeddings)")
|
||||
|
||||
if cosine_sim > 0.999:
|
||||
print(" - ✅ High consistency achieved!")
|
||||
else:
|
||||
print(" - ⚠️ Consistency may vary due to numerical precision")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Consistency test failed: {str(e)}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run all tests."""
|
||||
print("🚀 vLLM Long Text Embedding Client")
|
||||
print(f"📡 Connecting to: {BASE_URL}")
|
||||
print(f"🤖 Model: {MODEL_NAME}")
|
||||
masked_key = "*" * (len(API_KEY) - 4) + API_KEY[-4:] if len(API_KEY) > 4 else "****"
|
||||
print(f"🔑 API Key: {masked_key}")
|
||||
|
||||
# Run all test cases
|
||||
test_embedding_with_different_lengths()
|
||||
test_batch_embedding()
|
||||
test_multiple_long_texts_batch()
|
||||
test_embedding_consistency()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("🎉 All tests completed!")
|
||||
print("\n💡 Key Features Demonstrated:")
|
||||
print(" - ✅ Automatic chunked processing for long text")
|
||||
print(" - ✅ Seamless handling of mixed-length batches")
|
||||
print(" - ✅ Multiple long texts in single batch (chunk ID fix)")
|
||||
print(" - ✅ Unified chunked processing:")
|
||||
print(" • Native pooling used within each chunk")
|
||||
print(" • MEAN aggregation across all chunks")
|
||||
print(" • Complete semantic coverage for all pooling types")
|
||||
print(" - ✅ Consistent embedding generation")
|
||||
print(" - ✅ Backward compatibility with short text")
|
||||
print("\n📚 For more information, see:")
|
||||
print(
|
||||
" - Documentation: https://docs.vllm.ai/en/latest/models/pooling_models.html"
|
||||
)
|
||||
print(" - Chunked Processing Guide: openai_embedding_long_text.md")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
136
examples/pooling/embed/openai_embedding_long_text/service.sh
Normal file
136
examples/pooling/embed/openai_embedding_long_text/service.sh
Normal file
@@ -0,0 +1,136 @@
|
||||
#!/bin/bash
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# vLLM Embedding Server with Enhanced Chunked Processing
|
||||
# This script starts a vLLM server with chunked processing enabled for long text embedding.
|
||||
# Now supports proper pooling type validation and model-specific configurations.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Configuration
|
||||
MODEL_NAME=${MODEL_NAME:-"intfloat/multilingual-e5-large"}
|
||||
MODEL_CODE=${MODEL_CODE:-"multilingual-e5-large"}
|
||||
|
||||
PORT=${PORT:-31090}
|
||||
GPU_COUNT=${GPU_COUNT:-1}
|
||||
MAX_EMBED_LEN=${MAX_EMBED_LEN:-3072000}
|
||||
API_KEY=${API_KEY:-"your-api-key"}
|
||||
|
||||
# Enhanced pooling configuration with model-specific defaults
|
||||
POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST
|
||||
export VLLM_ENABLE_CHUNKED_PROCESSING=true
|
||||
export CUDA_VISIBLE_DEVICES=2,3,4,5
|
||||
|
||||
echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing"
|
||||
echo "=================================================================="
|
||||
|
||||
# Environment variables for optimization
|
||||
export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
|
||||
# Function to determine optimal pooling type for known models
|
||||
get_optimal_pooling_type() {
|
||||
local model="$1"
|
||||
case "$model" in
|
||||
*"e5-"* | *"multilingual-e5"*)
|
||||
echo "MEAN" # E5 series native pooling
|
||||
;;
|
||||
*"bge-"*)
|
||||
echo "CLS" # BGE series native pooling
|
||||
;;
|
||||
*"gte-"*)
|
||||
echo "LAST" # GTE series native pooling
|
||||
;;
|
||||
*"sentence-t5"* | *"st5"*)
|
||||
echo "MEAN" # Sentence-T5 native pooling
|
||||
;;
|
||||
*"jina-embeddings"*)
|
||||
echo "MEAN" # Jina embeddings native pooling
|
||||
;;
|
||||
*"Qwen"*"Embedding"*)
|
||||
echo "LAST" # Qwen embeddings native pooling
|
||||
;;
|
||||
*)
|
||||
echo "MEAN" # Default native pooling for unknown models
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Auto-detect pooling type if not explicitly set
|
||||
if [ "$POOLING_TYPE" = "auto" ]; then
|
||||
POOLING_TYPE=$(get_optimal_pooling_type "$MODEL_NAME")
|
||||
echo "🔍 Auto-detected pooling type: $POOLING_TYPE for model $MODEL_NAME"
|
||||
fi
|
||||
|
||||
# Display configuration
|
||||
echo "📋 Configuration:"
|
||||
echo " - Model: $MODEL_NAME"
|
||||
echo " - Port: $PORT"
|
||||
echo " - GPU Count: $GPU_COUNT"
|
||||
echo " - Enhanced Chunked Processing: ${VLLM_ENABLE_CHUNKED_PROCESSING}"
|
||||
echo " - Max Embed Length: ${MAX_EMBED_LEN} tokens"
|
||||
echo " - Native Pooling Type: $POOLING_TYPE + Normalization"
|
||||
echo " - Cross-chunk Aggregation: MEAN (automatic)"
|
||||
echo ""
|
||||
|
||||
# Validate GPU availability
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
gpu_count=$(nvidia-smi --list-gpus | wc -l)
|
||||
echo "🖥️ Available GPUs: $gpu_count"
|
||||
if [ "$GPU_COUNT" -gt "$gpu_count" ]; then
|
||||
echo "⚠️ Warning: Requested $GPU_COUNT GPUs but only $gpu_count available"
|
||||
echo " Adjusting to use $gpu_count GPUs"
|
||||
GPU_COUNT=$gpu_count
|
||||
fi
|
||||
else
|
||||
echo "⚠️ Warning: nvidia-smi not found. GPU detection skipped."
|
||||
fi
|
||||
|
||||
# Chunked processing uses unified MEAN aggregation
|
||||
echo "ℹ️ Chunked Processing: Using $POOLING_TYPE pooling within chunks, MEAN aggregation across chunks"
|
||||
echo " - All chunks processed for complete semantic coverage"
|
||||
echo " - Weighted averaging based on chunk token counts"
|
||||
|
||||
echo ""
|
||||
echo "🔧 Starting server with enhanced chunked processing configuration..."
|
||||
|
||||
# Build pooler config JSON
|
||||
POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enable_chunked_processing\": ${VLLM_ENABLE_CHUNKED_PROCESSING}, \"max_embed_len\": ${MAX_EMBED_LEN}}"
|
||||
|
||||
# Start vLLM server with enhanced chunked processing
|
||||
vllm serve "$MODEL_NAME" \
|
||||
--tensor-parallel-size "$GPU_COUNT" \
|
||||
--enforce-eager \
|
||||
--pooler-config "$POOLER_CONFIG" \
|
||||
--served-model-name ${MODEL_CODE} \
|
||||
--api-key "$API_KEY" \
|
||||
--trust-remote-code \
|
||||
--port "$PORT" \
|
||||
--host 0.0.0.0
|
||||
|
||||
echo ""
|
||||
echo "✅ vLLM Embedding Server started successfully!"
|
||||
echo ""
|
||||
echo "📡 Server Information:"
|
||||
echo " - Base URL: http://localhost:$PORT"
|
||||
echo " - Model Code: ${MODEL_CODE}"
|
||||
echo " - API Key: $API_KEY"
|
||||
echo " - Native Pooling: $POOLING_TYPE | Cross-chunk: MEAN"
|
||||
echo ""
|
||||
echo "🧪 Test the server with:"
|
||||
echo " python examples/online_serving/openai_embedding_long_text/client.py"
|
||||
echo ""
|
||||
echo "📚 Enhanced features enabled:"
|
||||
echo " ✅ Intelligent native pooling type detection"
|
||||
echo " ✅ Unified MEAN aggregation for chunked processing"
|
||||
echo " ✅ Model-specific native pooling optimization"
|
||||
echo " ✅ Enhanced max embedding length (${MAX_EMBED_LEN} tokens)"
|
||||
echo " ✅ Complete semantic coverage for all pooling types"
|
||||
echo " ✅ OpenAI-compatible API"
|
||||
echo " ✅ GPU acceleration"
|
||||
echo ""
|
||||
echo "🔧 Advanced usage:"
|
||||
echo " - Set POOLING_TYPE=MEAN|CLS|LAST to override auto-detection"
|
||||
echo " - Set MAX_EMBED_LEN to adjust maximum input length"
|
||||
echo " - All pooling types use MEAN aggregation across chunks"
|
||||
37
examples/pooling/embed/openai_embedding_matryoshka_fy.py
Normal file
37
examples/pooling/embed/openai_embedding_matryoshka_fy.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Example Python client for embedding API dimensions using vLLM API server
|
||||
NOTE:
|
||||
start a supported Matryoshka Embeddings model server with `vllm serve`, e.g.
|
||||
vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
|
||||
def main():
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
responses = client.embeddings.create(
|
||||
input=["Follow the white rabbit."],
|
||||
model=model,
|
||||
dimensions=32,
|
||||
)
|
||||
|
||||
for data in responses.data:
|
||||
print(data.embedding) # List of float of len 32
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
56
examples/pooling/plugin/prithvi_geospatial_mae_client.py
Normal file
56
examples/pooling/plugin/prithvi_geospatial_mae_client.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
# This example shows how to perform an online inference that generates
|
||||
# multimodal data. In this specific case this example will take a geotiff
|
||||
# image as input, process it using the multimodal data processor, and
|
||||
# perform inference.
|
||||
# Requirements :
|
||||
# - install TerraTorch v1.1 (or later):
|
||||
# pip install terratorch>=v1.1
|
||||
# - start vllm in serving mode with the below args
|
||||
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
||||
# --model-impl terratorch
|
||||
# --trust-remote-code
|
||||
# --skip-tokenizer-init --enforce-eager
|
||||
# --io-processor-plugin terratorch_segmentation
|
||||
# --enable-mm-embeds
|
||||
|
||||
|
||||
def main():
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||
server_endpoint = "http://localhost:8000/pooling"
|
||||
|
||||
request_payload_url = {
|
||||
"data": {
|
||||
"data": image_url,
|
||||
"data_format": "url",
|
||||
"image_format": "tiff",
|
||||
"out_data_format": "b64_json",
|
||||
},
|
||||
"priority": 0,
|
||||
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
||||
}
|
||||
|
||||
ret = requests.post(server_endpoint, json=request_payload_url)
|
||||
|
||||
print(f"response.status_code: {ret.status_code}")
|
||||
print(f"response.reason:{ret.reason}")
|
||||
|
||||
response = ret.json()
|
||||
|
||||
decoded_image = base64.b64decode(response["data"]["data"])
|
||||
|
||||
out_path = os.path.join(os.getcwd(), "online_prediction.tiff")
|
||||
|
||||
with open(out_path, "wb") as f:
|
||||
f.write(decoded_image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
# This example shows how to perform an offline inference that generates
|
||||
# multimodal data. In this specific case this example will take a geotiff
|
||||
# image as input, process it using the multimodal data processor, and
|
||||
# perform inference.
|
||||
# Requirements:
|
||||
# - install TerraTorch v1.1 (or later):
|
||||
# pip install terratorch>=v1.1
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_default_dtype(torch.float16)
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||
|
||||
img_prompt = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM.
|
||||
# The maximum number depends on the available GPU memory
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin="terratorch_segmentation",
|
||||
model_impl="terratorch",
|
||||
enable_mm_embeds=True,
|
||||
)
|
||||
|
||||
pooler_output = llm.encode(img_prompt, pooling_task="plugin")
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
print(output)
|
||||
decoded_data = base64.b64decode(output.data)
|
||||
|
||||
file_path = os.path.join(os.getcwd(), "offline_prediction.tiff")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(decoded_data)
|
||||
|
||||
print(f"Output file path: {file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
419
examples/pooling/plugin/prithvi_geospatial_mae_offline.py
Normal file
419
examples/pooling/plugin/prithvi_geospatial_mae_offline.py
Normal file
@@ -0,0 +1,419 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
|
||||
import albumentations
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import regex as re
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
NO_DATA = -9999
|
||||
NO_DATA_FLOAT = 0.0001
|
||||
OFFSET = 0
|
||||
PERCENTILE = 99
|
||||
|
||||
datamodule_config = {
|
||||
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
|
||||
"batch_size": 16,
|
||||
"constant_scale": 0.0001,
|
||||
"data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
|
||||
"drop_last": True,
|
||||
"no_data_replace": 0.0,
|
||||
"no_label_replace": -1,
|
||||
"num_workers": 8,
|
||||
"test_transform": [
|
||||
albumentations.Resize(
|
||||
always_apply=False, height=448, interpolation=1, p=1, width=448
|
||||
),
|
||||
albumentations.pytorch.ToTensorV2(
|
||||
transpose_mask=False, always_apply=True, p=1.0
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class PrithviMAE:
|
||||
def __init__(self, model):
|
||||
self.model = LLM(
|
||||
model=model,
|
||||
skip_tokenizer_init=True,
|
||||
dtype="float16",
|
||||
enforce_eager=True,
|
||||
model_impl="terratorch",
|
||||
enable_mm_embeds=True,
|
||||
)
|
||||
|
||||
def run(self, input_data, location_coords):
|
||||
# merge the inputs into one data structure
|
||||
if input_data is not None and input_data.dtype == torch.float32:
|
||||
input_data = input_data.to(torch.float16)
|
||||
input_data = input_data[0]
|
||||
|
||||
mm_data = {
|
||||
"pixel_values": input_data,
|
||||
"location_coords": location_coords,
|
||||
}
|
||||
|
||||
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
|
||||
outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False)
|
||||
|
||||
return outputs[0].outputs.data
|
||||
|
||||
|
||||
def generate_datamodule():
|
||||
datamodule = Sen1Floods11NonGeoDataModule(
|
||||
data_root=datamodule_config["data_root"],
|
||||
batch_size=datamodule_config["batch_size"],
|
||||
num_workers=datamodule_config["num_workers"],
|
||||
bands=datamodule_config["bands"],
|
||||
drop_last=datamodule_config["drop_last"],
|
||||
test_transform=datamodule_config["test_transform"],
|
||||
)
|
||||
|
||||
return datamodule
|
||||
|
||||
|
||||
def process_channel_group(orig_img, channels):
|
||||
"""
|
||||
Args:
|
||||
orig_img: torch.Tensor representing original image (reference)
|
||||
with shape = (bands, H, W).
|
||||
channels: list of indices representing RGB channels.
|
||||
|
||||
Returns:
|
||||
torch.Tensor with shape (num_channels, height, width)
|
||||
for original image
|
||||
"""
|
||||
|
||||
orig_img = orig_img[channels, ...]
|
||||
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
|
||||
valid_mask[orig_img == NO_DATA_FLOAT] = False
|
||||
|
||||
# Rescale (enhancing contrast)
|
||||
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
||||
min_value = OFFSET
|
||||
|
||||
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
|
||||
|
||||
# No data as zeros
|
||||
orig_img[~valid_mask] = 0
|
||||
|
||||
return orig_img
|
||||
|
||||
|
||||
def read_geotiff(file_path: str):
|
||||
"""Read all bands from *file_path* and return image + meta info.
|
||||
|
||||
Args:
|
||||
file_path: path to image file.
|
||||
|
||||
Returns:
|
||||
np.ndarray with shape (bands, height, width)
|
||||
meta info dict
|
||||
"""
|
||||
|
||||
with rasterio.open(file_path) as src:
|
||||
img = src.read()
|
||||
meta = src.meta
|
||||
try:
|
||||
coords = src.lnglat()
|
||||
except Exception:
|
||||
# Cannot read coords
|
||||
coords = None
|
||||
|
||||
return img, meta, coords
|
||||
|
||||
|
||||
def save_geotiff(image, output_path: str, meta: dict):
|
||||
"""Save multi-band image in Geotiff file.
|
||||
|
||||
Args:
|
||||
image: np.ndarray with shape (bands, height, width)
|
||||
output_path: path where to save the image
|
||||
meta: dict with meta info.
|
||||
"""
|
||||
|
||||
with rasterio.open(output_path, "w", **meta) as dest:
|
||||
for i in range(image.shape[0]):
|
||||
dest.write(image[i, :, :], i + 1)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _convert_np_uint8(float_image: torch.Tensor):
|
||||
image = float_image.numpy() * 255.0
|
||||
image = image.astype(dtype=np.uint8)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def load_example(
|
||||
file_paths: list[str],
|
||||
mean: list[float] = None,
|
||||
std: list[float] = None,
|
||||
indices: list[int] | None = None,
|
||||
):
|
||||
"""Build an input example by loading images in *file_paths*.
|
||||
|
||||
Args:
|
||||
file_paths: list of file paths .
|
||||
mean: list containing mean values for each band in the
|
||||
images in *file_paths*.
|
||||
std: list containing std values for each band in the
|
||||
images in *file_paths*.
|
||||
|
||||
Returns:
|
||||
np.array containing created example
|
||||
list of meta info for each image in *file_paths*
|
||||
"""
|
||||
|
||||
imgs = []
|
||||
metas = []
|
||||
temporal_coords = []
|
||||
location_coords = []
|
||||
|
||||
for file in file_paths:
|
||||
img, meta, coords = read_geotiff(file)
|
||||
|
||||
# Rescaling (don't normalize on nodata)
|
||||
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
||||
if indices is not None:
|
||||
img = img[..., indices]
|
||||
if mean is not None and std is not None:
|
||||
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
||||
|
||||
imgs.append(img)
|
||||
metas.append(meta)
|
||||
if coords is not None:
|
||||
location_coords.append(coords)
|
||||
|
||||
try:
|
||||
match = re.search(r"(\d{7,8}T\d{6})", file)
|
||||
if match:
|
||||
year = int(match.group(1)[:4])
|
||||
julian_day = match.group(1).split("T")[0][4:]
|
||||
if len(julian_day) == 3:
|
||||
julian_day = int(julian_day)
|
||||
else:
|
||||
julian_day = (
|
||||
datetime.datetime.strptime(julian_day, "%m%d")
|
||||
.timetuple()
|
||||
.tm_yday
|
||||
)
|
||||
temporal_coords.append([year, julian_day])
|
||||
except Exception as e:
|
||||
print(f"Could not extract timestamp for {file} ({e})")
|
||||
|
||||
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
||||
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
||||
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
||||
|
||||
return imgs, temporal_coords, location_coords, metas
|
||||
|
||||
|
||||
def run_model(
|
||||
input_data,
|
||||
temporal_coords,
|
||||
location_coords,
|
||||
model,
|
||||
datamodule,
|
||||
img_size,
|
||||
lightning_model=None,
|
||||
):
|
||||
# Reflect pad if not divisible by img_size
|
||||
original_h, original_w = input_data.shape[-2:]
|
||||
pad_h = (img_size - (original_h % img_size)) % img_size
|
||||
pad_w = (img_size - (original_w % img_size)) % img_size
|
||||
input_data = np.pad(
|
||||
input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
|
||||
)
|
||||
|
||||
# Build sliding window
|
||||
|
||||
batch_size = 1
|
||||
# batch = torch.tensor(input_data, device="cpu")
|
||||
batch = torch.tensor(input_data)
|
||||
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
||||
h1, w1 = windows.shape[3:5]
|
||||
windows = rearrange(
|
||||
windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
|
||||
)
|
||||
|
||||
# Split into batches if number of windows > batch_size
|
||||
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
||||
windows = torch.tensor_split(windows, num_batches, dim=0)
|
||||
|
||||
if temporal_coords:
|
||||
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
|
||||
else:
|
||||
temporal_coords = None
|
||||
if location_coords:
|
||||
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
|
||||
else:
|
||||
location_coords = None
|
||||
|
||||
# Run Prithvi-EO-V2-300M-TL-Sen1Floods11
|
||||
pred_imgs = []
|
||||
for x in windows:
|
||||
# Apply standardization
|
||||
x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0))
|
||||
x = datamodule.aug(x)["image"]
|
||||
|
||||
with torch.no_grad():
|
||||
pred = model.run(x, location_coords=location_coords)
|
||||
y_hat = pred.argmax(dim=1)
|
||||
|
||||
y_hat = torch.nn.functional.interpolate(
|
||||
y_hat.unsqueeze(1).float(), size=img_size, mode="nearest"
|
||||
)
|
||||
|
||||
pred_imgs.append(y_hat)
|
||||
|
||||
pred_imgs = torch.concat(pred_imgs, dim=0)
|
||||
|
||||
# Build images from patches
|
||||
pred_imgs = rearrange(
|
||||
pred_imgs,
|
||||
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
|
||||
h=img_size,
|
||||
w=img_size,
|
||||
b=1,
|
||||
c=1,
|
||||
h1=h1,
|
||||
w1=w1,
|
||||
)
|
||||
|
||||
# Cut padded area back to original size
|
||||
pred_imgs = pred_imgs[..., :original_h, :original_w]
|
||||
|
||||
# Squeeze (batch size 1)
|
||||
pred_imgs = pred_imgs[0]
|
||||
|
||||
return pred_imgs
|
||||
|
||||
|
||||
def main(
|
||||
data_file: str,
|
||||
model: str,
|
||||
output_dir: str,
|
||||
rgb_outputs: bool,
|
||||
input_indices: list[int] = None,
|
||||
):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
model_obj = PrithviMAE(model=model)
|
||||
datamodule = generate_datamodule()
|
||||
img_size = 512 # Size of Sen1Floods11
|
||||
|
||||
input_data, temporal_coords, location_coords, meta_data = load_example(
|
||||
file_paths=[data_file],
|
||||
indices=input_indices,
|
||||
)
|
||||
|
||||
meta_data = meta_data[0] # only one image
|
||||
|
||||
if input_data.mean() > 1:
|
||||
input_data = input_data / 10000 # Convert to range 0-1
|
||||
|
||||
channels = [
|
||||
datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
|
||||
] # BGR -> RGB
|
||||
|
||||
pred = run_model(
|
||||
input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
|
||||
)
|
||||
# Save pred
|
||||
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
||||
pred_file = os.path.join(
|
||||
output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
|
||||
)
|
||||
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
|
||||
|
||||
# Save image + pred
|
||||
meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0)
|
||||
|
||||
if input_data.mean() < 1:
|
||||
input_data = input_data * 10000 # Scale to 0-10000
|
||||
|
||||
rgb_orig = process_channel_group(
|
||||
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
|
||||
channels=channels,
|
||||
)
|
||||
rgb_orig = rgb_orig.to(torch.float32)
|
||||
|
||||
pred[pred == 0.0] = np.nan
|
||||
img_pred = rgb_orig * 0.7 + pred * 0.3
|
||||
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
|
||||
|
||||
img_pred_file = os.path.join(
|
||||
output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
|
||||
)
|
||||
save_geotiff(
|
||||
image=_convert_np_uint8(img_pred),
|
||||
output_path=img_pred_file,
|
||||
meta=meta_data,
|
||||
)
|
||||
|
||||
# Save image rgb
|
||||
if rgb_outputs:
|
||||
name_suffix = os.path.splitext(os.path.basename(data_file))[0]
|
||||
rgb_file = os.path.join(
|
||||
output_dir,
|
||||
f"original_rgb_{name_suffix}.tiff",
|
||||
)
|
||||
save_geotiff(
|
||||
image=_convert_np_uint8(rgb_orig),
|
||||
output_path=rgb_file,
|
||||
meta=meta_data,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default="./India_900498_S2Hand.tif",
|
||||
help="Path to the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
||||
help="Path to a checkpoint file to load from.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Path to the directory where to save outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_indices",
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="""
|
||||
0-based indices of the six Prithvi channels to be selected from the input.
|
||||
By default selects [1,2,3,8,11,12] for S2L1C data.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgb_outputs",
|
||||
action="store_true",
|
||||
help="If present, output files will only contain RGB channels. "
|
||||
"Otherwise, all bands will be saved.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(**vars(args))
|
||||
63
examples/pooling/pooling/openai_pooling_client.py
Normal file
63
examples/pooling/pooling/openai_pooling_client.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Example online usage of Pooling API.
|
||||
|
||||
Run `vllm serve <model> --runner pooling`
|
||||
to start up the server in vLLM. e.g.
|
||||
|
||||
vllm serve internlm/internlm2-1_8b-reward --trust-remote-code
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="internlm/internlm2-1_8b-reward")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/pooling"
|
||||
model_name = args.model
|
||||
|
||||
# Input like Completions API
|
||||
prompt = {"model": model_name, "input": "vLLM is great!"}
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("-" * 50)
|
||||
print("Pooling Response:")
|
||||
pprint.pprint(pooling_response.json())
|
||||
print("-" * 50)
|
||||
|
||||
# Input like Chat API
|
||||
prompt = {
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "vLLM is great!"}],
|
||||
}
|
||||
],
|
||||
}
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("Pooling Response:")
|
||||
pprint.pprint(pooling_response.json())
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
410
examples/pooling/pooling/vision_language_pooling.py
Normal file
410
examples/pooling/pooling/vision_language_pooling.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference with
|
||||
the correct prompt format on vision language models for multimodal pooling.
|
||||
|
||||
For most models, the prompt format should follow corresponding examples
|
||||
on HuggingFace model repository.
|
||||
"""
|
||||
|
||||
from argparse import Namespace
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
from typing import Literal, NamedTuple, TypeAlias, TypedDict, get_args
|
||||
|
||||
from PIL.Image import Image
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.entrypoints.score_utils import ScoreMultiModalParam
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
EXAMPLES_DIR = ROOT_DIR / "examples"
|
||||
|
||||
|
||||
class TextQuery(TypedDict):
|
||||
modality: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ImageQuery(TypedDict):
|
||||
modality: Literal["image"]
|
||||
image: Image
|
||||
|
||||
|
||||
class TextImageQuery(TypedDict):
|
||||
modality: Literal["text+image"]
|
||||
text: str
|
||||
image: Image
|
||||
|
||||
|
||||
class TextImagesQuery(TypedDict):
|
||||
modality: Literal["text+images"]
|
||||
text: str
|
||||
image: ScoreMultiModalParam
|
||||
|
||||
|
||||
QueryModality = Literal["text", "image", "text+image", "text+images"]
|
||||
Query: TypeAlias = TextQuery | ImageQuery | TextImageQuery | TextImagesQuery
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompt: str | None = None
|
||||
image: Image | None = None
|
||||
query: str | None = None
|
||||
documents: ScoreMultiModalParam | None = None
|
||||
|
||||
|
||||
def run_clip(query: Query) -> ModelRequestData:
|
||||
if query["modality"] == "text":
|
||||
prompt = query["text"]
|
||||
image = None
|
||||
elif query["modality"] == "image":
|
||||
prompt = "" # For image input, make sure that the prompt text is empty
|
||||
image = query["image"]
|
||||
else:
|
||||
modality = query["modality"]
|
||||
raise ValueError(f"Unsupported query modality: '{modality}'")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="openai/clip-vit-base-patch32",
|
||||
runner="pooling",
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
def run_e5_v(query: Query) -> ModelRequestData:
|
||||
llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501
|
||||
|
||||
if query["modality"] == "text":
|
||||
text = query["text"]
|
||||
prompt = llama3_template.format(f"{text}\nSummary above sentence in one word: ")
|
||||
image = None
|
||||
elif query["modality"] == "image":
|
||||
prompt = llama3_template.format("<image>\nSummary above image in one word: ")
|
||||
image = query["image"]
|
||||
else:
|
||||
modality = query["modality"]
|
||||
raise ValueError(f"Unsupported query modality: '{modality}'")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="royokong/e5-v",
|
||||
runner="pooling",
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
def run_jinavl_reranker(query: Query) -> ModelRequestData:
|
||||
if query["modality"] != "text+images":
|
||||
raise ValueError(f"Unsupported query modality: '{query['modality']}'")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="jinaai/jina-reranker-m0",
|
||||
runner="pooling",
|
||||
max_model_len=32768,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs={
|
||||
"min_pixels": 3136,
|
||||
"max_pixels": 602112,
|
||||
},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
query=query["text"],
|
||||
documents=query["image"],
|
||||
)
|
||||
|
||||
|
||||
def run_siglip(query: Query) -> ModelRequestData:
|
||||
if query["modality"] == "text":
|
||||
prompt = query["text"]
|
||||
image = None
|
||||
elif query["modality"] == "image":
|
||||
prompt = "" # For image input, make sure that the prompt text is empty
|
||||
image = query["image"]
|
||||
else:
|
||||
modality = query["modality"]
|
||||
raise ValueError(f"Unsupported query modality: '{modality}'")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="google/siglip-base-patch16-224",
|
||||
runner="pooling",
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
def _get_vlm2vec_prompt_image(query: Query, image_token: str):
|
||||
if query["modality"] == "text":
|
||||
text = query["text"]
|
||||
prompt = f"Find me an everyday image that matches the given caption: {text}"
|
||||
image = None
|
||||
elif query["modality"] == "image":
|
||||
prompt = f"{image_token} Find a day-to-day image that looks similar to the provided image." # noqa: E501
|
||||
image = query["image"]
|
||||
elif query["modality"] == "text+image":
|
||||
text = query["text"]
|
||||
prompt = f"{image_token} Represent the given image with the following question: {text}" # noqa: E501
|
||||
image = query["image"]
|
||||
else:
|
||||
modality = query["modality"]
|
||||
raise ValueError(f"Unsupported query modality: {modality!r}")
|
||||
|
||||
return prompt, image
|
||||
|
||||
|
||||
def run_vlm2vec_phi3v(query: Query) -> ModelRequestData:
|
||||
prompt, image = _get_vlm2vec_prompt_image(query, "<|image_1|>")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="TIGER-Lab/VLM2Vec-Full",
|
||||
runner="pooling",
|
||||
max_model_len=4096,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs={"num_crops": 4},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData:
|
||||
# vLLM does not support LoRA adapters on multi-modal encoder,
|
||||
# so we merge the weights first
|
||||
from huggingface_hub.constants import HF_HUB_CACHE
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor
|
||||
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
|
||||
model_id = "TIGER-Lab/VLM2Vec-Qwen2VL-2B"
|
||||
|
||||
base_model = AutoModelForImageTextToText.from_pretrained(model_id)
|
||||
lora_model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
model_id,
|
||||
config=PeftConfig.from_pretrained(model_id),
|
||||
)
|
||||
model = lora_model.merge_and_unload().to(dtype=base_model.dtype)
|
||||
model._hf_peft_config_loaded = False # Needed to save the merged model
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_id,
|
||||
# `min_pixels` and `max_pixels` are deprecated for
|
||||
# transformers `preprocessor_config.json`
|
||||
size={"shortest_edge": 3136, "longest_edge": 12845056},
|
||||
)
|
||||
processor.chat_template = load_chat_template(
|
||||
# The original chat template is not correct
|
||||
EXAMPLES_DIR / "template_vlm2vec_qwen2vl.jinja",
|
||||
)
|
||||
|
||||
merged_path = str(
|
||||
Path(HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--") + "-vllm")
|
||||
)
|
||||
print(f"Saving merged model to {merged_path}...")
|
||||
print(
|
||||
"NOTE: This directory is not tracked by `huggingface_hub` "
|
||||
"so you have to delete this manually if you don't want it anymore."
|
||||
)
|
||||
model.save_pretrained(merged_path)
|
||||
processor.save_pretrained(merged_path)
|
||||
print("Done!")
|
||||
|
||||
prompt, image = _get_vlm2vec_prompt_image(query, "<|image_pad|>")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=merged_path,
|
||||
runner="pooling",
|
||||
max_model_len=4096,
|
||||
mm_processor_kwargs={
|
||||
"min_pixels": 3136,
|
||||
"max_pixels": 12845056,
|
||||
},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
)
|
||||
|
||||
|
||||
def get_query(modality: QueryModality):
|
||||
if modality == "text":
|
||||
return TextQuery(modality="text", text="A dog sitting in the grass")
|
||||
|
||||
if modality == "image":
|
||||
return ImageQuery(
|
||||
modality="image",
|
||||
image=fetch_image(
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501
|
||||
),
|
||||
)
|
||||
|
||||
if modality == "text+image":
|
||||
return TextImageQuery(
|
||||
modality="text+image",
|
||||
text="A cat standing in the snow.",
|
||||
image=fetch_image(
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/cat_snow.jpg" # noqa: E501
|
||||
),
|
||||
)
|
||||
|
||||
if modality == "text+images":
|
||||
return TextImagesQuery(
|
||||
modality="text+images",
|
||||
text="slm markdown",
|
||||
image={
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
msg = f"Modality {modality} is not supported."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def run_encode(model: str, modality: QueryModality, seed: int):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {}
|
||||
)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
mm_data = {}
|
||||
if req_data.image is not None:
|
||||
mm_data["image"] = req_data.image
|
||||
|
||||
outputs = llm.embed(
|
||||
{
|
||||
"prompt": req_data.prompt,
|
||||
"multi_modal_data": mm_data,
|
||||
}
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
print(output.outputs.embedding)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def run_score(model: str, modality: QueryModality, seed: int):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
outputs = llm.score(req_data.query, req_data.documents)
|
||||
|
||||
print("-" * 30)
|
||||
print([output.outputs.score for output in outputs])
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"clip": run_clip,
|
||||
"e5_v": run_e5_v,
|
||||
"jinavl_reranker": run_jinavl_reranker,
|
||||
"siglip": run_siglip,
|
||||
"vlm2vec_phi3v": run_vlm2vec_phi3v,
|
||||
"vlm2vec_qwen2vl": run_vlm2vec_qwen2vl,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Demo on using vLLM for offline inference with "
|
||||
"vision language models for multimodal pooling tasks."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
"-m",
|
||||
type=str,
|
||||
default="vlm2vec_phi3v",
|
||||
choices=model_example_map.keys(),
|
||||
help="The name of the embedding model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
"-t",
|
||||
type=str,
|
||||
default="embedding",
|
||||
choices=["embedding", "scoring"],
|
||||
help="The task type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--modality",
|
||||
type=str,
|
||||
default="image",
|
||||
choices=get_args(QueryModality),
|
||||
help="Modality of the input.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
if args.task == "embedding":
|
||||
run_encode(args.model_name, args.modality, args.seed)
|
||||
elif args.task == "scoring":
|
||||
run_score(args.model_name, args.modality, args.seed)
|
||||
else:
|
||||
raise ValueError(f"Unsupported task: {args.task}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
47
examples/pooling/score/cohere_rerank_client.py
Normal file
47
examples/pooling/score/cohere_rerank_client.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Example of using the OpenAI entrypoint's rerank API which is compatible with
|
||||
the Cohere SDK: https://github.com/cohere-ai/cohere-python
|
||||
Note that `pip install cohere` is needed to run this example.
|
||||
|
||||
run: vllm serve BAAI/bge-reranker-base
|
||||
"""
|
||||
|
||||
import cohere
|
||||
from cohere import Client, ClientV2
|
||||
|
||||
model = "BAAI/bge-reranker-base"
|
||||
|
||||
query = "What is the capital of France?"
|
||||
|
||||
documents = [
|
||||
"The capital of France is Paris",
|
||||
"Reranking is fun!",
|
||||
"vLLM is an open-source framework for fast AI serving",
|
||||
]
|
||||
|
||||
|
||||
def cohere_rerank(
|
||||
client: Client | ClientV2, model: str, query: str, documents: list[str]
|
||||
) -> dict:
|
||||
return client.rerank(model=model, query=query, documents=documents)
|
||||
|
||||
|
||||
def main():
|
||||
# cohere v1 client
|
||||
cohere_v1 = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
|
||||
rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
|
||||
print("-" * 50)
|
||||
print("rerank_v1_result:\n", rerank_v1_result)
|
||||
print("-" * 50)
|
||||
|
||||
# or the v2
|
||||
cohere_v2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
|
||||
rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
|
||||
print("rerank_v2_result:\n", rerank_v2_result)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
134
examples/pooling/score/convert_model_to_seq_cls.py
Normal file
134
examples/pooling/score/convert_model_to_seq_cls.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
# Usage:
|
||||
# for BAAI/bge-reranker-v2-gemma
|
||||
# Caution: "Yes" and "yes" are two different tokens
|
||||
# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
|
||||
# for mxbai-rerank-v2
|
||||
# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls
|
||||
# for Qwen3-Reranker
|
||||
# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls
|
||||
|
||||
|
||||
def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device):
|
||||
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
|
||||
assert len(tokens) == 2
|
||||
|
||||
lm_head_weights = causal_lm.lm_head.weight
|
||||
|
||||
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
|
||||
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
|
||||
|
||||
score_weight = lm_head_weights[true_id].to(device).to(
|
||||
torch.float32
|
||||
) - lm_head_weights[false_id].to(device).to(torch.float32)
|
||||
|
||||
with torch.no_grad():
|
||||
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
|
||||
if seq_cls_model.score.bias is not None:
|
||||
seq_cls_model.score.bias.zero_()
|
||||
|
||||
|
||||
def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device):
|
||||
lm_head_weights = causal_lm.lm_head.weight
|
||||
|
||||
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
|
||||
|
||||
score_weight = lm_head_weights[token_ids].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
seq_cls_model.score.weight.copy_(score_weight)
|
||||
if seq_cls_model.score.bias is not None:
|
||||
seq_cls_model.score.bias.zero_()
|
||||
|
||||
|
||||
method_map = {
|
||||
function.__name__: function for function in [from_2_way_softmax, no_post_processing]
|
||||
}
|
||||
|
||||
|
||||
def converting(
|
||||
model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu"
|
||||
):
|
||||
assert method in method_map
|
||||
|
||||
if method == "from_2_way_softmax":
|
||||
assert len(classifier_from_tokens) == 2
|
||||
num_labels = 1
|
||||
else:
|
||||
num_labels = len(classifier_from_tokens)
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
||||
causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map=device
|
||||
)
|
||||
|
||||
seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name,
|
||||
num_labels=num_labels,
|
||||
ignore_mismatched_sizes=True,
|
||||
device_map=device,
|
||||
)
|
||||
|
||||
method_map[method](
|
||||
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
|
||||
)
|
||||
|
||||
# `llm as reranker` defaults to not using pad_token
|
||||
seq_cls_model.config.use_pad_token = use_pad_token
|
||||
seq_cls_model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
seq_cls_model.save_pretrained(path)
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Converting *ForCausalLM models to "
|
||||
"*ForSequenceClassification models."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="BAAI/bge-reranker-v2-gemma",
|
||||
help="Model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--classifier_from_tokens",
|
||||
type=str,
|
||||
default='["Yes"]',
|
||||
help="classifier from tokens",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--method", type=str, default="no_post_processing", help="Converting converting"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-pad-token", action="store_true", help="Whether to use pad_token"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path",
|
||||
type=str,
|
||||
default="./bge-reranker-v2-gemma-seq-cls",
|
||||
help="Path to save converted model",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
converting(
|
||||
model_name=args.model_name,
|
||||
classifier_from_tokens=json.loads(args.classifier_from_tokens),
|
||||
method=args.method,
|
||||
use_pad_token=args.use_pad_token,
|
||||
path=args.path,
|
||||
)
|
||||
89
examples/pooling/score/offline_reranker.py
Normal file
89
examples/pooling/score/offline_reranker.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
model_name = "Qwen/Qwen3-Reranker-0.6B"
|
||||
|
||||
# What is the difference between the official original version and one
|
||||
# that has been converted into a sequence classification model?
|
||||
# Qwen3-Reranker is a language model that doing reranker by using the
|
||||
# logits of "no" and "yes" tokens.
|
||||
# It needs to computing 151669 tokens logits, making this method extremely
|
||||
# inefficient, not to mention incompatible with the vllm score API.
|
||||
# A method for converting the original model into a sequence classification
|
||||
# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
|
||||
# Models converted offline using this method can not only be more efficient
|
||||
# and support the vllm score API, but also make the init parameters more
|
||||
# concise, for example.
|
||||
# llm = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", runner="pooling")
|
||||
|
||||
# If you want to load the official original version, the init parameters are
|
||||
# as follows.
|
||||
|
||||
|
||||
def get_llm() -> LLM:
|
||||
"""Initializes and returns the LLM model for Qwen3-Reranker."""
|
||||
return LLM(
|
||||
model=model_name,
|
||||
runner="pooling",
|
||||
hf_overrides={
|
||||
"architectures": ["Qwen3ForSequenceClassification"],
|
||||
"classifier_from_token": ["no", "yes"],
|
||||
"is_original_qwen3_reranker": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Why do we need hf_overrides for the official original version:
|
||||
# vllm converts it to Qwen3ForSequenceClassification when loaded for
|
||||
# better performance.
|
||||
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
|
||||
# to manually route to Qwen3ForSequenceClassification.
|
||||
# - Then, we will extract the vector corresponding to classifier_from_token
|
||||
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
|
||||
# - Third, we will convert these two vectors into one vector. The use of
|
||||
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.
|
||||
|
||||
# Please use the query_template and document_template to format the query and
|
||||
# document for better reranker results.
|
||||
|
||||
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
|
||||
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
|
||||
query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
|
||||
document_template = "<Document>: {doc}{suffix}"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
instruction = (
|
||||
"Given a web search query, retrieve relevant passages that answer the query"
|
||||
)
|
||||
|
||||
queries = [
|
||||
"What is the capital of China?",
|
||||
"Explain gravity",
|
||||
]
|
||||
|
||||
documents = [
|
||||
"The capital of China is Beijing.",
|
||||
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
|
||||
]
|
||||
|
||||
queries = [
|
||||
query_template.format(prefix=prefix, instruction=instruction, query=query)
|
||||
for query in queries
|
||||
]
|
||||
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]
|
||||
|
||||
llm = get_llm()
|
||||
outputs = llm.score(queries, documents)
|
||||
|
||||
print("-" * 30)
|
||||
print([output.outputs.score for output in outputs])
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
63
examples/pooling/score/openai_cross_encoder_score.py
Normal file
63
examples/pooling/score/openai_cross_encoder_score.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Example online usage of Score API.
|
||||
|
||||
Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/score"
|
||||
model_name = args.model
|
||||
|
||||
text_1 = "What is the capital of Brazil?"
|
||||
text_2 = "The capital of Brazil is Brasilia."
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("\nPrompt when text_1 and text_2 are both strings:")
|
||||
pprint.pprint(prompt)
|
||||
print("\nScore Response:")
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("\nPrompt when text_1 is string and text_2 is a list:")
|
||||
pprint.pprint(prompt)
|
||||
print("\nScore Response:")
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
text_1 = ["What is the capital of Brazil?", "What is the capital of France?"]
|
||||
text_2 = ["The capital of Brazil is Brasilia.", "The capital of France is Paris."]
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("\nPrompt when text_1 and text_2 are both lists:")
|
||||
pprint.pprint(prompt)
|
||||
print("\nScore Response:")
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Example online usage of Score API.
|
||||
|
||||
Run `vllm serve <model> --runner pooling` to start up the server in vLLM.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="jinaai/jina-reranker-m0")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/score"
|
||||
model_name = args.model
|
||||
|
||||
text_1 = "slm markdown"
|
||||
text_2 = {
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png"
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("\nPrompt when text_1 is string and text_2 is a image list:")
|
||||
pprint.pprint(prompt)
|
||||
print("\nScore Response:")
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
42
examples/pooling/score/openai_reranker.py
Normal file
42
examples/pooling/score/openai_reranker.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Example of using the OpenAI entrypoint's rerank API which is compatible with
|
||||
Jina and Cohere https://jina.ai/reranker
|
||||
|
||||
run: vllm serve BAAI/bge-reranker-base
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:8000/rerank"
|
||||
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
data = {
|
||||
"model": "BAAI/bge-reranker-base",
|
||||
"query": "What is the capital of France?",
|
||||
"documents": [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
"Horses and cows are both animals",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
|
||||
# Check the response
|
||||
if response.status_code == 200:
|
||||
print("Request successful!")
|
||||
print(json.dumps(response.json(), indent=2))
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
54
examples/pooling/token_classify/ner.py
Normal file
54
examples/pooling/token_classify/ner.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="boltuix/NeuroBERT-NER",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Barack Obama visited Microsoft headquarters in Seattle on January 2025."
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(**vars(args))
|
||||
tokenizer = llm.get_tokenizer()
|
||||
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
|
||||
|
||||
# Run inference
|
||||
outputs = llm.encode(prompts, pooling_task="token_classify")
|
||||
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
logits = output.outputs.data
|
||||
predictions = logits.argmax(dim=-1)
|
||||
|
||||
# Map predictions to labels
|
||||
tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids)
|
||||
labels = [label_map[p.item()] for p in predictions]
|
||||
|
||||
# Print results
|
||||
for token, label in zip(tokens, labels):
|
||||
if token not in tokenizer.all_special_tokens:
|
||||
print(f"{token:15} → {label}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
71
examples/pooling/token_classify/ner_client.py
Normal file
71
examples/pooling/token_classify/ner_client.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
|
||||
|
||||
"""
|
||||
Example online usage of Pooling API for Named Entity Recognition (NER).
|
||||
|
||||
Run `vllm serve <model> --runner pooling`
|
||||
to start up the server in vLLM. e.g.
|
||||
|
||||
vllm serve boltuix/NeuroBERT-NER
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="boltuix/NeuroBERT-NER")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
api_url = f"http://{args.host}:{args.port}/pooling"
|
||||
model_name = args.model
|
||||
|
||||
# Load tokenizer and config
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
label_map = config.id2label
|
||||
|
||||
# Input text
|
||||
text = "Barack Obama visited Microsoft headquarters in Seattle on January 2025."
|
||||
prompt = {"model": model_name, "input": text}
|
||||
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
|
||||
# Run inference
|
||||
output = pooling_response.json()["data"][0]
|
||||
logits = torch.tensor(output["data"])
|
||||
predictions = logits.argmax(dim=-1)
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
|
||||
# Map predictions to labels
|
||||
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
||||
labels = [label_map[p.item()] for p in predictions]
|
||||
assert len(tokens) == len(predictions)
|
||||
|
||||
# Print results
|
||||
for token, label in zip(tokens, labels):
|
||||
if token not in tokenizer.all_special_tokens:
|
||||
print(f"{token:15} → {label}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
71
examples/pooling/token_embed/jina_embeddings_v4.py
Normal file
71
examples/pooling/token_embed/jina_embeddings_v4.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.inputs.data import TextPrompt
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
|
||||
# Initialize model
|
||||
model = LLM(
|
||||
model="jinaai/jina-embeddings-v4-vllm-text-matching",
|
||||
runner="pooling",
|
||||
max_model_len=1024,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
|
||||
# Create text prompts
|
||||
text1 = "Ein wunderschöner Sonnenuntergang am Strand"
|
||||
text1_prompt = TextPrompt(prompt=f"Query: {text1}")
|
||||
|
||||
text2 = "浜辺に沈む美しい夕日"
|
||||
text2_prompt = TextPrompt(prompt=f"Query: {text2}")
|
||||
|
||||
# Create image prompt
|
||||
image = fetch_image(
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/eskimo.jpg" # noqa: E501
|
||||
)
|
||||
image_prompt = TextPrompt(
|
||||
prompt="<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|>\n", # noqa: E501
|
||||
multi_modal_data={"image": image},
|
||||
)
|
||||
|
||||
# Encode all prompts
|
||||
prompts = [text1_prompt, text2_prompt, image_prompt]
|
||||
outputs = model.encode(prompts, pooling_task="token_embed")
|
||||
|
||||
|
||||
def get_embeddings(outputs):
|
||||
VISION_START_TOKEN_ID, VISION_END_TOKEN_ID = 151652, 151653
|
||||
|
||||
embeddings = []
|
||||
for output in outputs:
|
||||
if VISION_START_TOKEN_ID in output.prompt_token_ids:
|
||||
# Gather only vision tokens
|
||||
img_start_pos = torch.where(
|
||||
torch.tensor(output.prompt_token_ids) == VISION_START_TOKEN_ID
|
||||
)[0][0]
|
||||
img_end_pos = torch.where(
|
||||
torch.tensor(output.prompt_token_ids) == VISION_END_TOKEN_ID
|
||||
)[0][0]
|
||||
embeddings_tensor = output.outputs.data.detach().clone()[
|
||||
img_start_pos : img_end_pos + 1
|
||||
]
|
||||
else:
|
||||
# Use all tokens for text-only prompts
|
||||
embeddings_tensor = output.outputs.data.detach().clone()
|
||||
|
||||
# Pool and normalize embeddings
|
||||
pooled_output = (
|
||||
embeddings_tensor.sum(dim=0, dtype=torch.float32)
|
||||
/ embeddings_tensor.shape[0]
|
||||
)
|
||||
embeddings.append(torch.nn.functional.normalize(pooled_output, dim=-1))
|
||||
return embeddings
|
||||
|
||||
|
||||
embeddings = get_embeddings(outputs)
|
||||
|
||||
for embedding in embeddings:
|
||||
print(embedding.shape)
|
||||
56
examples/pooling/token_embed/multi_vector_retrieval.py
Normal file
56
examples/pooling/token_embed/multi_vector_retrieval.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="BAAI/bge-m3",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass runner="pooling" for embedding models
|
||||
llm = LLM(**vars(args))
|
||||
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
outputs = llm.embed(prompts)
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
print(len(embeds))
|
||||
|
||||
# Generate embedding for each token. The output is a list of PoolingRequestOutput.
|
||||
outputs = llm.encode(prompts, pooling_task="token_embed")
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
multi_vector = output.outputs.data
|
||||
print(multi_vector.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Example online usage of Pooling API for multi vector retrieval.
|
||||
|
||||
Run `vllm serve <model> --runner pooling`
|
||||
to start up the server in vLLM. e.g.
|
||||
|
||||
vllm serve BAAI/bge-m3
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="BAAI/bge-m3")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/pooling"
|
||||
model_name = args.model
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
prompt = {"model": model_name, "input": prompts}
|
||||
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
for output in pooling_response.json()["data"]:
|
||||
multi_vector = torch.tensor(output["data"])
|
||||
print(multi_vector.shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user