Fix the perf regression due to additional_stop_token_ids (#1773)
This commit is contained in:
@@ -164,7 +164,7 @@ def get_tokenizer(
|
|||||||
"slowdown. Consider using a fast tokenizer instead."
|
"slowdown. Consider using a fast tokenizer instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
handle_additional_stop_token_ids(tokenizer)
|
attach_additional_stop_token_ids(tokenizer)
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -184,11 +184,11 @@ def get_processor(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
handle_additional_stop_token_ids(processor.tokenizer)
|
attach_additional_stop_token_ids(processor.tokenizer)
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
|
|
||||||
def handle_additional_stop_token_ids(tokenizer):
|
def attach_additional_stop_token_ids(tokenizer):
|
||||||
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
|
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
|
||||||
if "<|eom_id|>" in tokenizer.get_added_vocab():
|
if "<|eom_id|>" in tokenizer.get_added_vocab():
|
||||||
tokenizer.additional_stop_token_ids = set(
|
tokenizer.additional_stop_token_ids = set(
|
||||||
|
|||||||
@@ -42,11 +42,11 @@ class Sampler(nn.Module):
|
|||||||
logits = logits.contiguous()
|
logits = logits.contiguous()
|
||||||
|
|
||||||
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
||||||
exit(1) if crash_on_warning else None
|
|
||||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||||
logits = torch.where(
|
logits = torch.where(
|
||||||
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
||||||
)
|
)
|
||||||
|
exit(1) if crash_on_warning else None
|
||||||
|
|
||||||
if sampling_info.is_all_greedy:
|
if sampling_info.is_all_greedy:
|
||||||
# Use torch.argmax if all requests use greedy sampling
|
# Use torch.argmax if all requests use greedy sampling
|
||||||
|
|||||||
@@ -334,15 +334,20 @@ class Req:
|
|||||||
|
|
||||||
last_token_id = self.output_ids[-1]
|
last_token_id = self.output_ids[-1]
|
||||||
|
|
||||||
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
matched_eos = False
|
||||||
|
|
||||||
|
# Check stop token ids
|
||||||
|
if self.sampling_params.stop_token_ids:
|
||||||
|
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
||||||
|
if self.tokenizer.additional_stop_token_ids:
|
||||||
|
matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
|
||||||
if matched_eos and not self.sampling_params.ignore_eos:
|
if matched_eos and not self.sampling_params.ignore_eos:
|
||||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Check stop strings
|
||||||
if len(self.sampling_params.stop_strs) > 0:
|
if len(self.sampling_params.stop_strs) > 0:
|
||||||
tail_str = self.tokenizer.decode(
|
tail_str = self.tokenizer.decode(
|
||||||
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
||||||
|
|||||||
@@ -31,9 +31,12 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|||||||
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
|
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
sequences=[
|
sequences=[
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
data=list(
|
data=(
|
||||||
req.sampling_params.stop_token_ids
|
list(
|
||||||
|
(req.sampling_params.stop_token_ids or set())
|
||||||
|
| (req.tokenizer.additional_stop_token_ids or set())
|
||||||
| {req.tokenizer.eos_token_id}
|
| {req.tokenizer.eos_token_id}
|
||||||
|
)
|
||||||
),
|
),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.orchestrator.device,
|
device=self.orchestrator.device,
|
||||||
|
|||||||
@@ -50,10 +50,10 @@ class SamplingParams:
|
|||||||
self.presence_penalty = presence_penalty
|
self.presence_penalty = presence_penalty
|
||||||
self.repetition_penalty = repetition_penalty
|
self.repetition_penalty = repetition_penalty
|
||||||
self.stop_strs = stop
|
self.stop_strs = stop
|
||||||
if stop_token_ids is None:
|
if stop_token_ids:
|
||||||
self.stop_token_ids = set()
|
|
||||||
else:
|
|
||||||
self.stop_token_ids = set(stop_token_ids)
|
self.stop_token_ids = set(stop_token_ids)
|
||||||
|
else:
|
||||||
|
self.stop_token_ids = None
|
||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.min_new_tokens = min_new_tokens
|
self.min_new_tokens = min_new_tokens
|
||||||
self.ignore_eos = ignore_eos
|
self.ignore_eos = ignore_eos
|
||||||
@@ -134,10 +134,6 @@ class SamplingParams:
|
|||||||
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
||||||
self.stop_str_max_len = stop_str_max_len
|
self.stop_str_max_len = stop_str_max_len
|
||||||
|
|
||||||
# Process stop token ids
|
|
||||||
if tokenizer and tokenizer.additional_stop_token_ids:
|
|
||||||
self.stop_token_ids.update(tokenizer.additional_stop_token_ids)
|
|
||||||
|
|
||||||
def to_srt_kwargs(self):
|
def to_srt_kwargs(self):
|
||||||
return {
|
return {
|
||||||
"max_new_tokens": self.max_new_tokens,
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
|||||||
Reference in New Issue
Block a user