diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 0f9f94dca..0bb4872fe 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -163,6 +163,15 @@ def get_tokenizer( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead." ) + + # Special handling for stop token <|eom_id|> generated by llama 3 tool use. + if "<|eom_id|>" in tokenizer.get_added_vocab(): + tokenizer.additional_stop_token_ids = set( + [tokenizer.get_added_vocab()["<|eom_id|>"]] + ) + else: + tokenizer.additional_stop_token_ids = None + return tokenizer diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 55ec00fc1..e5151a1c2 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -51,8 +51,9 @@ class SamplingParams: self.repetition_penalty = repetition_penalty self.stop_strs = stop if stop_token_ids is None: - stop_token_ids = [] - self.stop_token_ids = {*stop_token_ids} + self.stop_token_ids = set() + else: + self.stop_token_ids = set(stop_token_ids) self.max_new_tokens = max_new_tokens self.min_new_tokens = min_new_tokens self.ignore_eos = ignore_eos @@ -119,10 +120,7 @@ class SamplingParams: # Process stop strings if self.stop_strs is None: self.stop_strs = [] - if self.stop_token_ids is None: - self.stop_str_max_len = 0 - else: - self.stop_str_max_len = 1 + self.stop_str_max_len = 0 else: if isinstance(self.stop_strs, str): self.stop_strs = [self.stop_strs] @@ -136,6 +134,10 @@ class SamplingParams: stop_str_max_len = max(stop_str_max_len, len(stop_str)) self.stop_str_max_len = stop_str_max_len + # Process stop token ids + if tokenizer.additional_stop_token_ids: + self.stop_token_ids.update(tokenizer.additional_stop_token_ids) + def to_srt_kwargs(self): return { "max_new_tokens": self.max_new_tokens,