support clip embedding model (#4506)

This commit is contained in:
Pan Lyu
2025-03-27 15:18:15 +08:00
committed by GitHub
parent 1afe3d0798
commit c913ed4046
8 changed files with 746 additions and 9 deletions

View File

@@ -19,10 +19,16 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
)
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Engine
from sglang.srt.utils import load_image
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
DEFAULT_PROMPTS = [
@@ -140,7 +146,6 @@ class HFRunner:
def _get_gme_qwen2_vl_embeddings(
self, prompts, image_data: Optional[List[str]] = None
):
from sglang.srt.utils import load_image
images = None
if image_data is not None:
@@ -226,6 +231,9 @@ class HFRunner:
low_cpu_mem_usage=True,
).cuda()
self.processor = AutoProcessor.from_pretrained(model_path)
elif "clip" in model_path.lower():
self.model = AutoModel.from_pretrained(model_path).cuda()
self.processor = AutoProcessor.from_pretrained(model_path)
else:
self.model = _get_sentence_transformer_embedding_model(
model_path, torch_dtype
@@ -272,6 +280,23 @@ class HFRunner:
assert not self.output_str_only
if "gme-qwen2-vl" in model_path.lower():
logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
elif "clip" in model_path.lower():
if image_data is not None:
image = load_image(image_data)
inputs = self.processor(
images=image[0], return_tensors="pt"
)
logits = self.model.get_image_features(
pixel_values=inputs.data["pixel_values"].cuda(),
).tolist()
else:
inputs = self.tokenizer(
prompts, padding=True, return_tensors="pt"
)
logits = self.model.get_text_features(
input_ids=inputs.data["input_ids"].cuda(),
attention_mask=inputs.data["attention_mask"].cuda(),
).tolist()
else:
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))