Add Support for Qwen2-VL Multi-modal Embedding Models (#3694)
This commit is contained in:
@@ -19,7 +19,7 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.server import Engine
|
||||
@@ -135,6 +135,76 @@ class HFRunner:
|
||||
return True
|
||||
return False
|
||||
|
||||
# copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py
|
||||
|
||||
def _get_gme_qwen2_vl_embeddings(
|
||||
self, prompts, image_data: Optional[List[str]] = None
|
||||
):
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
images = None
|
||||
if image_data is not None:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
|
||||
inputs = self.processor(
|
||||
text=prompts,
|
||||
images=images,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=1800,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
embeddings = self._forward_gme_qwen2_vl(**inputs)
|
||||
return embeddings.tolist()
|
||||
|
||||
def _forward_gme_qwen2_vl(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pooling_mask: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.model.visual.get_dtype())
|
||||
image_embeds = self.model.visual(
|
||||
pixel_values, grid_thw=image_grid_thw
|
||||
).to(inputs_embeds.device)
|
||||
image_mask = input_ids == self.model.config.image_token_id
|
||||
inputs_embeds[image_mask] = image_embeds
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
outputs = self.model.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
pooling_mask = attention_mask if pooling_mask is None else pooling_mask
|
||||
left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO
|
||||
if left_padding:
|
||||
embeddings = outputs.last_hidden_state[:, -1]
|
||||
else:
|
||||
sequence_lengths = pooling_mask.sum(dim=1) - 1
|
||||
batch_size = outputs.last_hidden_state.shape[0]
|
||||
embeddings = outputs.last_hidden_state[
|
||||
torch.arange(batch_size, device=outputs.last_hidden_state.device),
|
||||
sequence_lengths,
|
||||
]
|
||||
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
||||
return embeddings.contiguous()
|
||||
|
||||
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
||||
# Apply model-specific patches
|
||||
monkey_patch_gemma2_sdpa()
|
||||
@@ -148,9 +218,18 @@ class HFRunner:
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
elif self.model_type == "embedding":
|
||||
self.model = _get_sentence_transformer_embedding_model(
|
||||
model_path, torch_dtype
|
||||
)
|
||||
if "gme-qwen2-vl" in model_path.lower():
|
||||
self.model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=False,
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
self.processor = AutoProcessor.from_pretrained(model_path)
|
||||
else:
|
||||
self.model = _get_sentence_transformer_embedding_model(
|
||||
model_path, torch_dtype
|
||||
)
|
||||
elif self.model_type == "reward":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
@@ -169,7 +248,9 @@ class HFRunner:
|
||||
|
||||
# Run forward
|
||||
while True:
|
||||
prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
|
||||
prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = (
|
||||
in_queue.get()
|
||||
)
|
||||
if lora_paths is not None:
|
||||
assert len(prompts) == len(lora_paths)
|
||||
|
||||
@@ -189,7 +270,10 @@ class HFRunner:
|
||||
)
|
||||
elif self.model_type == "embedding":
|
||||
assert not self.output_str_only
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
if "gme-qwen2-vl" in model_path.lower():
|
||||
logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
|
||||
else:
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
|
||||
elif self.model_type == "reward":
|
||||
@@ -211,11 +295,14 @@ class HFRunner:
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
image_data: Optional[List[str]] = None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
token_ids_logprob: Optional[int] = None,
|
||||
):
|
||||
self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
|
||||
self.in_queue.put(
|
||||
(prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob)
|
||||
)
|
||||
return self.out_queue.get()
|
||||
|
||||
def terminate(self):
|
||||
@@ -396,6 +483,7 @@ class SRTRunner:
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
image_data: Optional[List[str]] = None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
logprob_start_len: int = 0,
|
||||
@@ -413,17 +501,23 @@ class SRTRunner:
|
||||
token_ids_logprob=token_ids_logprob,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
response = self.engine.encode(prompt=prompts, image_data=image_data)
|
||||
if isinstance(response, list):
|
||||
logits = [x["embedding"] for x in response]
|
||||
else:
|
||||
logits = [response["embedding"]]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
# reward model
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
scores = [x["embedding"][0] for x in response]
|
||||
return ModelOutput(scores=scores)
|
||||
|
||||
def batch_forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
image_data: Optional[List[str]] = None,
|
||||
max_new_tokens=8,
|
||||
lora_paths=None,
|
||||
):
|
||||
@@ -439,7 +533,7 @@ class SRTRunner:
|
||||
lora_paths=lora_paths,
|
||||
)
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
response = self.engine.encode(prompts, image_data)
|
||||
if self.model_type == "embedding":
|
||||
logits = [x["embedding"] for x in response]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
|
||||
Reference in New Issue
Block a user