From b68c4c073ba730f3ced08830fd804132269bdfc9 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Sat, 10 Aug 2024 13:46:42 -0700 Subject: [PATCH] fix: force max new tokens to be 1 for embedding request (#1019) --- python/sglang/srt/managers/io_struct.py | 9 +++++---- test/srt/models/test_embedding_models.py | 4 +++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 75208874f..2d12505ae 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -195,7 +195,8 @@ class EmbeddingReqInput: if self.rid is None: self.rid = uuid.uuid4().hex if self.sampling_params is None: - self.sampling_params = {"max_new_tokens": 1} + self.sampling_params = {} + self.sampling_params["max_new_tokens"] = 1 else: # support select operation self.batch_size = ( @@ -207,9 +208,9 @@ class EmbeddingReqInput: if not isinstance(self.rid, list): raise ValueError("The rid should be a list.") if self.sampling_params is None: - self.sampling_params = [ - {"max_new_tokens": 1} for _ in range(self.batch_size) - ] + self.sampling_params = [{}] * self.batch_size + for i in range(self.batch_size): + self.sampling_params[i]["max_new_tokens"] = 1 @dataclass diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py index c29c33188..520e811a8 100644 --- a/test/srt/models/test_embedding_models.py +++ b/test/srt/models/test_embedding_models.py @@ -44,7 +44,9 @@ class TestEmbeddingModels(unittest.TestCase): torch_dtype=torch_dtype, is_generation_model=False, ) as srt_runner: - srt_outputs = srt_runner.forward(prompts) + srt_outputs = srt_runner.forward( + prompts, + ) for i in range(len(prompts)): hf_logits = torch.Tensor(hf_outputs.embed_logits[i])