diff --git a/python/sglang/srt/entrypoints/openai/api_server.py b/python/sglang/srt/entrypoints/openai/api_server.py index 490e4ac13..b575275ae 100644 --- a/python/sglang/srt/entrypoints/openai/api_server.py +++ b/python/sglang/srt/entrypoints/openai/api_server.py @@ -40,9 +40,10 @@ from sglang.srt.disaggregation.utils import ( register_disaggregation_server, ) from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses +from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.metrics.func_timer import enable_func_timer -from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.openai_api.protocol import EmbeddingRequest, ModelCard, ModelList from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( add_prometheus_middleware, @@ -64,6 +65,7 @@ class AppState: server_args: Optional[ServerArgs] = None tokenizer_manager: Optional[TokenizerManager] = None scheduler_info: Optional[Dict] = None + embedding_server: Optional[OpenAIServingEmbedding] = None @asynccontextmanager @@ -78,6 +80,9 @@ async def lifespan(app: FastAPI): tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) app.state.tokenizer_manager = tokenizer_manager app.state.scheduler_info = scheduler_info + app.state.serving_embedding = OpenAIServingEmbedding( + tokenizer_manager=tokenizer_manager + ) if server_args.enable_metrics: add_prometheus_middleware(app) @@ -169,7 +174,16 @@ async def openai_v1_chat_completions(raw_request: Request): @app.post("/v1/embeddings") async def openai_v1_embeddings(raw_request: Request): - pass + try: + request_json = await raw_request.json() + request = EmbeddingRequest(**request_json) + except Exception as e: + return app.state.serving_embedding.create_error_response( + f"Invalid request body, error: {str(e)}" + ) + + ret = await app.state.serving_embedding.handle_request(request, raw_request) + return ret @app.post("/v1/score") diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 718378f5e..d441f7a20 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -37,7 +37,7 @@ class OpenAIServingBase(ABC): # Convert to internal format adapted_request, processed_request = self._convert_to_internal_request( - [request], [self._generate_request_id_base(request)] + request, self._generate_request_id_base(request) ) # Note(Xinyuan): raw_request below is only used for detecting the connection of the client @@ -73,8 +73,8 @@ class OpenAIServingBase(ABC): @abstractmethod def _convert_to_internal_request( self, - all_requests: List[OpenAIServingRequest], - request_ids: List[str], + request: OpenAIServingRequest, + request_id: str, ) -> tuple[ GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]] ]: diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index 33d1e5918..79333df6b 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -71,111 +71,61 @@ class OpenAIServingEmbedding(OpenAIServingBase): def _convert_to_internal_request( self, - all_requests: List[EmbeddingRequest], - request_ids: List[str], + request: EmbeddingRequest, + request_id: str, ) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]: """Convert OpenAI embedding request to internal format""" - prompts = [request.input for request in all_requests] - - # Handle single vs multiple requests - if len(all_requests) == 1: - prompt = prompts[0] - if isinstance(prompt, str): - # Single string input + prompt = request.input + if isinstance(prompt, str): + # Single string input + prompt_kwargs = {"text": prompt} + elif isinstance(prompt, list): + if len(prompt) > 0 and isinstance(prompt[0], str): + # List of strings prompt_kwargs = {"text": prompt} - elif isinstance(prompt, list): - if len(prompt) > 0 and isinstance(prompt[0], str): - # List of strings - prompt_kwargs = {"text": prompt} - elif len(prompt) > 0 and isinstance( - prompt[0], MultimodalEmbeddingInput - ): - # Handle multimodal embedding inputs - texts = [] - images = [] - for item in prompt: - # Use padding for text if None - this could be improved - texts.append(item.text if item.text is not None else "padding") - images.append(item.image if item.image is not None else None) + elif len(prompt) > 0 and isinstance(prompt[0], MultimodalEmbeddingInput): + # Handle multimodal embedding inputs + texts = [] + images = [] + for item in prompt: + # Use padding for text if None - this could be improved + texts.append(item.text if item.text is not None else "padding") + images.append(item.image if item.image is not None else None) - generate_prompts = [] - # Check if we have a chat template for multimodal embeddings - # This would need to be passed in from the server configuration - chat_template_name = getattr( - self.tokenizer_manager, "chat_template_name", None - ) - if chat_template_name is not None: - convs = generate_embedding_convs( - texts, images, chat_template_name - ) - for conv in convs: - generate_prompts.append(conv.get_prompt()) - else: - generate_prompts = texts - - if len(generate_prompts) == 1: - prompt_kwargs = { - "text": generate_prompts[0], - "image_data": images[0], - } - else: - prompt_kwargs = { - "text": generate_prompts, - "image_data": images, - } + generate_prompts = [] + # Check if we have a chat template for multimodal embeddings + # This would need to be passed in from the server configuration + chat_template_name = getattr( + self.tokenizer_manager, "chat_template_name", None + ) + if chat_template_name is not None: + convs = generate_embedding_convs(texts, images, chat_template_name) + for conv in convs: + generate_prompts.append(conv.get_prompt()) else: - # List of integers (token IDs) or empty list - prompt_kwargs = {"input_ids": prompt} + generate_prompts = texts + + if len(generate_prompts) == 1: + prompt_kwargs = { + "text": generate_prompts[0], + "image_data": images[0], + } + else: + prompt_kwargs = { + "text": generate_prompts, + "image_data": images, + } else: - # Other types (should not happen but handle gracefully) + # List of integers (token IDs) or empty list prompt_kwargs = {"input_ids": prompt} - # Use the passed request_ids for single request - final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids else: - # Handle batch requests - if len(prompts) > 0: - # Validate that all prompts have the same type - first_prompt = prompts[0] - first_type = type(first_prompt) - for i, prompt in enumerate(prompts[1:], 1): - if type(prompt) != first_type: - raise AssertionError( - f"All prompts in batch must have the same type, but prompt at index {i} has different type" - ) - - if isinstance(first_prompt, str): - # Batch of strings - prompt_kwargs = {"text": prompts} - elif isinstance(first_prompt, list): - if len(first_prompt) > 0 and isinstance(first_prompt[0], str): - # Batch of lists of strings - prompt_kwargs = {"text": prompts} - elif len(first_prompt) > 0 and isinstance( - first_prompt[0], MultimodalEmbeddingInput - ): - # Handle multimodal batch requests - raise NotImplementedError( - "Multiple requests with multimodal inputs are not supported yet" - ) - else: - # Batch of token ID lists - prompt_kwargs = {"input_ids": prompts} - else: - # Other types - prompt_kwargs = {"input_ids": prompts} - else: - prompt_kwargs = {"input_ids": prompts} - # Use the passed request_ids for batch requests - final_request_id = request_ids - + # Other types (should not happen but handle gracefully) + prompt_kwargs = {"input_ids": prompt} adapted_request = EmbeddingReqInput( - rid=final_request_id, **prompt_kwargs, ) - return adapted_request, ( - all_requests[0] if len(all_requests) == 1 else all_requests - ) + return adapted_request, request async def _handle_non_streaming_request( self, @@ -194,14 +144,10 @@ class OpenAIServingEmbedding(OpenAIServingBase): if not isinstance(ret, list): ret = [ret] - response = self._build_embedding_response( - ret, self.tokenizer_manager.model_path - ) + response = self._build_embedding_response(ret) return response - def _build_embedding_response( - self, ret: List[Dict[str, Any]], model_path: str - ) -> EmbeddingResponse: + def _build_embedding_response(self, ret: List[Dict[str, Any]]) -> EmbeddingResponse: """Build the embedding response""" embedding_objects = [] prompt_tokens = 0 @@ -219,7 +165,7 @@ class OpenAIServingEmbedding(OpenAIServingBase): return EmbeddingResponse( data=embedding_objects, - model=model_path, + model=self.tokenizer_manager.model_path, usage=UsageInfo( prompt_tokens=prompt_tokens, total_tokens=prompt_tokens, diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py index 137db8fe5..b927be4fe 100644 --- a/test/srt/openai/test_serving_embedding.py +++ b/test/srt/openai/test_serving_embedding.py @@ -95,20 +95,20 @@ class ServingEmbeddingTestCase(unittest.TestCase): """Test converting single string request to internal format.""" adapted_request, processed_request = ( self.serving_embedding._convert_to_internal_request( - [self.basic_req], ["test-id"] + self.basic_req, "test-id" ) ) self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertEqual(adapted_request.text, "Hello, how are you?") - self.assertEqual(adapted_request.rid, "test-id") + self.assertEqual(adapted_request.rid, None) self.assertEqual(processed_request, self.basic_req) def test_convert_list_string_request(self): """Test converting list of strings request to internal format.""" adapted_request, processed_request = ( self.serving_embedding._convert_to_internal_request( - [self.list_req], ["test-id"] + self.list_req, "test-id" ) ) @@ -116,27 +116,27 @@ class ServingEmbeddingTestCase(unittest.TestCase): self.assertEqual( adapted_request.text, ["Hello, how are you?", "I am fine, thank you!"] ) - self.assertEqual(adapted_request.rid, "test-id") + self.assertEqual(adapted_request.rid, None) self.assertEqual(processed_request, self.list_req) def test_convert_token_ids_request(self): """Test converting token IDs request to internal format.""" adapted_request, processed_request = ( self.serving_embedding._convert_to_internal_request( - [self.token_ids_req], ["test-id"] + self.token_ids_req, "test-id" ) ) self.assertIsInstance(adapted_request, EmbeddingReqInput) self.assertEqual(adapted_request.input_ids, [1, 2, 3, 4, 5]) - self.assertEqual(adapted_request.rid, "test-id") + self.assertEqual(adapted_request.rid, None) self.assertEqual(processed_request, self.token_ids_req) def test_convert_multimodal_request(self): """Test converting multimodal request to internal format.""" adapted_request, processed_request = ( self.serving_embedding._convert_to_internal_request( - [self.multimodal_req], ["test-id"] + self.multimodal_req, "test-id" ) ) @@ -147,7 +147,7 @@ class ServingEmbeddingTestCase(unittest.TestCase): self.assertIn("World", adapted_request.text) self.assertEqual(adapted_request.image_data[0], "base64_image_data") self.assertIsNone(adapted_request.image_data[1]) - self.assertEqual(adapted_request.rid, "test-id") + self.assertEqual(adapted_request.rid, None) def test_build_single_embedding_response(self): """Test building response for single embedding.""" @@ -158,9 +158,7 @@ class ServingEmbeddingTestCase(unittest.TestCase): } ] - response = self.serving_embedding._build_embedding_response( - ret_data, "test-model" - ) + response = self.serving_embedding._build_embedding_response(ret_data) self.assertIsInstance(response, EmbeddingResponse) self.assertEqual(response.model, "test-model") @@ -185,9 +183,7 @@ class ServingEmbeddingTestCase(unittest.TestCase): }, ] - response = self.serving_embedding._build_embedding_response( - ret_data, "test-model" - ) + response = self.serving_embedding._build_embedding_response(ret_data) self.assertIsInstance(response, EmbeddingResponse) self.assertEqual(len(response.data), 2)