Add openai embedding API (#997)
This commit is contained in:
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
|
||||
generate_chat_conv,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
BatchRequest,
|
||||
BatchResponse,
|
||||
@@ -52,6 +52,7 @@ from sglang.srt.openai_api.protocol import (
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
EmbeddingObject,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ErrorResponse,
|
||||
@@ -1016,10 +1017,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
def v1_embedding_request(all_requests, tokenizer_manager):
|
||||
prompts = []
|
||||
sampling_params_list = []
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
first_prompt_type = type(all_requests[0].input)
|
||||
|
||||
for request in all_requests:
|
||||
prompt = request.prompt
|
||||
prompt = request.input
|
||||
assert (
|
||||
type(prompt) == first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
@@ -1046,17 +1047,26 @@ def v1_embedding_request(all_requests, tokenizer_manager):
|
||||
return adapted_request, all_requests
|
||||
|
||||
|
||||
def v1_embedding_response(request, ret, to_file=False):
|
||||
response = []
|
||||
def v1_embedding_response(ret, model_path, to_file=False):
|
||||
embedding_objects = []
|
||||
prompt_tokens = 0
|
||||
for idx, ret_item in enumerate(ret):
|
||||
response.append(
|
||||
EmbeddingResponse(
|
||||
embedding_objects.append(
|
||||
EmbeddingObject(
|
||||
embedding=ret[idx]["embedding"],
|
||||
index=idx,
|
||||
embedding=ret[idx],
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
return response
|
||||
prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]
|
||||
|
||||
return EmbeddingResponse(
|
||||
data=embedding_objects,
|
||||
model=model_path,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
total_tokens=prompt_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
||||
@@ -1074,7 +1084,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
|
||||
response = v1_embedding_response(request, ret)
|
||||
response = v1_embedding_response(ret, tokenizer_manager.model_path)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -319,8 +319,14 @@ class EmbeddingRequest(BaseModel):
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
index: str
|
||||
embedding: List[float] = None
|
||||
class EmbeddingObject(BaseModel):
|
||||
embedding: List[float]
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
data: List[EmbeddingObject]
|
||||
model: str
|
||||
object: str = "list"
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
Reference in New Issue
Block a user