diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 63a4d3ac4..535935654 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -285,6 +285,21 @@ class Engine(EngineBase): ret = loop.run_until_complete(generator.__anext__()) return ret + async def async_encode( + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + image_data: Optional[Union[List[str], str]] = None, + ) -> Dict: + """ + Asynchronous version of encode method. + + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. + Please refer to `EmbeddingReqInput` for the documentation. + """ + obj = EmbeddingReqInput(text=prompt, image_data=image_data) + generator = self.tokenizer_manager.generate_request(obj, None) + return await generator.__anext__() + def shutdown(self): """Shutdown the engine""" kill_process_tree(os.getpid(), include_parent=False) diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 672344c63..d6f5ac685 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -185,6 +185,35 @@ class TestSRTEngine(CustomTestCase): result = throughput_test(server_args=server_args, bench_args=bench_args) self.assertGreater(result["total_throughput"], 3000) + def test_8_engine_async_encode_consistency(self): + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST + + engine = sgl.Engine( + model_path=model_path, + is_embedding=True, + random_seed=42, + disable_radix_cache=True, + ) + + # Get sync and async embeddings + out1 = torch.tensor(engine.encode(prompt)["embedding"]) + loop = asyncio.get_event_loop() + out2 = torch.tensor( + loop.run_until_complete(engine.async_encode(prompt))["embedding"] + ) + + engine.shutdown() + + print("\n==== Shapes ====") + print(f"sync shape: {out1.shape}") + print(f"async shape: {out2.shape}") + + self.assertTrue( + torch.allclose(out1, out2, atol=1e-5, rtol=1e-3), + "Sync and async embeddings are not equal within tolerance", + ) + if __name__ == "__main__": unittest.main()