diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 56d7c8a1f..6a2582e42 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -164,7 +164,7 @@ def get_tokenizer( "slowdown. Consider using a fast tokenizer instead." ) - handle_additional_stop_token_ids(tokenizer) + attach_additional_stop_token_ids(tokenizer) return tokenizer @@ -184,11 +184,11 @@ def get_processor( **kwargs, ) - handle_additional_stop_token_ids(processor.tokenizer) + attach_additional_stop_token_ids(processor.tokenizer) return processor -def handle_additional_stop_token_ids(tokenizer): +def attach_additional_stop_token_ids(tokenizer): # Special handling for stop token <|eom_id|> generated by llama 3 tool use. if "<|eom_id|>" in tokenizer.get_added_vocab(): tokenizer.additional_stop_token_ids = set( diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 9ae5801cc..a5afcab51 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -42,11 +42,11 @@ class Sampler(nn.Module): logits = logits.contiguous() if self.use_nan_detectioin and torch.any(torch.isnan(logits)): - exit(1) if crash_on_warning else None logger.warning("Detected errors during sampling! NaN in the logits.") logits = torch.where( torch.isnan(logits), torch.full_like(logits, -1e5), logits ) + exit(1) if crash_on_warning else None if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index fac008d3f..fcd06d8cc 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -334,15 +334,20 @@ class Req: last_token_id = self.output_ids[-1] - matched_eos = last_token_id in self.sampling_params.stop_token_ids + matched_eos = False + # Check stop token ids + if self.sampling_params.stop_token_ids: + matched_eos = last_token_id in self.sampling_params.stop_token_ids if self.tokenizer is not None: matched_eos |= last_token_id == self.tokenizer.eos_token_id - + if self.tokenizer.additional_stop_token_ids: + matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids if matched_eos and not self.sampling_params.ignore_eos: self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) return + # Check stop strings if len(self.sampling_params.stop_strs) > 0: tail_str = self.tokenizer.decode( self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py b/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py index c9e0f078e..cc97a2eac 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py @@ -31,9 +31,12 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence( sequences=[ torch.tensor( - data=list( - req.sampling_params.stop_token_ids - | {req.tokenizer.eos_token_id} + data=( + list( + (req.sampling_params.stop_token_ids or set()) + | (req.tokenizer.additional_stop_token_ids or set()) + | {req.tokenizer.eos_token_id} + ) ), dtype=torch.int64, device=self.orchestrator.device, diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index b0863b557..fbe90ba0f 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -50,10 +50,10 @@ class SamplingParams: self.presence_penalty = presence_penalty self.repetition_penalty = repetition_penalty self.stop_strs = stop - if stop_token_ids is None: - self.stop_token_ids = set() - else: + if stop_token_ids: self.stop_token_ids = set(stop_token_ids) + else: + self.stop_token_ids = None self.max_new_tokens = max_new_tokens self.min_new_tokens = min_new_tokens self.ignore_eos = ignore_eos @@ -134,10 +134,6 @@ class SamplingParams: stop_str_max_len = max(stop_str_max_len, len(stop_str)) self.stop_str_max_len = stop_str_max_len - # Process stop token ids - if tokenizer and tokenizer.additional_stop_token_ids: - self.stop_token_ids.update(tokenizer.additional_stop_token_ids) - def to_srt_kwargs(self): return { "max_new_tokens": self.max_new_tokens,