Add e5-mistral embedding model - step 3/3 (#988)
This commit is contained in:
@@ -52,7 +52,7 @@ from sglang.srt.managers.controller_single import (
|
||||
start_controller_process as start_controller_process_single,
|
||||
)
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.openai_api.adapter import (
|
||||
load_chat_template_for_openai_api,
|
||||
@@ -97,6 +97,7 @@ async def health() -> Response:
|
||||
async def get_model_info():
|
||||
result = {
|
||||
"model_path": tokenizer_manager.model_path,
|
||||
"is_generation": tokenizer_manager.is_generation,
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -148,6 +149,21 @@ app.post("/generate")(generate_request)
|
||||
app.put("/generate")(generate_request)
|
||||
|
||||
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle an embedding request."""
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
app.post("/encode")(encode_request)
|
||||
app.put("/encode")(encode_request)
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_v1_completions(raw_request: Request):
|
||||
return await v1_completions(tokenizer_manager, raw_request)
|
||||
@@ -380,6 +396,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
except (AssertionError, requests.exceptions.RequestException) as e:
|
||||
last_traceback = get_exception_traceback()
|
||||
pass
|
||||
model_info = res.json()
|
||||
|
||||
if not success:
|
||||
if pipe_finish_writer is not None:
|
||||
@@ -388,15 +405,17 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
sys.exit(1)
|
||||
|
||||
# Send a warmup request
|
||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||
max_new_tokens = 8 if model_info["is_generation"] else 0
|
||||
try:
|
||||
for _ in range(server_args.dp_size):
|
||||
res = requests.post(
|
||||
url + "/generate",
|
||||
url + request_name,
|
||||
json={
|
||||
"text": "The capital city of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 8,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
@@ -529,5 +548,18 @@ class Runtime:
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompt: str,
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
}
|
||||
response = requests.post(
|
||||
self.url + "/encode",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user