Add Support for Qwen2-VL Multi-modal Embedding Models (#3694)
This commit is contained in:
@@ -38,6 +38,8 @@
|
|||||||
- Mistral embedding models
|
- Mistral embedding models
|
||||||
- Qwen embedding models
|
- Qwen embedding models
|
||||||
- `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`
|
- `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`
|
||||||
|
- Multi-modal embedding models
|
||||||
|
- `python -m sglang.launch_server --model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct --is-embedding --chat-template gme-qwen2-vl`
|
||||||
|
|
||||||
## Reward Models
|
## Reward Models
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class SeparatorStyle(IntEnum):
|
|||||||
CHATGLM3 = auto()
|
CHATGLM3 = auto()
|
||||||
DEEPSEEK_CHAT = auto()
|
DEEPSEEK_CHAT = auto()
|
||||||
METAMATH = auto()
|
METAMATH = auto()
|
||||||
|
QWEN2_VL_EMBED = auto()
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -110,6 +111,15 @@ class Conversation:
|
|||||||
else:
|
else:
|
||||||
ret += role + "\n"
|
ret += role + "\n"
|
||||||
return ret
|
return ret
|
||||||
|
elif self.sep_style == SeparatorStyle.QWEN2_VL_EMBED:
|
||||||
|
ret = "" if system_prompt == "" else system_prompt + self.sep
|
||||||
|
for role, message in self.messages:
|
||||||
|
if message:
|
||||||
|
ret += role + "\n" + message + self.sep
|
||||||
|
else:
|
||||||
|
ret += role + "\n"
|
||||||
|
ret += self.stop_str
|
||||||
|
return ret
|
||||||
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
||||||
ret = system_prompt
|
ret = system_prompt
|
||||||
for role, message in self.messages:
|
for role, message in self.messages:
|
||||||
@@ -366,6 +376,46 @@ def chat_template_exists(template_name: str) -> bool:
|
|||||||
return template_name in chat_templates
|
return template_name in chat_templates
|
||||||
|
|
||||||
|
|
||||||
|
def generate_embedding_convs(
|
||||||
|
texts: List[str], images: List[str], template_name: str
|
||||||
|
) -> List[Conversation]:
|
||||||
|
conv_template = chat_templates[template_name].copy()
|
||||||
|
convs = []
|
||||||
|
for text, image in zip(texts, images):
|
||||||
|
conv = Conversation(
|
||||||
|
name=conv_template.name,
|
||||||
|
system_template=conv_template.system_template,
|
||||||
|
system_message=conv_template.system_message,
|
||||||
|
roles=conv_template.roles,
|
||||||
|
messages=list(conv_template.messages), # prevent in-place modification
|
||||||
|
offset=conv_template.offset,
|
||||||
|
sep_style=SeparatorStyle(conv_template.sep_style),
|
||||||
|
sep=conv_template.sep,
|
||||||
|
sep2=conv_template.sep2,
|
||||||
|
stop_str=conv_template.stop_str,
|
||||||
|
image_data=[],
|
||||||
|
modalities=[],
|
||||||
|
image_token=conv_template.image_token,
|
||||||
|
)
|
||||||
|
real_content = ""
|
||||||
|
|
||||||
|
if image is not None:
|
||||||
|
image_token = (
|
||||||
|
conv.image_token + "\n"
|
||||||
|
if conv.name != "gme-qwen2-vl"
|
||||||
|
else conv.image_token
|
||||||
|
)
|
||||||
|
real_content += image_token
|
||||||
|
if text is not None:
|
||||||
|
real_content += text
|
||||||
|
conv.append_message(conv.roles[0], real_content)
|
||||||
|
# Add a blank message for the assistant.
|
||||||
|
conv.append_message(conv.roles[1], None)
|
||||||
|
convs.append(conv)
|
||||||
|
|
||||||
|
return convs
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_conv(
|
def generate_chat_conv(
|
||||||
request: ChatCompletionRequest, template_name: str
|
request: ChatCompletionRequest, template_name: str
|
||||||
) -> Conversation:
|
) -> Conversation:
|
||||||
@@ -555,6 +605,20 @@ register_conv_template(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage
|
||||||
|
register_conv_template(
|
||||||
|
Conversation(
|
||||||
|
name="gme-qwen2-vl",
|
||||||
|
system_message="You are a helpful assistant.",
|
||||||
|
system_template="<|im_start|>system\n{system_message}",
|
||||||
|
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||||
|
sep="<|im_end|>\n",
|
||||||
|
sep_style=SeparatorStyle.QWEN2_VL_EMBED,
|
||||||
|
stop_str="<|endoftext|>",
|
||||||
|
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
|
# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
|
||||||
register_conv_template(
|
register_conv_template(
|
||||||
Conversation(
|
Conversation(
|
||||||
|
|||||||
@@ -214,13 +214,13 @@ class Engine:
|
|||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
||||||
|
image_data: Optional[Union[List[str], str]] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
||||||
Please refer to `EmbeddingReqInput` for the documentation.
|
Please refer to `EmbeddingReqInput` for the documentation.
|
||||||
"""
|
"""
|
||||||
|
obj = EmbeddingReqInput(text=prompt, image_data=image_data)
|
||||||
obj = EmbeddingReqInput(text=prompt)
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||||
ret = loop.run_until_complete(generator.__anext__())
|
ret = loop.run_until_complete(generator.__anext__())
|
||||||
|
|||||||
@@ -293,6 +293,8 @@ class TokenizedGenerateReqInput:
|
|||||||
class EmbeddingReqInput:
|
class EmbeddingReqInput:
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
text: Optional[Union[List[str], str]] = None
|
text: Optional[Union[List[str], str]] = None
|
||||||
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||||
|
image_data: Optional[Union[List[str], str]] = None
|
||||||
# The token ids for text; one can either specify text or input_ids.
|
# The token ids for text; one can either specify text or input_ids.
|
||||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
# The request id.
|
# The request id.
|
||||||
@@ -303,28 +305,40 @@ class EmbeddingReqInput:
|
|||||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||||
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||||
log_metrics: bool = True
|
log_metrics: bool = True
|
||||||
|
# The modalities of the image data [image, multi-images, video]
|
||||||
|
modalities: Optional[List[str]] = None
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
if (self.text is None and self.input_ids is None) or (
|
# at least one of text, input_ids, or image should be provided
|
||||||
self.text is not None and self.input_ids is not None
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
||||||
):
|
raise ValueError(
|
||||||
raise ValueError("Either text or input_ids should be provided.")
|
"At least one of text, input_ids, or image should be provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
# text and input_ids cannot be provided at the same time
|
||||||
|
if self.text is not None and self.input_ids is not None:
|
||||||
|
raise ValueError("text and input_ids cannot be provided at the same time")
|
||||||
|
|
||||||
# Derive the batch size
|
# Derive the batch size
|
||||||
|
self.batch_size = 0
|
||||||
|
self.is_single = True
|
||||||
|
|
||||||
|
# check the batch size of text
|
||||||
if self.text is not None:
|
if self.text is not None:
|
||||||
if isinstance(self.text, str):
|
if isinstance(self.text, list):
|
||||||
self.is_single = True
|
self.batch_size += len(self.text)
|
||||||
self.batch_size = 1
|
|
||||||
else:
|
else:
|
||||||
|
self.batch_size += 1
|
||||||
|
|
||||||
|
# check the batch size of input_ids
|
||||||
|
if self.input_ids is not None:
|
||||||
|
if isinstance(self.input_ids[0], list):
|
||||||
|
self.batch_size += len(self.input_ids)
|
||||||
|
else:
|
||||||
|
self.batch_size += 1
|
||||||
|
|
||||||
|
if self.batch_size > 1:
|
||||||
self.is_single = False
|
self.is_single = False
|
||||||
self.batch_size = len(self.text)
|
|
||||||
else:
|
|
||||||
if isinstance(self.input_ids[0], int):
|
|
||||||
self.is_single = True
|
|
||||||
self.batch_size = 1
|
|
||||||
else:
|
|
||||||
self.is_single = False
|
|
||||||
self.batch_size = len(self.input_ids)
|
|
||||||
|
|
||||||
# Fill in default arguments
|
# Fill in default arguments
|
||||||
if self.is_single:
|
if self.is_single:
|
||||||
@@ -352,6 +366,7 @@ class EmbeddingReqInput:
|
|||||||
return EmbeddingReqInput(
|
return EmbeddingReqInput(
|
||||||
text=self.text[i] if self.text is not None else None,
|
text=self.text[i] if self.text is not None else None,
|
||||||
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
||||||
|
image_data=self.image_data[i] if self.image_data is not None else None,
|
||||||
sampling_params=self.sampling_params[i],
|
sampling_params=self.sampling_params[i],
|
||||||
rid=self.rid[i],
|
rid=self.rid[i],
|
||||||
)
|
)
|
||||||
@@ -365,6 +380,8 @@ class TokenizedEmbeddingReqInput:
|
|||||||
input_text: str
|
input_text: str
|
||||||
# The input token ids
|
# The input token ids
|
||||||
input_ids: List[int]
|
input_ids: List[int]
|
||||||
|
# The image inputs
|
||||||
|
image_inputs: dict
|
||||||
# Dummy sampling params for compatibility
|
# Dummy sampling params for compatibility
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
|
|
||||||
|
|||||||
@@ -767,6 +767,30 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
|
||||||
|
# Handle multimodal inputs
|
||||||
|
if recv_req.image_inputs is not None:
|
||||||
|
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
||||||
|
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
||||||
|
req.origin_input_ids = self.pad_input_ids_func(
|
||||||
|
req.origin_input_ids, image_inputs
|
||||||
|
)
|
||||||
|
req.extend_image_inputs(image_inputs)
|
||||||
|
|
||||||
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||||
|
error_msg = (
|
||||||
|
"Multimodal prompt is too long after expanding multimodal tokens. "
|
||||||
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
req.origin_input_ids = [0]
|
||||||
|
req.image_inputs = None
|
||||||
|
req.sampling_params.max_new_tokens = 0
|
||||||
|
req.finished_reason = FINISH_ABORT(
|
||||||
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||||
|
)
|
||||||
|
self.waiting_queue.append(req)
|
||||||
|
return
|
||||||
|
|
||||||
# Validate prompts length
|
# Validate prompts length
|
||||||
error_msg = validate_input_length(
|
error_msg = validate_input_length(
|
||||||
req,
|
req,
|
||||||
|
|||||||
@@ -372,13 +372,12 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
input_ids = self.tokenizer.encode(input_text)
|
input_ids = self.tokenizer.encode(input_text)
|
||||||
|
|
||||||
if self.is_generation:
|
|
||||||
# TODO: also support getting embeddings for multimodal models
|
|
||||||
image_inputs: Dict = await self.image_processor.process_images_async(
|
image_inputs: Dict = await self.image_processor.process_images_async(
|
||||||
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
||||||
)
|
)
|
||||||
if image_inputs and "input_ids" in image_inputs:
|
if image_inputs and "input_ids" in image_inputs:
|
||||||
input_ids = image_inputs["input_ids"]
|
input_ids = image_inputs["input_ids"]
|
||||||
|
if self.is_generation:
|
||||||
return_logprob = obj.return_logprob
|
return_logprob = obj.return_logprob
|
||||||
logprob_start_len = obj.logprob_start_len
|
logprob_start_len = obj.logprob_start_len
|
||||||
top_logprobs_num = obj.top_logprobs_num
|
top_logprobs_num = obj.top_logprobs_num
|
||||||
@@ -438,6 +437,7 @@ class TokenizerManager:
|
|||||||
obj.rid,
|
obj.rid,
|
||||||
input_text,
|
input_text,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
image_inputs,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from sglang.srt.conversation import (
|
|||||||
SeparatorStyle,
|
SeparatorStyle,
|
||||||
chat_template_exists,
|
chat_template_exists,
|
||||||
generate_chat_conv,
|
generate_chat_conv,
|
||||||
|
generate_embedding_convs,
|
||||||
register_conv_template,
|
register_conv_template,
|
||||||
)
|
)
|
||||||
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
||||||
@@ -68,6 +69,7 @@ from sglang.srt.openai_api.protocol import (
|
|||||||
FileResponse,
|
FileResponse,
|
||||||
FunctionResponse,
|
FunctionResponse,
|
||||||
LogProbs,
|
LogProbs,
|
||||||
|
MultimodalEmbeddingInput,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
TopLogprob,
|
TopLogprob,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
@@ -1556,11 +1558,37 @@ def v1_embedding_request(all_requests, tokenizer_manager):
|
|||||||
prompt = prompts[0]
|
prompt = prompts[0]
|
||||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
||||||
prompt_kwargs = {"text": prompt}
|
prompt_kwargs = {"text": prompt}
|
||||||
|
elif isinstance(prompt, list) and isinstance(
|
||||||
|
prompt[0], MultimodalEmbeddingInput
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
chat_template_name is not None
|
||||||
|
), "chat_template_name is required for multimodal inputs"
|
||||||
|
texts = []
|
||||||
|
images = []
|
||||||
|
for item in prompt:
|
||||||
|
texts.append(item.text if item.text is not None else None)
|
||||||
|
images.append(item.image if item.image is not None else None)
|
||||||
|
convs = generate_embedding_convs(texts, images, chat_template_name)
|
||||||
|
generate_prompts = []
|
||||||
|
for conv in convs:
|
||||||
|
generate_prompts.append(conv.get_prompt())
|
||||||
|
if len(generate_prompts) == 1:
|
||||||
|
prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]}
|
||||||
|
else:
|
||||||
|
prompt_kwargs = {"text": generate_prompts, "image_data": images}
|
||||||
else:
|
else:
|
||||||
prompt_kwargs = {"input_ids": prompt}
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
else:
|
else:
|
||||||
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
||||||
prompt_kwargs = {"text": prompts}
|
prompt_kwargs = {"text": prompts}
|
||||||
|
elif isinstance(prompts[0], list) and isinstance(
|
||||||
|
prompts[0][0], MultimodalEmbeddingInput
|
||||||
|
):
|
||||||
|
# TODO: multiple requests
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Multiple requests with multimodal inputs are not supported yet"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_kwargs = {"input_ids": prompts}
|
prompt_kwargs = {"input_ids": prompts}
|
||||||
|
|
||||||
|
|||||||
@@ -403,10 +403,17 @@ class ChatCompletionStreamResponse(BaseModel):
|
|||||||
usage: Optional[UsageInfo] = None
|
usage: Optional[UsageInfo] = None
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalEmbeddingInput(BaseModel):
|
||||||
|
text: Optional[str] = None
|
||||||
|
image: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequest(BaseModel):
|
class EmbeddingRequest(BaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
# https://platform.openai.com/docs/api-reference/embeddings/create
|
# https://platform.openai.com/docs/api-reference/embeddings/create
|
||||||
input: Union[List[int], List[List[int]], str, List[str]]
|
input: Union[
|
||||||
|
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
|
||||||
|
]
|
||||||
model: str
|
model: str
|
||||||
encoding_format: str = "float"
|
encoding_format: str = "float"
|
||||||
dimensions: int = None
|
dimensions: int = None
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.server import Engine
|
from sglang.srt.server import Engine
|
||||||
@@ -135,6 +135,76 @@ class HFRunner:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py
|
||||||
|
|
||||||
|
def _get_gme_qwen2_vl_embeddings(
|
||||||
|
self, prompts, image_data: Optional[List[str]] = None
|
||||||
|
):
|
||||||
|
from sglang.srt.utils import load_image
|
||||||
|
|
||||||
|
images = None
|
||||||
|
if image_data is not None:
|
||||||
|
images = [load_image(image)[0] for image in image_data]
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
text=prompts,
|
||||||
|
images=images,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=1800,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
||||||
|
with torch.no_grad():
|
||||||
|
embeddings = self._forward_gme_qwen2_vl(**inputs)
|
||||||
|
return embeddings.tolist()
|
||||||
|
|
||||||
|
def _forward_gme_qwen2_vl(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
pooling_mask: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.model.model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = pixel_values.type(self.model.visual.get_dtype())
|
||||||
|
image_embeds = self.model.visual(
|
||||||
|
pixel_values, grid_thw=image_grid_thw
|
||||||
|
).to(inputs_embeds.device)
|
||||||
|
image_mask = input_ids == self.model.config.image_token_id
|
||||||
|
inputs_embeds[image_mask] = image_embeds
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
outputs = self.model.model(
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
pooling_mask = attention_mask if pooling_mask is None else pooling_mask
|
||||||
|
left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO
|
||||||
|
if left_padding:
|
||||||
|
embeddings = outputs.last_hidden_state[:, -1]
|
||||||
|
else:
|
||||||
|
sequence_lengths = pooling_mask.sum(dim=1) - 1
|
||||||
|
batch_size = outputs.last_hidden_state.shape[0]
|
||||||
|
embeddings = outputs.last_hidden_state[
|
||||||
|
torch.arange(batch_size, device=outputs.last_hidden_state.device),
|
||||||
|
sequence_lengths,
|
||||||
|
]
|
||||||
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||||
|
return embeddings.contiguous()
|
||||||
|
|
||||||
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
||||||
# Apply model-specific patches
|
# Apply model-specific patches
|
||||||
monkey_patch_gemma2_sdpa()
|
monkey_patch_gemma2_sdpa()
|
||||||
@@ -148,6 +218,15 @@ class HFRunner:
|
|||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
).cuda()
|
).cuda()
|
||||||
elif self.model_type == "embedding":
|
elif self.model_type == "embedding":
|
||||||
|
if "gme-qwen2-vl" in model_path.lower():
|
||||||
|
self.model = AutoModelForVision2Seq.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
trust_remote_code=False,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
).cuda()
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_path)
|
||||||
|
else:
|
||||||
self.model = _get_sentence_transformer_embedding_model(
|
self.model = _get_sentence_transformer_embedding_model(
|
||||||
model_path, torch_dtype
|
model_path, torch_dtype
|
||||||
)
|
)
|
||||||
@@ -169,7 +248,9 @@ class HFRunner:
|
|||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
while True:
|
while True:
|
||||||
prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
|
prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = (
|
||||||
|
in_queue.get()
|
||||||
|
)
|
||||||
if lora_paths is not None:
|
if lora_paths is not None:
|
||||||
assert len(prompts) == len(lora_paths)
|
assert len(prompts) == len(lora_paths)
|
||||||
|
|
||||||
@@ -189,6 +270,9 @@ class HFRunner:
|
|||||||
)
|
)
|
||||||
elif self.model_type == "embedding":
|
elif self.model_type == "embedding":
|
||||||
assert not self.output_str_only
|
assert not self.output_str_only
|
||||||
|
if "gme-qwen2-vl" in model_path.lower():
|
||||||
|
logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
|
||||||
|
else:
|
||||||
logits = self.model.encode(prompts).tolist()
|
logits = self.model.encode(prompts).tolist()
|
||||||
out_queue.put(ModelOutput(embed_logits=logits))
|
out_queue.put(ModelOutput(embed_logits=logits))
|
||||||
|
|
||||||
@@ -211,11 +295,14 @@ class HFRunner:
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||||
|
image_data: Optional[List[str]] = None,
|
||||||
max_new_tokens: int = 8,
|
max_new_tokens: int = 8,
|
||||||
lora_paths: Optional[List[str]] = None,
|
lora_paths: Optional[List[str]] = None,
|
||||||
token_ids_logprob: Optional[int] = None,
|
token_ids_logprob: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
|
self.in_queue.put(
|
||||||
|
(prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob)
|
||||||
|
)
|
||||||
return self.out_queue.get()
|
return self.out_queue.get()
|
||||||
|
|
||||||
def terminate(self):
|
def terminate(self):
|
||||||
@@ -396,6 +483,7 @@ class SRTRunner:
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||||
|
image_data: Optional[List[str]] = None,
|
||||||
max_new_tokens: int = 8,
|
max_new_tokens: int = 8,
|
||||||
lora_paths: Optional[List[str]] = None,
|
lora_paths: Optional[List[str]] = None,
|
||||||
logprob_start_len: int = 0,
|
logprob_start_len: int = 0,
|
||||||
@@ -413,17 +501,23 @@ class SRTRunner:
|
|||||||
token_ids_logprob=token_ids_logprob,
|
token_ids_logprob=token_ids_logprob,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = self.engine.encode(prompts)
|
|
||||||
if self.model_type == "embedding":
|
if self.model_type == "embedding":
|
||||||
|
response = self.engine.encode(prompt=prompts, image_data=image_data)
|
||||||
|
if isinstance(response, list):
|
||||||
logits = [x["embedding"] for x in response]
|
logits = [x["embedding"] for x in response]
|
||||||
return ModelOutput(embed_logits=logits)
|
|
||||||
else:
|
else:
|
||||||
|
logits = [response["embedding"]]
|
||||||
|
return ModelOutput(embed_logits=logits)
|
||||||
|
# reward model
|
||||||
|
else:
|
||||||
|
response = self.engine.encode(prompts)
|
||||||
scores = [x["embedding"][0] for x in response]
|
scores = [x["embedding"][0] for x in response]
|
||||||
return ModelOutput(scores=scores)
|
return ModelOutput(scores=scores)
|
||||||
|
|
||||||
def batch_forward(
|
def batch_forward(
|
||||||
self,
|
self,
|
||||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||||
|
image_data: Optional[List[str]] = None,
|
||||||
max_new_tokens=8,
|
max_new_tokens=8,
|
||||||
lora_paths=None,
|
lora_paths=None,
|
||||||
):
|
):
|
||||||
@@ -439,7 +533,7 @@ class SRTRunner:
|
|||||||
lora_paths=lora_paths,
|
lora_paths=lora_paths,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = self.engine.encode(prompts)
|
response = self.engine.encode(prompts, image_data)
|
||||||
if self.model_type == "embedding":
|
if self.model_type == "embedding":
|
||||||
logits = [x["embedding"] for x in response]
|
logits = [x["embedding"] for x in response]
|
||||||
return ModelOutput(embed_logits=logits)
|
return ModelOutput(embed_logits=logits)
|
||||||
|
|||||||
85
test/srt/models/test_gme_qwen_models.py
Normal file
85
test/srt/models/test_gme_qwen_models.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
import multiprocessing as mp
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.test.runners import HFRunner, SRTRunner
|
||||||
|
from sglang.test.test_utils import get_similarities
|
||||||
|
|
||||||
|
TEXTS = "two Subway Series sandwiches with meats, cheese, lettuce, tomatoes, and onions on a black background, accompanied by the Subway Series logo, highlighting a new sandwich series."
|
||||||
|
IMAGES = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct", 1e-3),
|
||||||
|
]
|
||||||
|
TORCH_DTYPES = [torch.float16]
|
||||||
|
|
||||||
|
|
||||||
|
class TestQmeQwenModels(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
def assert_close_embeddings(self, model, prefill_tolerance, torch_dtype):
|
||||||
|
|
||||||
|
prompts_no_image = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{TEXTS}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
|
||||||
|
prompts_with_image = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
|
||||||
|
with HFRunner(
|
||||||
|
model,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
model_type="embedding",
|
||||||
|
) as hf_runner:
|
||||||
|
hf_text_embeddings = hf_runner.forward(prompts=[prompts_no_image])
|
||||||
|
hf_image_embeddings = hf_runner.forward(
|
||||||
|
prompts=[prompts_with_image], image_data=[IMAGES]
|
||||||
|
)
|
||||||
|
with SRTRunner(
|
||||||
|
model,
|
||||||
|
tp_size=1,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
model_type="embedding",
|
||||||
|
) as srt_runner:
|
||||||
|
srt_text_embeddings = srt_runner.forward(prompts=prompts_no_image)
|
||||||
|
srt_image_embeddings = srt_runner.forward(
|
||||||
|
prompts=prompts_with_image, image_data=IMAGES
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity = get_similarities(
|
||||||
|
hf_text_embeddings.embed_logits[0], srt_text_embeddings.embed_logits[0]
|
||||||
|
)
|
||||||
|
print("texts similarity diff", abs(similarity - 1))
|
||||||
|
assert torch.all(
|
||||||
|
abs(similarity - 1) < prefill_tolerance
|
||||||
|
), "embeddings are not all close"
|
||||||
|
similarity = get_similarities(
|
||||||
|
hf_image_embeddings.embed_logits[0], srt_image_embeddings.embed_logits[0]
|
||||||
|
)
|
||||||
|
print("images similarity diff", abs(similarity - 1))
|
||||||
|
assert torch.all(
|
||||||
|
abs(similarity - 1) < prefill_tolerance
|
||||||
|
), "embeddings are not all close"
|
||||||
|
|
||||||
|
def test_accuracy(self):
|
||||||
|
for model, prefill_tolerance in MODELS:
|
||||||
|
for torch_dtype in TORCH_DTYPES:
|
||||||
|
self.assert_close_embeddings(model, prefill_tolerance, torch_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -13,6 +13,7 @@ suites = {
|
|||||||
"models/test_qwen_models.py",
|
"models/test_qwen_models.py",
|
||||||
"models/test_reward_models.py",
|
"models/test_reward_models.py",
|
||||||
"test_gptqmodel_dynamic.py",
|
"test_gptqmodel_dynamic.py",
|
||||||
|
"models/test_gme_qwen_models.py",
|
||||||
"test_abort.py",
|
"test_abort.py",
|
||||||
"test_chunked_prefill.py",
|
"test_chunked_prefill.py",
|
||||||
"test_custom_allreduce.py",
|
"test_custom_allreduce.py",
|
||||||
|
|||||||
Reference in New Issue
Block a user