diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 8420f20dd..8f6700575 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -437,13 +437,13 @@ class TokenizerManager: is_stream = hasattr(obj, "stream") and obj.stream tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] - output_list = [] + output_list = [None] * len(tasks) while tasks: done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) for task in done: - gen_index = tasks.index(task) + cur_index = tasks.index(task) try: result = task.result() @@ -451,14 +451,14 @@ class TokenizerManager: if is_stream: yield result else: - output_list.append(result) + output_list[result["index"]] = result - tasks[gen_index] = asyncio.create_task( - generators[gen_index].__anext__() + tasks[cur_index] = asyncio.create_task( + generators[cur_index].__anext__() ) except StopAsyncIteration: - del generators[gen_index] - del tasks[gen_index] + del generators[cur_index] + del tasks[cur_index] if not is_stream: yield output_list diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 3ec5cd633..241fabf6d 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -591,7 +591,7 @@ class Runtime: def generate( self, - prompt: str, + prompt: Union[str, List[str]], sampling_params: Optional[Dict] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, @@ -612,7 +612,7 @@ class Runtime: def encode( self, - prompt: str, + prompt: Union[str, List[str]], ): json_data = { "text": prompt, diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 4fc1f0f25..9f18a91f7 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -28,10 +28,10 @@ from sglang.srt.server import Runtime DEFAULT_PROMPTS = [ # the output of gemma-2-2b from SRT is unstable on the commented prompt # "The capital of France is", + "Apple is red. Banana is Yellow. " * 800 + "Apple is", "The capital of the United Kindom is", "Today is a sunny day and I like", "AI is a field of computer science focused on", - "Apple is red. Banana is Yellow. " * 800 + "Apple is", ] dirpath = os.path.dirname(__file__) diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py index 44fed2ad0..cc830f625 100644 --- a/test/srt/models/test_embedding_models.py +++ b/test/srt/models/test_embedding_models.py @@ -20,7 +20,7 @@ import torch from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.test_utils import get_similarities -MODELS = [("intfloat/e5-mistral-7b-instruct", 1)] +MODELS = [("intfloat/e5-mistral-7b-instruct", 1, 0.2)] TORCH_DTYPES = [torch.float16] @@ -32,6 +32,7 @@ class TestEmbeddingModels(unittest.TestCase): model_path, tp_size, torch_dtype, + long_context_tolerance, ) -> None: with HFRunner( model_path, torch_dtype=torch_dtype, is_generation_model=False @@ -52,20 +53,22 @@ class TestEmbeddingModels(unittest.TestCase): hf_logits = torch.Tensor(hf_outputs.embed_logits[i]) srt_logits = torch.Tensor(srt_outputs.embed_logits[i]) - similarities = torch.tensor(get_similarities(hf_logits, srt_logits)) - print("max similarity diff", torch.max(abs(similarities - 1))) + similarity = torch.tensor(get_similarities(hf_logits, srt_logits)) + print("similarity diff", abs(similarity - 1)) - if hf_logits.shape[0] <= 100: - tolerance = 1e-2 - assert torch.all( - abs(similarities - 1) < tolerance - ), "embeddings are not all close" + if len(prompts[i]) <= 1000: + tolerance = 1e-5 + else: + tolerance = long_context_tolerance + assert torch.all( + abs(similarity - 1) < tolerance + ), "embeddings are not all close" def test_prefill_logits(self): - for model, tp_size in MODELS: + for model, tp_size, long_context_tolerance in MODELS: for torch_dtype in TORCH_DTYPES: self.assert_close_prefill_logits( - DEFAULT_PROMPTS, model, tp_size, torch_dtype + DEFAULT_PROMPTS, model, tp_size, torch_dtype, long_context_tolerance )