Add openai embedding API (#997)

This commit is contained in:
Ying Sheng
2024-08-09 11:19:18 -07:00
committed by GitHub
parent 05c50a82b8
commit b16e856f11
8 changed files with 135 additions and 19 deletions

View File

@@ -194,7 +194,8 @@ class EmbeddingReqInput:
if is_single:
if self.rid is None:
self.rid = uuid.uuid4().hex
self.sampling_params = {"max_new_tokens": 0}
if self.sampling_params is None:
self.sampling_params = {"max_new_tokens": 1}
else:
# support select operation
self.batch_size = (
@@ -205,9 +206,10 @@ class EmbeddingReqInput:
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
self.sampling_params = [
{"max_new_tokens": 0} for _ in range(self.batch_size)
]
if self.sampling_params is None:
self.sampling_params = [
{"max_new_tokens": 1} for _ in range(self.batch_size)
]
@dataclass

View File

@@ -262,6 +262,7 @@ class TokenizerManager:
):
yield response
else:
assert self.is_generation
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
yield input_ids

View File

@@ -499,6 +499,8 @@ class ModelTpServer:
req.embedding = embeddings[i]
if req is not self.current_inflight_req:
# Inflight reqs' prefill is not finished
# dummy output token for embedding models
req.output_ids.append(0)
req.check_finished()
if req.finished():

View File

@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
generate_chat_conv,
register_conv_template,
)
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.openai_api.protocol import (
BatchRequest,
BatchResponse,
@@ -52,6 +52,7 @@ from sglang.srt.openai_api.protocol import (
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
EmbeddingObject,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
@@ -1016,10 +1017,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
def v1_embedding_request(all_requests, tokenizer_manager):
prompts = []
sampling_params_list = []
first_prompt_type = type(all_requests[0].prompt)
first_prompt_type = type(all_requests[0].input)
for request in all_requests:
prompt = request.prompt
prompt = request.input
assert (
type(prompt) == first_prompt_type
), "All prompts must be of the same type in file input settings"
@@ -1046,17 +1047,26 @@ def v1_embedding_request(all_requests, tokenizer_manager):
return adapted_request, all_requests
def v1_embedding_response(request, ret, to_file=False):
response = []
def v1_embedding_response(ret, model_path, to_file=False):
embedding_objects = []
prompt_tokens = 0
for idx, ret_item in enumerate(ret):
response.append(
EmbeddingResponse(
embedding_objects.append(
EmbeddingObject(
embedding=ret[idx]["embedding"],
index=idx,
embedding=ret[idx],
object="embedding",
)
)
return response
prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]
return EmbeddingResponse(
data=embedding_objects,
model=model_path,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens,
),
)
async def v1_embeddings(tokenizer_manager, raw_request: Request):
@@ -1074,7 +1084,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
if not isinstance(ret, list):
ret = [ret]
response = v1_embedding_response(request, ret)
response = v1_embedding_response(ret, tokenizer_manager.model_path)
return response

View File

@@ -319,8 +319,14 @@ class EmbeddingRequest(BaseModel):
user: Optional[str] = None
class EmbeddingResponse(BaseModel):
index: str
embedding: List[float] = None
class EmbeddingObject(BaseModel):
embedding: List[float]
index: int
object: str = "embedding"
class EmbeddingResponse(BaseModel):
data: List[EmbeddingObject]
model: str
object: str = "list"
usage: Optional[UsageInfo] = None

View File

@@ -60,6 +60,7 @@ from sglang.srt.openai_api.adapter import (
v1_chat_completions,
v1_completions,
v1_delete_file,
v1_embeddings,
v1_files_create,
v1_retrieve_batch,
v1_retrieve_file,
@@ -176,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request)
@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
response = await v1_embeddings(tokenizer_manager, raw_request)
return response
@app.get("/v1/models")
def available_models():
"""Show available models."""
@@ -412,7 +419,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 0
max_new_tokens = 8 if model_info["is_generation"] else 1
try:
for _ in range(server_args.dp_size):
res = requests.post(