Fix the perf regression due to additional_stop_token_ids (#1773)
This commit is contained in:
@@ -31,9 +31,12 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
||||
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
sequences=[
|
||||
torch.tensor(
|
||||
data=list(
|
||||
req.sampling_params.stop_token_ids
|
||||
| {req.tokenizer.eos_token_id}
|
||||
data=(
|
||||
list(
|
||||
(req.sampling_params.stop_token_ids or set())
|
||||
| (req.tokenizer.additional_stop_token_ids or set())
|
||||
| {req.tokenizer.eos_token_id}
|
||||
)
|
||||
),
|
||||
dtype=torch.int64,
|
||||
device=self.orchestrator.device,
|
||||
|
||||
@@ -50,10 +50,10 @@ class SamplingParams:
|
||||
self.presence_penalty = presence_penalty
|
||||
self.repetition_penalty = repetition_penalty
|
||||
self.stop_strs = stop
|
||||
if stop_token_ids is None:
|
||||
self.stop_token_ids = set()
|
||||
else:
|
||||
if stop_token_ids:
|
||||
self.stop_token_ids = set(stop_token_ids)
|
||||
else:
|
||||
self.stop_token_ids = None
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.min_new_tokens = min_new_tokens
|
||||
self.ignore_eos = ignore_eos
|
||||
@@ -134,10 +134,6 @@ class SamplingParams:
|
||||
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
||||
self.stop_str_max_len = stop_str_max_len
|
||||
|
||||
# Process stop token ids
|
||||
if tokenizer and tokenizer.additional_stop_token_ids:
|
||||
self.stop_token_ids.update(tokenizer.additional_stop_token_ids)
|
||||
|
||||
def to_srt_kwargs(self):
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
|
||||
Reference in New Issue
Block a user