Add support for Qwen2-VL-based embedding models (#2055)

This commit is contained in:
James Xu
2024-11-21 17:24:25 -05:00
committed by GitHub
parent f35cb46cc3
commit f6f713797b
4 changed files with 39 additions and 12 deletions

View File

@@ -58,6 +58,28 @@ def get_top_logprobs(logits, k):
return logprobs
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import is_sentence_transformer_model
if is_sentence_transformer_model(model_path):
model = SentenceTransformer(
model_path,
model_kwargs={"torch_dtype": torch_dtype},
)
else: # if no pre-trained sentence-transformers model
from sentence_transformers import models
word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype)
pooling_model = models.Pooling(
word_embedding_model.get_word_embedding_dimension(),
pooling_mode="lasttoken",
)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
return model.cuda()
@dataclass
class ModelOutput:
output_strs: List[str] = None
@@ -114,12 +136,9 @@ class HFRunner:
low_cpu_mem_usage=True,
).cuda()
elif self.model_type == "embedding":
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_path,
model_kwargs={"torch_dtype": torch_dtype},
).cuda()
self.model = _get_sentence_transformer_embedding_model(
model_path, torch_dtype
)
elif self.model_type == "reward":
from transformers import AutoModelForSequenceClassification