From f290bd4332ce4ff4be97d59e82daa013f99c66ca Mon Sep 17 00:00:00 2001 From: Chang Su Date: Fri, 10 Jan 2025 13:14:51 -0800 Subject: [PATCH] [Bugfix] Fix embedding model hangs with `--enable-metrics` (#2822) --- python/sglang/srt/configs/model_config.py | 2 +- .../sglang/srt/managers/tokenizer_manager.py | 8 +++- .../sglang/srt/model_executor/model_runner.py | 2 +- test/srt/test_openai_server.py | 41 +++++++++++++++++++ 4 files changed, 49 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a2f9b8284..072c88b04 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -128,7 +128,7 @@ class ModelConfig: self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.vocab_size = self.hf_text_config.vocab_size - # Veirfy quantization + # Verify quantization self._verify_quantization() # Cache attributes diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 08dbd02c5..00ef8458a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -688,7 +688,7 @@ class TokenizerManager: if self.enable_metrics: completion_tokens = ( recv_obj.completion_tokens[i] - if recv_obj.completion_tokens + if getattr(recv_obj, "completion_tokens", None) else 0 ) @@ -716,7 +716,11 @@ class TokenizerManager: time.time() - state.created_time ) # Compute time_per_output_token for the non-streaming case - if not state.obj.stream and completion_tokens >= 1: + if ( + hasattr(state.obj, "stream") + and not state.obj.stream + and completion_tokens >= 1 + ): self.metrics_collector.observe_time_per_output_token( (time.time() - state.created_time) / completion_tokens diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 719db19cd..efba8c25b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -724,7 +724,7 @@ class ModelRunner: elif forward_batch.forward_mode.is_idle(): return self.forward_idle(forward_batch) else: - raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}") + raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") def sample( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 379e57f35..4bedf7439 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -14,6 +14,7 @@ import openai from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( + DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -675,5 +676,45 @@ class TestOpenAIServerEBNF(unittest.TestCase): ), "Function name should be add for the above response" +class TestOpenAIEmbedding(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + # Configure embedding-specific args + other_args = ["--is-embedding", "--enable-metrics"] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=other_args, + ) + cls.base_url += "/v1" + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_embedding_single(self): + """Test single embedding request""" + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.embeddings.create(model=self.model, input="Hello world") + self.assertEqual(len(response.data), 1) + self.assertTrue(len(response.data[0].embedding) > 0) + + def test_embedding_batch(self): + """Test batch embedding request""" + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.embeddings.create( + model=self.model, input=["Hello world", "Test text"] + ) + self.assertEqual(len(response.data), 2) + self.assertTrue(len(response.data[0].embedding) > 0) + self.assertTrue(len(response.data[1].embedding) > 0) + + if __name__ == "__main__": unittest.main()