Fix the perf regression due to additional_stop_token_ids (#1773)

This commit is contained in:
Lianmin Zheng
2024-10-23 16:45:21 -07:00
committed by GitHub
parent 05b3bf5e8e
commit 8f8f96a621
5 changed files with 20 additions and 16 deletions

View File

@@ -31,9 +31,12 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=[
torch.tensor(
data=list(
req.sampling_params.stop_token_ids
| {req.tokenizer.eos_token_id}
data=(
list(
(req.sampling_params.stop_token_ids or set())
| (req.tokenizer.additional_stop_token_ids or set())
| {req.tokenizer.eos_token_id}
)
),
dtype=torch.int64,
device=self.orchestrator.device,

View File

@@ -50,10 +50,10 @@ class SamplingParams:
self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty
self.stop_strs = stop
if stop_token_ids is None:
self.stop_token_ids = set()
else:
if stop_token_ids:
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.max_new_tokens = max_new_tokens
self.min_new_tokens = min_new_tokens
self.ignore_eos = ignore_eos
@@ -134,10 +134,6 @@ class SamplingParams:
stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len
# Process stop token ids
if tokenizer and tokenizer.additional_stop_token_ids:
self.stop_token_ids.update(tokenizer.additional_stop_token_ids)
def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,