diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 69a617371..574a031ad 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -15,7 +15,8 @@ import json import logging from enum import IntEnum, auto -from typing import List, Optional, Union +from functools import lru_cache +from typing import List, Optional, Set, Union import torch from transformers import PretrainedConfig @@ -271,6 +272,14 @@ class ModelConfig: self.quantization, ) + @lru_cache() + def get_hf_eos_token_id(self) -> Optional[Set[int]]: + eos_ids = getattr(self.hf_config, "eos_token_id", None) + if eos_ids: + # it can be either int or list of int + eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) + return eos_ids + def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 7081f0d0c..ee2884df8 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch import dataclasses import logging -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Set, Tuple, Union import numpy as np import torch @@ -209,6 +209,7 @@ class Req: lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, session_id: Optional[str] = None, + eos_token_ids: Optional[Set[int]] = None, ): # Input and output info self.rid = rid @@ -236,6 +237,7 @@ class Req: self.finished_reason = None self.to_abort = False self.stream = stream + self.eos_token_ids = eos_token_ids # For incremental decoding # ----- | --------- read_ids -------| @@ -395,18 +397,23 @@ class Req: last_token_id = self.output_ids[-1] - matched_eos = False + if not self.sampling_params.ignore_eos: + matched_eos = False - # Check stop token ids - if self.sampling_params.stop_token_ids: - matched_eos = last_token_id in self.sampling_params.stop_token_ids - if self.tokenizer is not None: - matched_eos |= last_token_id == self.tokenizer.eos_token_id - if self.tokenizer.additional_stop_token_ids: - matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids - if matched_eos and not self.sampling_params.ignore_eos: - self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) - return + # Check stop token ids + if self.sampling_params.stop_token_ids: + matched_eos = last_token_id in self.sampling_params.stop_token_ids + if self.eos_token_ids: + matched_eos |= last_token_id in self.eos_token_ids + if self.tokenizer is not None: + matched_eos |= last_token_id == self.tokenizer.eos_token_id + if self.tokenizer.additional_stop_token_ids: + matched_eos |= ( + last_token_id in self.tokenizer.additional_stop_token_ids + ) + if matched_eos: + self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) + return # Check stop strings if len(self.sampling_params.stop_strs) > 0: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 39c0b6af8..8fe10eb99 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -517,6 +517,7 @@ class Scheduler: stream=recv_req.stream, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, + eos_token_ids=self.model_config.get_hf_eos_token_id(), ) req.tokenizer = self.tokenizer