Add e5-mistral embedding model - step 3/3 (#988)

This commit is contained in:
Ying Sheng
2024-08-08 16:31:19 -07:00
committed by GitHub
parent 9f662501a3
commit e040a2450b
14 changed files with 474 additions and 241 deletions

View File

@@ -23,6 +23,7 @@ import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime
from sglang.srt.utils import is_generation_model
DEFAULT_PROMPTS = [
"The capital of France is",
@@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [
NUM_TOP_LOGPROBS = 5
def is_embedding_model(model_path):
# FIXME incomplete list
if "e5-mistral-7b-instruct" in model_path.lower():
return True
return False
def get_dtype_str(torch_dtype):
if torch_dtype is torch.float16:
return "float16"
@@ -60,7 +54,7 @@ class HFRunner:
self,
model_path,
torch_dtype=torch.float16,
is_embedding_model=None,
is_generation_model=None,
):
self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue()
@@ -72,13 +66,13 @@ class HFRunner:
self.out_queue,
model_path,
torch_dtype,
is_embedding_model,
is_generation_model,
),
)
self.model_proc.start()
def start_model_process(
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
@@ -86,12 +80,12 @@ class HFRunner:
trust_remote_code=True,
)
self.is_embedding_model = (
is_embedding_model(model_path)
if is_embedding_model is None
else is_embedding_model
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
if not self.is_embedding_model:
if self.is_generation_model:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
@@ -103,13 +97,13 @@ class HFRunner:
self.model = SentenceTransformer(
model_path,
device="cpu",
).to(dtype=torch_dtype)
model_kwargs={"torch_dtype": torch_dtype},
)
while True:
prompts, max_new_tokens = in_queue.get()
if prompts is not None:
if not self.is_embedding_model:
if self.is_generation_model:
output_strs = []
prefill_logprobs = []
for p in prompts:
@@ -144,7 +138,6 @@ class HFRunner:
)
else:
assert isinstance(prompts, List[str])
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
@@ -175,16 +168,13 @@ class SRTRunner:
model_path,
tp_size=1,
torch_dtype=torch.float16,
is_embedding_model=None,
is_generation_model=None,
):
self.is_embedding_model = (
is_embedding_model(model_path)
if is_embedding_model is None
else is_embedding_model
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
if self.is_embedding_model:
raise NotImplementedError()
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,
@@ -196,38 +186,45 @@ class SRTRunner:
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=64,
):
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for prompt in prompts:
response = self.runtime.generate(
prompt,
sampling_params=sampling_params,
return_logprob=True,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
response = json.loads(response)
output_strs.append(response["text"])
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
if self.is_generation_model:
# the return value contains logprobs from prefill
output_strs = []
top_input_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for prompt in prompts:
response = self.runtime.generate(
prompt,
sampling_params=sampling_params,
return_logprob=True,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
response = json.loads(response)
output_strs.append(response["text"])
top_input_logprobs.append(
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
+ [
[
tup[0]
for tup in response["meta_info"]["output_top_logprobs"][0][
:NUM_TOP_LOGPROBS
]
]
]
]
)
# print(response["meta_info"]["output_top_logprobs"][0])
)
return ModelOutput(
output_strs=output_strs, top_input_logprobs=top_input_logprobs
)
return ModelOutput(
output_strs=output_strs, top_input_logprobs=top_input_logprobs
)
else:
logits = []
for prompt in prompts:
response = self.runtime.encode(prompt)
response = json.loads(response)
logits.append(response["embedding"])
return ModelOutput(embed_logits=logits)
def __enter__(self):
return self

View File

@@ -12,6 +12,8 @@ from typing import Callable, List, Optional
import numpy as np
import requests
import torch
import torch.nn.functional as F
from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
@@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
return 0 if success else -1
def get_similarities(vec1, vec2):
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)