support clip embedding model (#4506)
This commit is contained in:
@@ -19,10 +19,16 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
)
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.server import Engine
|
||||
from sglang.srt.utils import load_image
|
||||
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
||||
|
||||
DEFAULT_PROMPTS = [
|
||||
@@ -140,7 +146,6 @@ class HFRunner:
|
||||
def _get_gme_qwen2_vl_embeddings(
|
||||
self, prompts, image_data: Optional[List[str]] = None
|
||||
):
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
images = None
|
||||
if image_data is not None:
|
||||
@@ -226,6 +231,9 @@ class HFRunner:
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
self.processor = AutoProcessor.from_pretrained(model_path)
|
||||
elif "clip" in model_path.lower():
|
||||
self.model = AutoModel.from_pretrained(model_path).cuda()
|
||||
self.processor = AutoProcessor.from_pretrained(model_path)
|
||||
else:
|
||||
self.model = _get_sentence_transformer_embedding_model(
|
||||
model_path, torch_dtype
|
||||
@@ -272,6 +280,23 @@ class HFRunner:
|
||||
assert not self.output_str_only
|
||||
if "gme-qwen2-vl" in model_path.lower():
|
||||
logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
|
||||
elif "clip" in model_path.lower():
|
||||
if image_data is not None:
|
||||
image = load_image(image_data)
|
||||
inputs = self.processor(
|
||||
images=image[0], return_tensors="pt"
|
||||
)
|
||||
logits = self.model.get_image_features(
|
||||
pixel_values=inputs.data["pixel_values"].cuda(),
|
||||
).tolist()
|
||||
else:
|
||||
inputs = self.tokenizer(
|
||||
prompts, padding=True, return_tensors="pt"
|
||||
)
|
||||
logits = self.model.get_text_features(
|
||||
input_ids=inputs.data["input_ids"].cuda(),
|
||||
attention_mask=inputs.data["attention_mask"].cuda(),
|
||||
).tolist()
|
||||
else:
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
|
||||
Reference in New Issue
Block a user