fix: force max new tokens to be 1 for embedding request (#1019)

This commit is contained in:
Ying Sheng
2024-08-10 13:46:42 -07:00
committed by GitHub
parent e712837d38
commit b68c4c073b
2 changed files with 8 additions and 5 deletions

View File

@@ -195,7 +195,8 @@ class EmbeddingReqInput:
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.sampling_params is None:
self.sampling_params = {"max_new_tokens": 1}
self.sampling_params = {}
self.sampling_params["max_new_tokens"] = 1
else:
# support select operation
self.batch_size = (
@@ -207,9 +208,9 @@ class EmbeddingReqInput:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
if self.sampling_params is None:
self.sampling_params = [
{"max_new_tokens": 1} for _ in range(self.batch_size)
]
self.sampling_params = [{}] * self.batch_size
for i in range(self.batch_size):
self.sampling_params[i]["max_new_tokens"] = 1
@dataclass