fix: force max new tokens to be 1 for embedding request (#1019)
This commit is contained in:
@@ -195,7 +195,8 @@ class EmbeddingReqInput:
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = {"max_new_tokens": 1}
|
||||
self.sampling_params = {}
|
||||
self.sampling_params["max_new_tokens"] = 1
|
||||
else:
|
||||
# support select operation
|
||||
self.batch_size = (
|
||||
@@ -207,9 +208,9 @@ class EmbeddingReqInput:
|
||||
if not isinstance(self.rid, list):
|
||||
raise ValueError("The rid should be a list.")
|
||||
if self.sampling_params is None:
|
||||
self.sampling_params = [
|
||||
{"max_new_tokens": 1} for _ in range(self.batch_size)
|
||||
]
|
||||
self.sampling_params = [{}] * self.batch_size
|
||||
for i in range(self.batch_size):
|
||||
self.sampling_params[i]["max_new_tokens"] = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user