Add e5-mistral embedding model - step 3/3 (#988)

This commit is contained in:
Ying Sheng
2024-08-08 16:31:19 -07:00
committed by GitHub
parent 9f662501a3
commit e040a2450b
14 changed files with 474 additions and 241 deletions

View File

@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api,
@@ -97,6 +97,7 @@ async def health() -> Response:
async def get_model_info():
result = {
"model_path": tokenizer_manager.model_path,
"is_generation": tokenizer_manager.is_generation,
}
return result
@@ -148,6 +149,21 @@ app.post("/generate")(generate_request)
app.put("/generate")(generate_request)
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
app.post("/encode")(encode_request)
app.put("/encode")(encode_request)
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
return await v1_completions(tokenizer_manager, raw_request)
@@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
except (AssertionError, requests.exceptions.RequestException) as e:
last_traceback = get_exception_traceback()
pass
model_info = res.json()
if not success:
if pipe_finish_writer is not None:
@@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
sys.exit(1)
# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 0
try:
for _ in range(server_args.dp_size):
res = requests.post(
url + "/generate",
url + request_name,
json={
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 8,
"max_new_tokens": max_new_tokens,
},
},
headers=headers,
@@ -529,5 +548,18 @@ class Runtime:
)
return json.dumps(response.json())
def encode(
self,
prompt: str,
):
json_data = {
"text": prompt,
}
response = requests.post(
self.url + "/encode",
json=json_data,
)
return json.dumps(response.json())
def __del__(self):
self.shutdown()