Add Support for Qwen2-VL Multi-modal Embedding Models (#3694)
This commit is contained in:
@@ -44,6 +44,7 @@ class SeparatorStyle(IntEnum):
|
||||
CHATGLM3 = auto()
|
||||
DEEPSEEK_CHAT = auto()
|
||||
METAMATH = auto()
|
||||
QWEN2_VL_EMBED = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -110,6 +111,15 @@ class Conversation:
|
||||
else:
|
||||
ret += role + "\n"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.QWEN2_VL_EMBED:
|
||||
ret = "" if system_prompt == "" else system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + "\n" + message + self.sep
|
||||
else:
|
||||
ret += role + "\n"
|
||||
ret += self.stop_str
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
||||
ret = system_prompt
|
||||
for role, message in self.messages:
|
||||
@@ -366,6 +376,46 @@ def chat_template_exists(template_name: str) -> bool:
|
||||
return template_name in chat_templates
|
||||
|
||||
|
||||
def generate_embedding_convs(
|
||||
texts: List[str], images: List[str], template_name: str
|
||||
) -> List[Conversation]:
|
||||
conv_template = chat_templates[template_name].copy()
|
||||
convs = []
|
||||
for text, image in zip(texts, images):
|
||||
conv = Conversation(
|
||||
name=conv_template.name,
|
||||
system_template=conv_template.system_template,
|
||||
system_message=conv_template.system_message,
|
||||
roles=conv_template.roles,
|
||||
messages=list(conv_template.messages), # prevent in-place modification
|
||||
offset=conv_template.offset,
|
||||
sep_style=SeparatorStyle(conv_template.sep_style),
|
||||
sep=conv_template.sep,
|
||||
sep2=conv_template.sep2,
|
||||
stop_str=conv_template.stop_str,
|
||||
image_data=[],
|
||||
modalities=[],
|
||||
image_token=conv_template.image_token,
|
||||
)
|
||||
real_content = ""
|
||||
|
||||
if image is not None:
|
||||
image_token = (
|
||||
conv.image_token + "\n"
|
||||
if conv.name != "gme-qwen2-vl"
|
||||
else conv.image_token
|
||||
)
|
||||
real_content += image_token
|
||||
if text is not None:
|
||||
real_content += text
|
||||
conv.append_message(conv.roles[0], real_content)
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
convs.append(conv)
|
||||
|
||||
return convs
|
||||
|
||||
|
||||
def generate_chat_conv(
|
||||
request: ChatCompletionRequest, template_name: str
|
||||
) -> Conversation:
|
||||
@@ -555,6 +605,20 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="gme-qwen2-vl",
|
||||
system_message="You are a helpful assistant.",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||
sep="<|im_end|>\n",
|
||||
sep_style=SeparatorStyle.QWEN2_VL_EMBED,
|
||||
stop_str="<|endoftext|>",
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
|
||||
@@ -214,13 +214,13 @@ class Engine:
|
||||
def encode(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||
image_data: Optional[Union[List[str], str]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
||||
Please refer to `EmbeddingReqInput` for the documentation.
|
||||
"""
|
||||
|
||||
obj = EmbeddingReqInput(text=prompt)
|
||||
obj = EmbeddingReqInput(text=prompt, image_data=image_data)
|
||||
loop = asyncio.get_event_loop()
|
||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||
ret = loop.run_until_complete(generator.__anext__())
|
||||
|
||||
@@ -293,6 +293,8 @@ class TokenizedGenerateReqInput:
|
||||
class EmbeddingReqInput:
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
text: Optional[Union[List[str], str]] = None
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# The request id.
|
||||
@@ -303,28 +305,40 @@ class EmbeddingReqInput:
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||
log_metrics: bool = True
|
||||
# The modalities of the image data [image, multi-images, video]
|
||||
modalities: Optional[List[str]] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
self.text is not None and self.input_ids is not None
|
||||
):
|
||||
raise ValueError("Either text or input_ids should be provided.")
|
||||
# at least one of text, input_ids, or image should be provided
|
||||
if self.text is None and self.input_ids is None and self.image_data is None:
|
||||
raise ValueError(
|
||||
"At least one of text, input_ids, or image should be provided"
|
||||
)
|
||||
|
||||
# text and input_ids cannot be provided at the same time
|
||||
if self.text is not None and self.input_ids is not None:
|
||||
raise ValueError("text and input_ids cannot be provided at the same time")
|
||||
|
||||
# Derive the batch size
|
||||
self.batch_size = 0
|
||||
self.is_single = True
|
||||
|
||||
# check the batch size of text
|
||||
if self.text is not None:
|
||||
if isinstance(self.text, str):
|
||||
self.is_single = True
|
||||
self.batch_size = 1
|
||||
if isinstance(self.text, list):
|
||||
self.batch_size += len(self.text)
|
||||
else:
|
||||
self.is_single = False
|
||||
self.batch_size = len(self.text)
|
||||
else:
|
||||
if isinstance(self.input_ids[0], int):
|
||||
self.is_single = True
|
||||
self.batch_size = 1
|
||||
self.batch_size += 1
|
||||
|
||||
# check the batch size of input_ids
|
||||
if self.input_ids is not None:
|
||||
if isinstance(self.input_ids[0], list):
|
||||
self.batch_size += len(self.input_ids)
|
||||
else:
|
||||
self.is_single = False
|
||||
self.batch_size = len(self.input_ids)
|
||||
self.batch_size += 1
|
||||
|
||||
if self.batch_size > 1:
|
||||
self.is_single = False
|
||||
|
||||
# Fill in default arguments
|
||||
if self.is_single:
|
||||
@@ -352,6 +366,7 @@ class EmbeddingReqInput:
|
||||
return EmbeddingReqInput(
|
||||
text=self.text[i] if self.text is not None else None,
|
||||
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
||||
image_data=self.image_data[i] if self.image_data is not None else None,
|
||||
sampling_params=self.sampling_params[i],
|
||||
rid=self.rid[i],
|
||||
)
|
||||
@@ -365,6 +380,8 @@ class TokenizedEmbeddingReqInput:
|
||||
input_text: str
|
||||
# The input token ids
|
||||
input_ids: List[int]
|
||||
# The image inputs
|
||||
image_inputs: dict
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: SamplingParams
|
||||
|
||||
|
||||
@@ -767,6 +767,30 @@ class Scheduler:
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
# Handle multimodal inputs
|
||||
if recv_req.image_inputs is not None:
|
||||
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
||||
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
||||
req.origin_input_ids = self.pad_input_ids_func(
|
||||
req.origin_input_ids, image_inputs
|
||||
)
|
||||
req.extend_image_inputs(image_inputs)
|
||||
|
||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||
error_msg = (
|
||||
"Multimodal prompt is too long after expanding multimodal tokens. "
|
||||
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
req.origin_input_ids = [0]
|
||||
req.image_inputs = None
|
||||
req.sampling_params.max_new_tokens = 0
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||
)
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
# Validate prompts length
|
||||
error_msg = validate_input_length(
|
||||
req,
|
||||
|
||||
@@ -372,13 +372,12 @@ class TokenizerManager:
|
||||
)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
|
||||
image_inputs: Dict = await self.image_processor.process_images_async(
|
||||
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
if self.is_generation:
|
||||
# TODO: also support getting embeddings for multimodal models
|
||||
image_inputs: Dict = await self.image_processor.process_images_async(
|
||||
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
@@ -438,6 +437,7 @@ class TokenizerManager:
|
||||
obj.rid,
|
||||
input_text,
|
||||
input_ids,
|
||||
image_inputs,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.conversation import (
|
||||
SeparatorStyle,
|
||||
chat_template_exists,
|
||||
generate_chat_conv,
|
||||
generate_embedding_convs,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
||||
@@ -68,6 +69,7 @@ from sglang.srt.openai_api.protocol import (
|
||||
FileResponse,
|
||||
FunctionResponse,
|
||||
LogProbs,
|
||||
MultimodalEmbeddingInput,
|
||||
ToolCall,
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
@@ -1556,11 +1558,37 @@ def v1_embedding_request(all_requests, tokenizer_manager):
|
||||
prompt = prompts[0]
|
||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
||||
prompt_kwargs = {"text": prompt}
|
||||
elif isinstance(prompt, list) and isinstance(
|
||||
prompt[0], MultimodalEmbeddingInput
|
||||
):
|
||||
assert (
|
||||
chat_template_name is not None
|
||||
), "chat_template_name is required for multimodal inputs"
|
||||
texts = []
|
||||
images = []
|
||||
for item in prompt:
|
||||
texts.append(item.text if item.text is not None else None)
|
||||
images.append(item.image if item.image is not None else None)
|
||||
convs = generate_embedding_convs(texts, images, chat_template_name)
|
||||
generate_prompts = []
|
||||
for conv in convs:
|
||||
generate_prompts.append(conv.get_prompt())
|
||||
if len(generate_prompts) == 1:
|
||||
prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]}
|
||||
else:
|
||||
prompt_kwargs = {"text": generate_prompts, "image_data": images}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
else:
|
||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||
prompt_kwargs = {"text": prompts}
|
||||
elif isinstance(prompts[0], list) and isinstance(
|
||||
prompts[0][0], MultimodalEmbeddingInput
|
||||
):
|
||||
# TODO: multiple requests
|
||||
raise NotImplementedError(
|
||||
"Multiple requests with multimodal inputs are not supported yet"
|
||||
)
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
|
||||
|
||||
@@ -403,10 +403,17 @@ class ChatCompletionStreamResponse(BaseModel):
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
class MultimodalEmbeddingInput(BaseModel):
|
||||
text: Optional[str] = None
|
||||
image: Optional[str] = None
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
input: Union[List[int], List[List[int]], str, List[str]]
|
||||
input: Union[
|
||||
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
|
||||
]
|
||||
model: str
|
||||
encoding_format: str = "float"
|
||||
dimensions: int = None
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.server import Engine
|
||||
@@ -135,6 +135,76 @@ class HFRunner:
|
||||
return True
|
||||
return False
|
||||
|
||||
# copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py
|
||||
|
||||
def _get_gme_qwen2_vl_embeddings(
|
||||
self, prompts, image_data: Optional[List[str]] = None
|
||||
):
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
images = None
|
||||
if image_data is not None:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
|
||||
inputs = self.processor(
|
||||
text=prompts,
|
||||
images=images,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=1800,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
embeddings = self._forward_gme_qwen2_vl(**inputs)
|
||||
return embeddings.tolist()
|
||||
|
||||
def _forward_gme_qwen2_vl(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pooling_mask: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.model.visual.get_dtype())
|
||||
image_embeds = self.model.visual(
|
||||
pixel_values, grid_thw=image_grid_thw
|
||||
).to(inputs_embeds.device)
|
||||
image_mask = input_ids == self.model.config.image_token_id
|
||||
inputs_embeds[image_mask] = image_embeds
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
outputs = self.model.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
pooling_mask = attention_mask if pooling_mask is None else pooling_mask
|
||||
left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO
|
||||
if left_padding:
|
||||
embeddings = outputs.last_hidden_state[:, -1]
|
||||
else:
|
||||
sequence_lengths = pooling_mask.sum(dim=1) - 1
|
||||
batch_size = outputs.last_hidden_state.shape[0]
|
||||
embeddings = outputs.last_hidden_state[
|
||||
torch.arange(batch_size, device=outputs.last_hidden_state.device),
|
||||
sequence_lengths,
|
||||
]
|
||||
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||
return embeddings.contiguous()
|
||||
|
||||
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
||||
# Apply model-specific patches
|
||||
monkey_patch_gemma2_sdpa()
|
||||
@@ -148,9 +218,18 @@ class HFRunner:
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
elif self.model_type == "embedding":
|
||||
self.model = _get_sentence_transformer_embedding_model(
|
||||
model_path, torch_dtype
|
||||
)
|
||||
if "gme-qwen2-vl" in model_path.lower():
|
||||
self.model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=False,
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
self.processor = AutoProcessor.from_pretrained(model_path)
|
||||
else:
|
||||
self.model = _get_sentence_transformer_embedding_model(
|
||||
model_path, torch_dtype
|
||||
)
|
||||
elif self.model_type == "reward":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
@@ -169,7 +248,9 @@ class HFRunner:
|
||||
|
||||
# Run forward
|
||||
while True:
|
||||
prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
|
||||
prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = (
|
||||
in_queue.get()
|
||||
)
|
||||
if lora_paths is not None:
|
||||
assert len(prompts) == len(lora_paths)
|
||||
|
||||
@@ -189,7 +270,10 @@ class HFRunner:
|
||||
)
|
||||
elif self.model_type == "embedding":
|
||||
assert not self.output_str_only
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
if "gme-qwen2-vl" in model_path.lower():
|
||||
logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
|
||||
else:
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
|
||||
elif self.model_type == "reward":
|
||||
@@ -211,11 +295,14 @@ class HFRunner:
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
image_data: Optional[List[str]] = None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
token_ids_logprob: Optional[int] = None,
|
||||
):
|
||||
self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
|
||||
self.in_queue.put(
|
||||
(prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob)
|
||||
)
|
||||
return self.out_queue.get()
|
||||
|
||||
def terminate(self):
|
||||
@@ -396,6 +483,7 @@ class SRTRunner:
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
image_data: Optional[List[str]] = None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
logprob_start_len: int = 0,
|
||||
@@ -413,17 +501,23 @@ class SRTRunner:
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
response = self.engine.encode(prompt=prompts, image_data=image_data)
|
||||
if isinstance(response, list):
|
||||
logits = [x["embedding"] for x in response]
|
||||
else:
|
||||
logits = [response["embedding"]]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
# reward model
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
scores = [x["embedding"][0] for x in response]
|
||||
return ModelOutput(scores=scores)
|
||||
|
||||
def batch_forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
image_data: Optional[List[str]] = None,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
):
|
||||
@@ -439,7 +533,7 @@ class SRTRunner:
|
||||
lora_paths=lora_paths,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
response = self.engine.encode(prompts, image_data)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
|
||||
Reference in New Issue
Block a user