diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 6aa7e39e1..3b3fef5c8 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -32,6 +32,7 @@ from sglang.srt.hf_transformers_utils import ( from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_bool_env_var, is_hip +from sglang.utils import is_in_ci logger = logging.getLogger(__name__) @@ -166,19 +167,20 @@ class ModelConfig: derived_context_len = get_context_length(self.hf_text_config) if context_length is not None: if context_length > derived_context_len: - if get_bool_env_var( - "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True" + reason = "Target model's" if is_draft_model else "User-specified" + msg = ( + f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " + f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config." + ) + if ( + get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") + or is_in_ci() # FIXME: fix this special case ): - logger.warning( - f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " - f"This may lead to incorrect model outputs or CUDA errors." - ) + logger.warning(msg) self.context_len = context_length else: raise ValueError( - f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " - f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. " - f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1" + f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1" ) else: self.context_len = context_length diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index adfdd0541..f4bda8688 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -576,7 +576,7 @@ class TokenizerManager: f"model's context length ({self.context_len} tokens). " "Truncating the input." ) - input_ids = input_ids[:_max_req_len] + del input_ids[_max_req_len:] input_token_num = len(input_ids) else: raise ValueError( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6665458b8..a30fb897f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1236,6 +1236,11 @@ class ModelRunner: # Initialize req_to_token_pool if self.req_to_token_pool is None: + # FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding + extra_max_context_len = 4 + if self.server_args.speculative_num_draft_tokens is not None: + extra_max_context_len += self.server_args.speculative_num_draft_tokens + if self.server_args.disaggregation_mode == "decode": from sglang.srt.disaggregation.decode import DecodeReqToTokenPool @@ -1244,7 +1249,8 @@ class ModelRunner: pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0 self.req_to_token_pool = DecodeReqToTokenPool( size=max_num_reqs, - max_context_len=self.model_config.context_len + 4, + max_context_len=self.model_config.context_len + + extra_max_context_len, device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, pre_alloc_size=pre_alloc_size, @@ -1252,7 +1258,8 @@ class ModelRunner: else: self.req_to_token_pool = ReqToTokenPool( size=max_num_reqs, - max_context_len=self.model_config.context_len + 4, + max_context_len=self.model_config.context_len + + extra_max_context_len, device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, ) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index e824fb1ae..3ee3b1c54 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -41,6 +41,7 @@ class EAGLEDraftCudaGraphRunner: # Parse args self.eagle_worker = eagle_worker self.model_runner = model_runner = eagle_worker.model_runner + self.model_runner: EAGLEWorker self.graphs = {} self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 972d7182d..9dc7438c9 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -9,7 +9,6 @@ from huggingface_hub import snapshot_download from sglang.srt.distributed import ( GroupCoordinator, - get_tensor_model_parallel_world_size, get_tp_group, patch_tensor_parallel_group, ) @@ -92,7 +91,7 @@ class EAGLEWorker(TpModelWorker): ) self.padded_static_len = -1 - # Override context length with target model's context length + # Override the context length of the draft model to be the same as the target model. server_args.context_length = target_worker.model_runner.model_config.context_len # Do not capture cuda graph in `super().__init__()`