Add openai embedding API (#997)
This commit is contained in:
@@ -194,7 +194,8 @@ class EmbeddingReqInput:
|
|||||||
if is_single:
|
if is_single:
|
||||||
if self.rid is None:
|
if self.rid is None:
|
||||||
self.rid = uuid.uuid4().hex
|
self.rid = uuid.uuid4().hex
|
||||||
self.sampling_params = {"max_new_tokens": 0}
|
if self.sampling_params is None:
|
||||||
|
self.sampling_params = {"max_new_tokens": 1}
|
||||||
else:
|
else:
|
||||||
# support select operation
|
# support select operation
|
||||||
self.batch_size = (
|
self.batch_size = (
|
||||||
@@ -205,9 +206,10 @@ class EmbeddingReqInput:
|
|||||||
else:
|
else:
|
||||||
if not isinstance(self.rid, list):
|
if not isinstance(self.rid, list):
|
||||||
raise ValueError("The rid should be a list.")
|
raise ValueError("The rid should be a list.")
|
||||||
self.sampling_params = [
|
if self.sampling_params is None:
|
||||||
{"max_new_tokens": 0} for _ in range(self.batch_size)
|
self.sampling_params = [
|
||||||
]
|
{"max_new_tokens": 1} for _ in range(self.batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -262,6 +262,7 @@ class TokenizerManager:
|
|||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
|
assert self.is_generation
|
||||||
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
|
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
|
||||||
yield input_ids
|
yield input_ids
|
||||||
|
|
||||||
|
|||||||
@@ -499,6 +499,8 @@ class ModelTpServer:
|
|||||||
req.embedding = embeddings[i]
|
req.embedding = embeddings[i]
|
||||||
if req is not self.current_inflight_req:
|
if req is not self.current_inflight_req:
|
||||||
# Inflight reqs' prefill is not finished
|
# Inflight reqs' prefill is not finished
|
||||||
|
# dummy output token for embedding models
|
||||||
|
req.output_ids.append(0)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
|
|||||||
generate_chat_conv,
|
generate_chat_conv,
|
||||||
register_conv_template,
|
register_conv_template,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||||
from sglang.srt.openai_api.protocol import (
|
from sglang.srt.openai_api.protocol import (
|
||||||
BatchRequest,
|
BatchRequest,
|
||||||
BatchResponse,
|
BatchResponse,
|
||||||
@@ -52,6 +52,7 @@ from sglang.srt.openai_api.protocol import (
|
|||||||
CompletionResponseStreamChoice,
|
CompletionResponseStreamChoice,
|
||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
DeltaMessage,
|
DeltaMessage,
|
||||||
|
EmbeddingObject,
|
||||||
EmbeddingRequest,
|
EmbeddingRequest,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -1016,10 +1017,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
def v1_embedding_request(all_requests, tokenizer_manager):
|
def v1_embedding_request(all_requests, tokenizer_manager):
|
||||||
prompts = []
|
prompts = []
|
||||||
sampling_params_list = []
|
sampling_params_list = []
|
||||||
first_prompt_type = type(all_requests[0].prompt)
|
first_prompt_type = type(all_requests[0].input)
|
||||||
|
|
||||||
for request in all_requests:
|
for request in all_requests:
|
||||||
prompt = request.prompt
|
prompt = request.input
|
||||||
assert (
|
assert (
|
||||||
type(prompt) == first_prompt_type
|
type(prompt) == first_prompt_type
|
||||||
), "All prompts must be of the same type in file input settings"
|
), "All prompts must be of the same type in file input settings"
|
||||||
@@ -1046,17 +1047,26 @@ def v1_embedding_request(all_requests, tokenizer_manager):
|
|||||||
return adapted_request, all_requests
|
return adapted_request, all_requests
|
||||||
|
|
||||||
|
|
||||||
def v1_embedding_response(request, ret, to_file=False):
|
def v1_embedding_response(ret, model_path, to_file=False):
|
||||||
response = []
|
embedding_objects = []
|
||||||
|
prompt_tokens = 0
|
||||||
for idx, ret_item in enumerate(ret):
|
for idx, ret_item in enumerate(ret):
|
||||||
response.append(
|
embedding_objects.append(
|
||||||
EmbeddingResponse(
|
EmbeddingObject(
|
||||||
|
embedding=ret[idx]["embedding"],
|
||||||
index=idx,
|
index=idx,
|
||||||
embedding=ret[idx],
|
|
||||||
object="embedding",
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return response
|
prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]
|
||||||
|
|
||||||
|
return EmbeddingResponse(
|
||||||
|
data=embedding_objects,
|
||||||
|
model=model_path,
|
||||||
|
usage=UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
total_tokens=prompt_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
||||||
@@ -1074,7 +1084,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request):
|
|||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
|
|
||||||
response = v1_embedding_response(request, ret)
|
response = v1_embedding_response(ret, tokenizer_manager.model_path)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@@ -319,8 +319,14 @@ class EmbeddingRequest(BaseModel):
|
|||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingResponse(BaseModel):
|
class EmbeddingObject(BaseModel):
|
||||||
index: str
|
embedding: List[float]
|
||||||
embedding: List[float] = None
|
index: int
|
||||||
object: str = "embedding"
|
object: str = "embedding"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponse(BaseModel):
|
||||||
|
data: List[EmbeddingObject]
|
||||||
|
model: str
|
||||||
|
object: str = "list"
|
||||||
usage: Optional[UsageInfo] = None
|
usage: Optional[UsageInfo] = None
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ from sglang.srt.openai_api.adapter import (
|
|||||||
v1_chat_completions,
|
v1_chat_completions,
|
||||||
v1_completions,
|
v1_completions,
|
||||||
v1_delete_file,
|
v1_delete_file,
|
||||||
|
v1_embeddings,
|
||||||
v1_files_create,
|
v1_files_create,
|
||||||
v1_retrieve_batch,
|
v1_retrieve_batch,
|
||||||
v1_retrieve_file,
|
v1_retrieve_file,
|
||||||
@@ -176,6 +177,12 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|||||||
return await v1_chat_completions(tokenizer_manager, raw_request)
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/embeddings")
|
||||||
|
async def openai_v1_embeddings(raw_request: Request):
|
||||||
|
response = await v1_embeddings(tokenizer_manager, raw_request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
def available_models():
|
def available_models():
|
||||||
"""Show available models."""
|
"""Show available models."""
|
||||||
@@ -412,7 +419,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||||
max_new_tokens = 8 if model_info["is_generation"] else 0
|
max_new_tokens = 8 if model_info["is_generation"] else 1
|
||||||
try:
|
try:
|
||||||
for _ in range(server_args.dp_size):
|
for _ in range(server_args.dp_size):
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from sglang.test.test_utils import run_unittest_files
|
|||||||
suites = {
|
suites = {
|
||||||
"minimal": [
|
"minimal": [
|
||||||
"test_eval_accuracy.py",
|
"test_eval_accuracy.py",
|
||||||
|
"test_embedding_openai_server.py",
|
||||||
"test_openai_server.py",
|
"test_openai_server.py",
|
||||||
"test_vision_openai_server.py",
|
"test_vision_openai_server.py",
|
||||||
"test_chunked_prefill.py",
|
"test_chunked_prefill.py",
|
||||||
|
|||||||
87
test/srt/test_embedding_openai_server.py
Normal file
87
test/srt/test_embedding_openai_server.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.openai_api.protocol import EmbeddingObject
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.test_utils import popen_launch_server
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIServer(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "intfloat/e5-mistral-7b-instruct"
|
||||||
|
cls.base_url = "http://127.0.0.1:8157"
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
cls.tokenizer = get_tokenizer(cls.model)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
|
def run_embedding(self, use_list_input, token_input):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
prompt = "The capital of France is"
|
||||||
|
if token_input:
|
||||||
|
prompt_input = self.tokenizer.encode(prompt)
|
||||||
|
num_prompt_tokens = len(prompt_input)
|
||||||
|
else:
|
||||||
|
prompt_input = prompt
|
||||||
|
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||||
|
|
||||||
|
if use_list_input:
|
||||||
|
prompt_arg = [prompt_input, prompt_input]
|
||||||
|
num_prompts = len(prompt_arg)
|
||||||
|
else:
|
||||||
|
prompt_arg = prompt_input
|
||||||
|
num_prompts = 1
|
||||||
|
|
||||||
|
response = client.embeddings.create(
|
||||||
|
input=prompt_arg,
|
||||||
|
model=self.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.data) == num_prompts
|
||||||
|
assert isinstance(response.data, list)
|
||||||
|
assert response.data[0].embedding
|
||||||
|
assert response.data[0].index is not None
|
||||||
|
assert response.data[0].object == "embedding"
|
||||||
|
assert response.model == self.model
|
||||||
|
assert response.object == "list"
|
||||||
|
assert (
|
||||||
|
response.usage.prompt_tokens == num_prompt_tokens
|
||||||
|
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
|
||||||
|
assert (
|
||||||
|
response.usage.total_tokens == num_prompt_tokens
|
||||||
|
), f"{response.usage.total_tokens} vs {num_prompt_tokens}"
|
||||||
|
|
||||||
|
def run_batch(self):
|
||||||
|
# FIXME not implemented
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_embedding(self):
|
||||||
|
# TODO the fields of encoding_format, dimensions, user are skipped
|
||||||
|
# TODO support use_list_input
|
||||||
|
for use_list_input in [False]:
|
||||||
|
for token_input in [False, True]:
|
||||||
|
self.run_embedding(use_list_input, token_input)
|
||||||
|
|
||||||
|
def test_batch(self):
|
||||||
|
self.run_batch()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(warnings="ignore")
|
||||||
|
|
||||||
|
# t = TestOpenAIServer()
|
||||||
|
# t.setUpClass()
|
||||||
|
# t.test_embedding()
|
||||||
|
# t.tearDownClass()
|
||||||
Reference in New Issue
Block a user