Feat/refactor embedding server (#7322)
This commit is contained in:
@@ -40,9 +40,10 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
register_disaggregation_server,
|
register_disaggregation_server,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
|
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
|
||||||
|
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
add_prometheus_middleware,
|
add_prometheus_middleware,
|
||||||
@@ -64,6 +65,7 @@ class AppState:
|
|||||||
server_args: Optional[ServerArgs] = None
|
server_args: Optional[ServerArgs] = None
|
||||||
tokenizer_manager: Optional[TokenizerManager] = None
|
tokenizer_manager: Optional[TokenizerManager] = None
|
||||||
scheduler_info: Optional[Dict] = None
|
scheduler_info: Optional[Dict] = None
|
||||||
|
embedding_server: Optional[OpenAIServingEmbedding] = None
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -78,6 +80,9 @@ async def lifespan(app: FastAPI):
|
|||||||
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
|
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
|
||||||
app.state.tokenizer_manager = tokenizer_manager
|
app.state.tokenizer_manager = tokenizer_manager
|
||||||
app.state.scheduler_info = scheduler_info
|
app.state.scheduler_info = scheduler_info
|
||||||
|
app.state.serving_embedding = OpenAIServingEmbedding(
|
||||||
|
tokenizer_manager=tokenizer_manager
|
||||||
|
)
|
||||||
|
|
||||||
if server_args.enable_metrics:
|
if server_args.enable_metrics:
|
||||||
add_prometheus_middleware(app)
|
add_prometheus_middleware(app)
|
||||||
@@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|||||||
|
|
||||||
@app.post("/v1/embeddings")
|
@app.post("/v1/embeddings")
|
||||||
async def openai_v1_embeddings(raw_request: Request):
|
async def openai_v1_embeddings(raw_request: Request):
|
||||||
pass
|
try:
|
||||||
|
request_json = await raw_request.json()
|
||||||
|
request = EmbeddingRequest(**request_json)
|
||||||
|
except Exception as e:
|
||||||
|
return app.state.serving_embedding.create_error_response(
|
||||||
|
f"Invalid request body, error: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = await app.state.serving_embedding.handle_request(request, raw_request)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/score")
|
@app.post("/v1/score")
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class OpenAIServingBase(ABC):
|
|||||||
|
|
||||||
# Convert to internal format
|
# Convert to internal format
|
||||||
adapted_request, processed_request = self._convert_to_internal_request(
|
adapted_request, processed_request = self._convert_to_internal_request(
|
||||||
[request], [self._generate_request_id_base(request)]
|
request, self._generate_request_id_base(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
|
# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
|
||||||
@@ -73,8 +73,8 @@ class OpenAIServingBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _convert_to_internal_request(
|
def _convert_to_internal_request(
|
||||||
self,
|
self,
|
||||||
all_requests: List[OpenAIServingRequest],
|
request: OpenAIServingRequest,
|
||||||
request_ids: List[str],
|
request_id: str,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
|
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
|
||||||
]:
|
]:
|
||||||
|
|||||||
@@ -71,111 +71,61 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
|||||||
|
|
||||||
def _convert_to_internal_request(
|
def _convert_to_internal_request(
|
||||||
self,
|
self,
|
||||||
all_requests: List[EmbeddingRequest],
|
request: EmbeddingRequest,
|
||||||
request_ids: List[str],
|
request_id: str,
|
||||||
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
|
) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]:
|
||||||
"""Convert OpenAI embedding request to internal format"""
|
"""Convert OpenAI embedding request to internal format"""
|
||||||
prompts = [request.input for request in all_requests]
|
prompt = request.input
|
||||||
|
if isinstance(prompt, str):
|
||||||
# Handle single vs multiple requests
|
# Single string input
|
||||||
if len(all_requests) == 1:
|
prompt_kwargs = {"text": prompt}
|
||||||
prompt = prompts[0]
|
elif isinstance(prompt, list):
|
||||||
if isinstance(prompt, str):
|
if len(prompt) > 0 and isinstance(prompt[0], str):
|
||||||
# Single string input
|
# List of strings
|
||||||
prompt_kwargs = {"text": prompt}
|
prompt_kwargs = {"text": prompt}
|
||||||
elif isinstance(prompt, list):
|
elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput):
|
||||||
if len(prompt) > 0 and isinstance(prompt[0], str):
|
# Handle multimodal embedding inputs
|
||||||
# List of strings
|
texts = []
|
||||||
prompt_kwargs = {"text": prompt}
|
images = []
|
||||||
elif len(prompt) > 0 and isinstance(
|
for item in prompt:
|
||||||
prompt[0], MultimodalEmbeddingInput
|
# Use padding for text if None - this could be improved
|
||||||
):
|
texts.append(item.text if item.text is not None else "padding")
|
||||||
# Handle multimodal embedding inputs
|
images.append(item.image if item.image is not None else None)
|
||||||
texts = []
|
|
||||||
images = []
|
|
||||||
for item in prompt:
|
|
||||||
# Use padding for text if None - this could be improved
|
|
||||||
texts.append(item.text if item.text is not None else "padding")
|
|
||||||
images.append(item.image if item.image is not None else None)
|
|
||||||
|
|
||||||
generate_prompts = []
|
generate_prompts = []
|
||||||
# Check if we have a chat template for multimodal embeddings
|
# Check if we have a chat template for multimodal embeddings
|
||||||
# This would need to be passed in from the server configuration
|
# This would need to be passed in from the server configuration
|
||||||
chat_template_name = getattr(
|
chat_template_name = getattr(
|
||||||
self.tokenizer_manager, "chat_template_name", None
|
self.tokenizer_manager, "chat_template_name", None
|
||||||
)
|
)
|
||||||
if chat_template_name is not None:
|
if chat_template_name is not None:
|
||||||
convs = generate_embedding_convs(
|
convs = generate_embedding_convs(texts, images, chat_template_name)
|
||||||
texts, images, chat_template_name
|
for conv in convs:
|
||||||
)
|
generate_prompts.append(conv.get_prompt())
|
||||||
for conv in convs:
|
|
||||||
generate_prompts.append(conv.get_prompt())
|
|
||||||
else:
|
|
||||||
generate_prompts = texts
|
|
||||||
|
|
||||||
if len(generate_prompts) == 1:
|
|
||||||
prompt_kwargs = {
|
|
||||||
"text": generate_prompts[0],
|
|
||||||
"image_data": images[0],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
prompt_kwargs = {
|
|
||||||
"text": generate_prompts,
|
|
||||||
"image_data": images,
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
# List of integers (token IDs) or empty list
|
generate_prompts = texts
|
||||||
prompt_kwargs = {"input_ids": prompt}
|
|
||||||
|
if len(generate_prompts) == 1:
|
||||||
|
prompt_kwargs = {
|
||||||
|
"text": generate_prompts[0],
|
||||||
|
"image_data": images[0],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
prompt_kwargs = {
|
||||||
|
"text": generate_prompts,
|
||||||
|
"image_data": images,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# Other types (should not happen but handle gracefully)
|
# List of integers (token IDs) or empty list
|
||||||
prompt_kwargs = {"input_ids": prompt}
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
# Use the passed request_ids for single request
|
|
||||||
final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids
|
|
||||||
else:
|
else:
|
||||||
# Handle batch requests
|
# Other types (should not happen but handle gracefully)
|
||||||
if len(prompts) > 0:
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
# Validate that all prompts have the same type
|
|
||||||
first_prompt = prompts[0]
|
|
||||||
first_type = type(first_prompt)
|
|
||||||
for i, prompt in enumerate(prompts[1:], 1):
|
|
||||||
if type(prompt) != first_type:
|
|
||||||
raise AssertionError(
|
|
||||||
f"All prompts in batch must have the same type, but prompt at index {i} has different type"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(first_prompt, str):
|
|
||||||
# Batch of strings
|
|
||||||
prompt_kwargs = {"text": prompts}
|
|
||||||
elif isinstance(first_prompt, list):
|
|
||||||
if len(first_prompt) > 0 and isinstance(first_prompt[0], str):
|
|
||||||
# Batch of lists of strings
|
|
||||||
prompt_kwargs = {"text": prompts}
|
|
||||||
elif len(first_prompt) > 0 and isinstance(
|
|
||||||
first_prompt[0], MultimodalEmbeddingInput
|
|
||||||
):
|
|
||||||
# Handle multimodal batch requests
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Multiple requests with multimodal inputs are not supported yet"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Batch of token ID lists
|
|
||||||
prompt_kwargs = {"input_ids": prompts}
|
|
||||||
else:
|
|
||||||
# Other types
|
|
||||||
prompt_kwargs = {"input_ids": prompts}
|
|
||||||
else:
|
|
||||||
prompt_kwargs = {"input_ids": prompts}
|
|
||||||
# Use the passed request_ids for batch requests
|
|
||||||
final_request_id = request_ids
|
|
||||||
|
|
||||||
adapted_request = EmbeddingReqInput(
|
adapted_request = EmbeddingReqInput(
|
||||||
rid=final_request_id,
|
|
||||||
**prompt_kwargs,
|
**prompt_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, (
|
return adapted_request, request
|
||||||
all_requests[0] if len(all_requests) == 1 else all_requests
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_non_streaming_request(
|
async def _handle_non_streaming_request(
|
||||||
self,
|
self,
|
||||||
@@ -194,14 +144,10 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
|||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
|
|
||||||
response = self._build_embedding_response(
|
response = self._build_embedding_response(ret)
|
||||||
ret, self.tokenizer_manager.model_path
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _build_embedding_response(
|
def _build_embedding_response(self, ret: List[Dict[str, Any]]) -> EmbeddingResponse:
|
||||||
self, ret: List[Dict[str, Any]], model_path: str
|
|
||||||
) -> EmbeddingResponse:
|
|
||||||
"""Build the embedding response"""
|
"""Build the embedding response"""
|
||||||
embedding_objects = []
|
embedding_objects = []
|
||||||
prompt_tokens = 0
|
prompt_tokens = 0
|
||||||
@@ -219,7 +165,7 @@ class OpenAIServingEmbedding(OpenAIServingBase):
|
|||||||
|
|
||||||
return EmbeddingResponse(
|
return EmbeddingResponse(
|
||||||
data=embedding_objects,
|
data=embedding_objects,
|
||||||
model=model_path,
|
model=self.tokenizer_manager.model_path,
|
||||||
usage=UsageInfo(
|
usage=UsageInfo(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
total_tokens=prompt_tokens,
|
total_tokens=prompt_tokens,
|
||||||
|
|||||||
@@ -95,20 +95,20 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
|||||||
"""Test converting single string request to internal format."""
|
"""Test converting single string request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
self.serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(
|
||||||
[self.basic_req], ["test-id"]
|
self.basic_req, "test-id"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||||
self.assertEqual(adapted_request.text, "Hello, how are you?")
|
self.assertEqual(adapted_request.text, "Hello, how are you?")
|
||||||
self.assertEqual(adapted_request.rid, "test-id")
|
self.assertEqual(adapted_request.rid, None)
|
||||||
self.assertEqual(processed_request, self.basic_req)
|
self.assertEqual(processed_request, self.basic_req)
|
||||||
|
|
||||||
def test_convert_list_string_request(self):
|
def test_convert_list_string_request(self):
|
||||||
"""Test converting list of strings request to internal format."""
|
"""Test converting list of strings request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
self.serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(
|
||||||
[self.list_req], ["test-id"]
|
self.list_req, "test-id"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -116,27 +116,27 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
|
adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"]
|
||||||
)
|
)
|
||||||
self.assertEqual(adapted_request.rid, "test-id")
|
self.assertEqual(adapted_request.rid, None)
|
||||||
self.assertEqual(processed_request, self.list_req)
|
self.assertEqual(processed_request, self.list_req)
|
||||||
|
|
||||||
def test_convert_token_ids_request(self):
|
def test_convert_token_ids_request(self):
|
||||||
"""Test converting token IDs request to internal format."""
|
"""Test converting token IDs request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
self.serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(
|
||||||
[self.token_ids_req], ["test-id"]
|
self.token_ids_req, "test-id"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
self.assertIsInstance(adapted_request, EmbeddingReqInput)
|
||||||
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
|
self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5])
|
||||||
self.assertEqual(adapted_request.rid, "test-id")
|
self.assertEqual(adapted_request.rid, None)
|
||||||
self.assertEqual(processed_request, self.token_ids_req)
|
self.assertEqual(processed_request, self.token_ids_req)
|
||||||
|
|
||||||
def test_convert_multimodal_request(self):
|
def test_convert_multimodal_request(self):
|
||||||
"""Test converting multimodal request to internal format."""
|
"""Test converting multimodal request to internal format."""
|
||||||
adapted_request, processed_request = (
|
adapted_request, processed_request = (
|
||||||
self.serving_embedding._convert_to_internal_request(
|
self.serving_embedding._convert_to_internal_request(
|
||||||
[self.multimodal_req], ["test-id"]
|
self.multimodal_req, "test-id"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -147,7 +147,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
|||||||
self.assertIn("World", adapted_request.text)
|
self.assertIn("World", adapted_request.text)
|
||||||
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
|
self.assertEqual(adapted_request.image_data[0], "base64_image_data")
|
||||||
self.assertIsNone(adapted_request.image_data[1])
|
self.assertIsNone(adapted_request.image_data[1])
|
||||||
self.assertEqual(adapted_request.rid, "test-id")
|
self.assertEqual(adapted_request.rid, None)
|
||||||
|
|
||||||
def test_build_single_embedding_response(self):
|
def test_build_single_embedding_response(self):
|
||||||
"""Test building response for single embedding."""
|
"""Test building response for single embedding."""
|
||||||
@@ -158,9 +158,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.serving_embedding._build_embedding_response(
|
response = self.serving_embedding._build_embedding_response(ret_data)
|
||||||
ret_data, "test-model"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(response, EmbeddingResponse)
|
self.assertIsInstance(response, EmbeddingResponse)
|
||||||
self.assertEqual(response.model, "test-model")
|
self.assertEqual(response.model, "test-model")
|
||||||
@@ -185,9 +183,7 @@ class ServingEmbeddingTestCase(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.serving_embedding._build_embedding_response(
|
response = self.serving_embedding._build_embedding_response(ret_data)
|
||||||
ret_data, "test-model"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(response, EmbeddingResponse)
|
self.assertIsInstance(response, EmbeddingResponse)
|
||||||
self.assertEqual(len(response.data), 2)
|
self.assertEqual(len(response.data), 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user