Add e5-mistral embedding model - step 3/3 (#988)
This commit is contained in:
@@ -23,6 +23,7 @@ import torch.nn.functional as F
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from sglang.srt.server import Runtime
|
||||
from sglang.srt.utils import is_generation_model
|
||||
|
||||
DEFAULT_PROMPTS = [
|
||||
"The capital of France is",
|
||||
@@ -33,13 +34,6 @@ DEFAULT_PROMPTS = [
|
||||
NUM_TOP_LOGPROBS = 5
|
||||
|
||||
|
||||
def is_embedding_model(model_path):
|
||||
# FIXME incomplete list
|
||||
if "e5-mistral-7b-instruct" in model_path.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_dtype_str(torch_dtype):
|
||||
if torch_dtype is torch.float16:
|
||||
return "float16"
|
||||
@@ -60,7 +54,7 @@ class HFRunner:
|
||||
self,
|
||||
model_path,
|
||||
torch_dtype=torch.float16,
|
||||
is_embedding_model=None,
|
||||
is_generation_model=None,
|
||||
):
|
||||
self.in_queue = multiprocessing.Queue()
|
||||
self.out_queue = multiprocessing.Queue()
|
||||
@@ -72,13 +66,13 @@ class HFRunner:
|
||||
self.out_queue,
|
||||
model_path,
|
||||
torch_dtype,
|
||||
is_embedding_model,
|
||||
is_generation_model,
|
||||
),
|
||||
)
|
||||
self.model_proc.start()
|
||||
|
||||
def start_model_process(
|
||||
self, in_queue, out_queue, model_path, torch_dtype, is_embedding_model
|
||||
self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
|
||||
):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
@@ -86,12 +80,12 @@ class HFRunner:
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
self.is_embedding_model = (
|
||||
is_embedding_model(model_path)
|
||||
if is_embedding_model is None
|
||||
else is_embedding_model
|
||||
self.is_generation_model = (
|
||||
is_generation_model(model_path)
|
||||
if is_generation_model is None
|
||||
else is_generation_model
|
||||
)
|
||||
if not self.is_embedding_model:
|
||||
if self.is_generation_model:
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
@@ -103,13 +97,13 @@ class HFRunner:
|
||||
|
||||
self.model = SentenceTransformer(
|
||||
model_path,
|
||||
device="cpu",
|
||||
).to(dtype=torch_dtype)
|
||||
model_kwargs={"torch_dtype": torch_dtype},
|
||||
)
|
||||
|
||||
while True:
|
||||
prompts, max_new_tokens = in_queue.get()
|
||||
if prompts is not None:
|
||||
if not self.is_embedding_model:
|
||||
if self.is_generation_model:
|
||||
output_strs = []
|
||||
prefill_logprobs = []
|
||||
for p in prompts:
|
||||
@@ -144,7 +138,6 @@ class HFRunner:
|
||||
)
|
||||
|
||||
else:
|
||||
assert isinstance(prompts, List[str])
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
@@ -175,16 +168,13 @@ class SRTRunner:
|
||||
model_path,
|
||||
tp_size=1,
|
||||
torch_dtype=torch.float16,
|
||||
is_embedding_model=None,
|
||||
is_generation_model=None,
|
||||
):
|
||||
self.is_embedding_model = (
|
||||
is_embedding_model(model_path)
|
||||
if is_embedding_model is None
|
||||
else is_embedding_model
|
||||
self.is_generation_model = (
|
||||
is_generation_model(model_path)
|
||||
if is_generation_model is None
|
||||
else is_generation_model
|
||||
)
|
||||
if self.is_embedding_model:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.runtime = Runtime(
|
||||
model_path=model_path,
|
||||
tp_size=tp_size,
|
||||
@@ -196,38 +186,45 @@ class SRTRunner:
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
max_new_tokens=64,
|
||||
):
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for prompt in prompts:
|
||||
response = self.runtime.generate(
|
||||
prompt,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs.append(response["text"])
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||
]
|
||||
+ [
|
||||
if self.is_generation_model:
|
||||
# the return value contains logprobs from prefill
|
||||
output_strs = []
|
||||
top_input_logprobs = []
|
||||
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
||||
for prompt in prompts:
|
||||
response = self.runtime.generate(
|
||||
prompt,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=True,
|
||||
top_logprobs_num=NUM_TOP_LOGPROBS,
|
||||
)
|
||||
response = json.loads(response)
|
||||
output_strs.append(response["text"])
|
||||
top_input_logprobs.append(
|
||||
[
|
||||
tup[0]
|
||||
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||
:NUM_TOP_LOGPROBS
|
||||
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
||||
for x in response["meta_info"]["input_top_logprobs"][1:]
|
||||
]
|
||||
+ [
|
||||
[
|
||||
tup[0]
|
||||
for tup in response["meta_info"]["output_top_logprobs"][0][
|
||||
:NUM_TOP_LOGPROBS
|
||||
]
|
||||
]
|
||||
]
|
||||
]
|
||||
)
|
||||
# print(response["meta_info"]["output_top_logprobs"][0])
|
||||
)
|
||||
|
||||
return ModelOutput(
|
||||
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
||||
)
|
||||
return ModelOutput(
|
||||
output_strs=output_strs, top_input_logprobs=top_input_logprobs
|
||||
)
|
||||
else:
|
||||
logits = []
|
||||
for prompt in prompts:
|
||||
response = self.runtime.encode(prompt)
|
||||
response = json.loads(response)
|
||||
logits.append(response["embedding"])
|
||||
return ModelOutput(embed_logits=logits)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@@ -12,6 +12,8 @@ from typing import Callable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.openai import OpenAI
|
||||
@@ -492,3 +494,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
|
||||
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
|
||||
|
||||
return 0 if success else -1
|
||||
|
||||
|
||||
def get_similarities(vec1, vec2):
|
||||
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
|
||||
|
||||
Reference in New Issue
Block a user