Add support for Qwen2-VL-based embedding models (#2055)
This commit is contained in:
@@ -58,6 +58,28 @@ def get_top_logprobs(logits, k):
|
||||
return logprobs
|
||||
|
||||
|
||||
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers.util import is_sentence_transformer_model
|
||||
|
||||
if is_sentence_transformer_model(model_path):
|
||||
model = SentenceTransformer(
|
||||
model_path,
|
||||
model_kwargs={"torch_dtype": torch_dtype},
|
||||
)
|
||||
else: # if no pre-trained sentence-transformers model
|
||||
from sentence_transformers import models
|
||||
|
||||
word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype)
|
||||
pooling_model = models.Pooling(
|
||||
word_embedding_model.get_word_embedding_dimension(),
|
||||
pooling_mode="lasttoken",
|
||||
)
|
||||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
|
||||
|
||||
return model.cuda()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelOutput:
|
||||
output_strs: List[str] = None
|
||||
@@ -114,12 +136,9 @@ class HFRunner:
|
||||
low_cpu_mem_usage=True,
|
||||
).cuda()
|
||||
elif self.model_type == "embedding":
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
self.model = SentenceTransformer(
|
||||
model_path,
|
||||
model_kwargs={"torch_dtype": torch_dtype},
|
||||
).cuda()
|
||||
self.model = _get_sentence_transformer_embedding_model(
|
||||
model_path, torch_dtype
|
||||
)
|
||||
elif self.model_type == "reward":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
|
||||
Reference in New Issue
Block a user