Fix stop condition for <|eom_id|> (#1766)
This commit is contained in:
@@ -163,6 +163,15 @@ def get_tokenizer(
|
|||||||
"Using a slow tokenizer. This might cause a significant "
|
"Using a slow tokenizer. This might cause a significant "
|
||||||
"slowdown. Consider using a fast tokenizer instead."
|
"slowdown. Consider using a fast tokenizer instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
|
||||||
|
if "<|eom_id|>" in tokenizer.get_added_vocab():
|
||||||
|
tokenizer.additional_stop_token_ids = set(
|
||||||
|
[tokenizer.get_added_vocab()["<|eom_id|>"]]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tokenizer.additional_stop_token_ids = None
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -51,8 +51,9 @@ class SamplingParams:
|
|||||||
self.repetition_penalty = repetition_penalty
|
self.repetition_penalty = repetition_penalty
|
||||||
self.stop_strs = stop
|
self.stop_strs = stop
|
||||||
if stop_token_ids is None:
|
if stop_token_ids is None:
|
||||||
stop_token_ids = []
|
self.stop_token_ids = set()
|
||||||
self.stop_token_ids = {*stop_token_ids}
|
else:
|
||||||
|
self.stop_token_ids = set(stop_token_ids)
|
||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.min_new_tokens = min_new_tokens
|
self.min_new_tokens = min_new_tokens
|
||||||
self.ignore_eos = ignore_eos
|
self.ignore_eos = ignore_eos
|
||||||
@@ -119,10 +120,7 @@ class SamplingParams:
|
|||||||
# Process stop strings
|
# Process stop strings
|
||||||
if self.stop_strs is None:
|
if self.stop_strs is None:
|
||||||
self.stop_strs = []
|
self.stop_strs = []
|
||||||
if self.stop_token_ids is None:
|
self.stop_str_max_len = 0
|
||||||
self.stop_str_max_len = 0
|
|
||||||
else:
|
|
||||||
self.stop_str_max_len = 1
|
|
||||||
else:
|
else:
|
||||||
if isinstance(self.stop_strs, str):
|
if isinstance(self.stop_strs, str):
|
||||||
self.stop_strs = [self.stop_strs]
|
self.stop_strs = [self.stop_strs]
|
||||||
@@ -136,6 +134,10 @@ class SamplingParams:
|
|||||||
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
||||||
self.stop_str_max_len = stop_str_max_len
|
self.stop_str_max_len = stop_str_max_len
|
||||||
|
|
||||||
|
# Process stop token ids
|
||||||
|
if tokenizer.additional_stop_token_ids:
|
||||||
|
self.stop_token_ids.update(tokenizer.additional_stop_token_ids)
|
||||||
|
|
||||||
def to_srt_kwargs(self):
|
def to_srt_kwargs(self):
|
||||||
return {
|
return {
|
||||||
"max_new_tokens": self.max_new_tokens,
|
"max_new_tokens": self.max_new_tokens,
|
||||||
|
|||||||
Reference in New Issue
Block a user