fix: force max new tokens to be 1 for embedding request (#1019)
This commit is contained in:
@@ -195,7 +195,8 @@ class EmbeddingReqInput:
|
|||||||
if self.rid is None:
|
if self.rid is None:
|
||||||
self.rid = uuid.uuid4().hex
|
self.rid = uuid.uuid4().hex
|
||||||
if self.sampling_params is None:
|
if self.sampling_params is None:
|
||||||
self.sampling_params = {"max_new_tokens": 1}
|
self.sampling_params = {}
|
||||||
|
self.sampling_params["max_new_tokens"] = 1
|
||||||
else:
|
else:
|
||||||
# support select operation
|
# support select operation
|
||||||
self.batch_size = (
|
self.batch_size = (
|
||||||
@@ -207,9 +208,9 @@ class EmbeddingReqInput:
|
|||||||
if not isinstance(self.rid, list):
|
if not isinstance(self.rid, list):
|
||||||
raise ValueError("The rid should be a list.")
|
raise ValueError("The rid should be a list.")
|
||||||
if self.sampling_params is None:
|
if self.sampling_params is None:
|
||||||
self.sampling_params = [
|
self.sampling_params = [{}] * self.batch_size
|
||||||
{"max_new_tokens": 1} for _ in range(self.batch_size)
|
for i in range(self.batch_size):
|
||||||
]
|
self.sampling_params[i]["max_new_tokens"] = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -44,7 +44,9 @@ class TestEmbeddingModels(unittest.TestCase):
|
|||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
is_generation_model=False,
|
is_generation_model=False,
|
||||||
) as srt_runner:
|
) as srt_runner:
|
||||||
srt_outputs = srt_runner.forward(prompts)
|
srt_outputs = srt_runner.forward(
|
||||||
|
prompts,
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(len(prompts)):
|
for i in range(len(prompts)):
|
||||||
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
|
hf_logits = torch.Tensor(hf_outputs.embed_logits[i])
|
||||||
|
|||||||
Reference in New Issue
Block a user