From 7599badeaf5aeab8c9f72659ceb55bcaf9472e56 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Sat, 10 Aug 2024 08:39:05 -0700 Subject: [PATCH] Support embedding input as a list (#1014) --- .../sglang/srt/managers/tokenizer_manager.py | 100 ++++++++++-------- python/sglang/test/runners.py | 8 +- test/srt/test_embedding_openai_server.py | 5 +- 3 files changed, 64 insertions(+), 49 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e2c825973..e1bfbc7e6 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -153,9 +153,7 @@ class TokenizerManager: async for response in self._handle_single_request(obj, request): yield response else: - if isinstance(obj, EmbeddingReqInput): - raise NotImplementedError("Please send only one prompt in each request") - if obj.stream: + if hasattr(obj, "stream") and obj.stream: raise ValueError("Do not support stream for batch mode.") async for response in self._handle_batch_request(obj, request): @@ -283,24 +281,29 @@ class TokenizerManager: await self._wait_for_cache_prefill_response(event, state, obj, rid, request) yield input_ids - async def _handle_batch_request(self, obj: GenerateReqInput, request): + async def _handle_batch_request( + self, obj: Union[GenerateReqInput, EmbeddingReqInput], request + ): batch_size = obj.batch_size - parallel_sample_num = obj.parallel_sample_num + if self.is_generation: + parallel_sample_num = obj.parallel_sample_num - if parallel_sample_num != 1: - # Send prefill requests to cache the common input - parallel_sample_num += 1 - input_id_result = [] if obj.input_ids is None else None - for i in range(batch_size): - async for input_id in self._handle_single_request( - obj, request, index=i, is_cache_for_prefill=True - ): - if input_id_result is not None: - input_id_result.append(input_id) - if input_id_result is not None and len(input_id_result) > 1: - obj.input_ids = input_id_result - elif input_id_result is not None: - obj.input_ids = input_id_result[0] + if parallel_sample_num != 1: + # Send prefill requests to cache the common input + parallel_sample_num += 1 + input_id_result = [] if obj.input_ids is None else None + for i in range(batch_size): + async for input_id in self._handle_single_request( + obj, request, index=i, is_cache_for_prefill=True + ): + if input_id_result is not None: + input_id_result.append(input_id) + if input_id_result is not None and len(input_id_result) > 1: + obj.input_ids = input_id_result + elif input_id_result is not None: + obj.input_ids = input_id_result[0] + else: + parallel_sample_num = 1 # First send out all requests for i in range(batch_size): @@ -329,28 +332,38 @@ class TokenizerManager: input_text = None input_ids = obj.input_ids[i] sampling_params = self._get_sampling_params(obj.sampling_params[index]) - pixel_values, image_hash, image_size = await self._get_pixel_values( - obj.image_data[index] - ) - tokenized_obj = TokenizedGenerateReqInput( - rid, - input_text, - input_ids, - pixel_values, - image_hash, - image_size, - sampling_params, - obj.return_logprob[index], - obj.logprob_start_len[index], - obj.top_logprobs_num[index], - obj.stream, - ) + if self.is_generation: + pixel_values, image_hash, image_size = await self._get_pixel_values( + obj.image_data[index] + ) + + tokenized_obj = TokenizedGenerateReqInput( + rid, + input_text, + input_ids, + pixel_values, + image_hash, + image_size, + sampling_params, + obj.return_logprob[index], + obj.logprob_start_len[index], + obj.top_logprobs_num[index], + obj.stream, + ) + else: + tokenized_obj = TokenizedEmbeddingReqInput( + rid, + input_text, + input_ids, + sampling_params, + ) self.send_to_router.send_pyobj(tokenized_obj) event = asyncio.Event() state = ReqState([], False, event) self.rid_to_state[rid] = state + # Then wait for all responses output_list = [] for i in range(batch_size): @@ -373,14 +386,17 @@ class TokenizerManager: self.abort_request(rid) raise ValueError(f"Abort request {rid}") continue - output_list.append( - self.convert_logprob_style( - state.out_list[-1], - obj.return_logprob[index], - obj.top_logprobs_num[index], - obj.return_text_in_logprobs, + if self.is_generation: + output_list.append( + self.convert_logprob_style( + state.out_list[-1], + obj.return_logprob[index], + obj.top_logprobs_num[index], + obj.return_text_in_logprobs, + ) ) - ) + else: + output_list.append(state.out_list[-1]) assert state.finished del self.rid_to_state[rid] yield output_list diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 87277ca69..e619d58ca 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -219,11 +219,9 @@ class SRTRunner: output_strs=output_strs, top_input_logprobs=top_input_logprobs ) else: - logits = [] - for prompt in prompts: - response = self.runtime.encode(prompt) - response = json.loads(response) - logits.append(response["embedding"]) + response = self.runtime.encode(prompts) + response = json.loads(response) + logits = [x["embedding"] for x in response] return ModelOutput(embed_logits=logits) def __enter__(self): diff --git a/test/srt/test_embedding_openai_server.py b/test/srt/test_embedding_openai_server.py index d60ae5068..72dc7a009 100644 --- a/test/srt/test_embedding_openai_server.py +++ b/test/srt/test_embedding_openai_server.py @@ -38,8 +38,9 @@ class TestOpenAIServer(unittest.TestCase): num_prompt_tokens = len(self.tokenizer.encode(prompt)) if use_list_input: - prompt_arg = [prompt_input, prompt_input] + prompt_arg = [prompt_input] * 2 num_prompts = len(prompt_arg) + num_prompt_tokens *= num_prompts else: prompt_arg = prompt_input num_prompts = 1 @@ -70,7 +71,7 @@ class TestOpenAIServer(unittest.TestCase): def test_embedding(self): # TODO the fields of encoding_format, dimensions, user are skipped # TODO support use_list_input - for use_list_input in [False]: + for use_list_input in [False, True]: for token_input in [False, True]: self.run_embedding(use_list_input, token_input)