Add Support for Qwen2-VL Multi-modal Embedding Models (#3694)

This commit is contained in:
Pan Lyu
2025-03-07 08:46:20 +08:00
committed by GitHub
parent 13bc39c5d6
commit 361971b859
11 changed files with 356 additions and 34 deletions

View File

@@ -19,7 +19,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Engine
@@ -135,6 +135,76 @@ class HFRunner:
return True
return False
# copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py
def _get_gme_qwen2_vl_embeddings(
self, prompts, image_data: Optional[List[str]] = None
):
from sglang.srt.utils import load_image
images = None
if image_data is not None:
images = [load_image(image)[0] for image in image_data]
inputs = self.processor(
text=prompts,
images=images,
padding=True,
truncation=True,
max_length=1800,
return_tensors="pt",
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
embeddings = self._forward_gme_qwen2_vl(**inputs)
return embeddings.tolist()
def _forward_gme_qwen2_vl(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
pooling_mask: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.model.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.model.visual.get_dtype())
image_embeds = self.model.visual(
pixel_values, grid_thw=image_grid_thw
).to(inputs_embeds.device)
image_mask = input_ids == self.model.config.image_token_id
inputs_embeds[image_mask] = image_embeds
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
outputs = self.model.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
pooling_mask = attention_mask if pooling_mask is None else pooling_mask
left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO
if left_padding:
embeddings = outputs.last_hidden_state[:, -1]
else:
sequence_lengths = pooling_mask.sum(dim=1) - 1
batch_size = outputs.last_hidden_state.shape[0]
embeddings = outputs.last_hidden_state[
torch.arange(batch_size, device=outputs.last_hidden_state.device),
sequence_lengths,
]
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.contiguous()
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
# Apply model-specific patches
monkey_patch_gemma2_sdpa()
@@ -148,9 +218,18 @@ class HFRunner:
low_cpu_mem_usage=True,
).cuda()
elif self.model_type == "embedding":
self.model = _get_sentence_transformer_embedding_model(
model_path, torch_dtype
)
if "gme-qwen2-vl" in model_path.lower():
self.model = AutoModelForVision2Seq.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
low_cpu_mem_usage=True,
).cuda()
self.processor = AutoProcessor.from_pretrained(model_path)
else:
self.model = _get_sentence_transformer_embedding_model(
model_path, torch_dtype
)
elif self.model_type == "reward":
from transformers import AutoModelForSequenceClassification
@@ -169,7 +248,9 @@ class HFRunner:
# Run forward
while True:
prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = (
in_queue.get()
)
if lora_paths is not None:
assert len(prompts) == len(lora_paths)
@@ -189,7 +270,10 @@ class HFRunner:
)
elif self.model_type == "embedding":
assert not self.output_str_only
logits = self.model.encode(prompts).tolist()
if "gme-qwen2-vl" in model_path.lower():
logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
else:
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
elif self.model_type == "reward":
@@ -211,11 +295,14 @@ class HFRunner:
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
token_ids_logprob: Optional[int] = None,
):
self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
self.in_queue.put(
(prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob)
)
return self.out_queue.get()
def terminate(self):
@@ -396,6 +483,7 @@ class SRTRunner:
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
logprob_start_len: int = 0,
@@ -413,17 +501,23 @@ class SRTRunner:
token_ids_logprob=token_ids_logprob,
)
else:
response = self.engine.encode(prompts)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
response = self.engine.encode(prompt=prompts, image_data=image_data)
if isinstance(response, list):
logits = [x["embedding"] for x in response]
else:
logits = [response["embedding"]]
return ModelOutput(embed_logits=logits)
# reward model
else:
response = self.engine.encode(prompts)
scores = [x["embedding"][0] for x in response]
return ModelOutput(scores=scores)
def batch_forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None,
max_new_tokens=8,
lora_paths=None,
):
@@ -439,7 +533,7 @@ class SRTRunner:
lora_paths=lora_paths,
)
else:
response = self.engine.encode(prompts)
response = self.engine.encode(prompts, image_data)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)