[bug] fix errors related to context length in SD (#9388)
This commit is contained in:
@@ -32,6 +32,7 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||||
|
from sglang.utils import is_in_ci
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -166,19 +167,20 @@ class ModelConfig:
|
|||||||
derived_context_len = get_context_length(self.hf_text_config)
|
derived_context_len = get_context_length(self.hf_text_config)
|
||||||
if context_length is not None:
|
if context_length is not None:
|
||||||
if context_length > derived_context_len:
|
if context_length > derived_context_len:
|
||||||
if get_bool_env_var(
|
reason = "Target model's" if is_draft_model else "User-specified"
|
||||||
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True"
|
msg = (
|
||||||
|
f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
||||||
|
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
|
||||||
|
or is_in_ci() # FIXME: fix this special case
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(msg)
|
||||||
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
|
||||||
f"This may lead to incorrect model outputs or CUDA errors."
|
|
||||||
)
|
|
||||||
self.context_len = context_length
|
self.context_len = context_length
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
|
f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
|
||||||
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
|
|
||||||
f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.context_len = context_length
|
self.context_len = context_length
|
||||||
|
|||||||
@@ -576,7 +576,7 @@ class TokenizerManager:
|
|||||||
f"model's context length ({self.context_len} tokens). "
|
f"model's context length ({self.context_len} tokens). "
|
||||||
"Truncating the input."
|
"Truncating the input."
|
||||||
)
|
)
|
||||||
input_ids = input_ids[:_max_req_len]
|
del input_ids[_max_req_len:]
|
||||||
input_token_num = len(input_ids)
|
input_token_num = len(input_ids)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -1236,6 +1236,11 @@ class ModelRunner:
|
|||||||
|
|
||||||
# Initialize req_to_token_pool
|
# Initialize req_to_token_pool
|
||||||
if self.req_to_token_pool is None:
|
if self.req_to_token_pool is None:
|
||||||
|
# FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
|
||||||
|
extra_max_context_len = 4
|
||||||
|
if self.server_args.speculative_num_draft_tokens is not None:
|
||||||
|
extra_max_context_len += self.server_args.speculative_num_draft_tokens
|
||||||
|
|
||||||
if self.server_args.disaggregation_mode == "decode":
|
if self.server_args.disaggregation_mode == "decode":
|
||||||
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
||||||
|
|
||||||
@@ -1244,7 +1249,8 @@ class ModelRunner:
|
|||||||
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
||||||
self.req_to_token_pool = DecodeReqToTokenPool(
|
self.req_to_token_pool = DecodeReqToTokenPool(
|
||||||
size=max_num_reqs,
|
size=max_num_reqs,
|
||||||
max_context_len=self.model_config.context_len + 4,
|
max_context_len=self.model_config.context_len
|
||||||
|
+ extra_max_context_len,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
pre_alloc_size=pre_alloc_size,
|
pre_alloc_size=pre_alloc_size,
|
||||||
@@ -1252,7 +1258,8 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
size=max_num_reqs,
|
size=max_num_reqs,
|
||||||
max_context_len=self.model_config.context_len + 4,
|
max_context_len=self.model_config.context_len
|
||||||
|
+ extra_max_context_len,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
# Parse args
|
# Parse args
|
||||||
self.eagle_worker = eagle_worker
|
self.eagle_worker = eagle_worker
|
||||||
self.model_runner = model_runner = eagle_worker.model_runner
|
self.model_runner = model_runner = eagle_worker.model_runner
|
||||||
|
self.model_runner: EAGLEWorker
|
||||||
self.graphs = {}
|
self.graphs = {}
|
||||||
self.output_buffers = {}
|
self.output_buffers = {}
|
||||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from huggingface_hub import snapshot_download
|
|||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
GroupCoordinator,
|
GroupCoordinator,
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
get_tp_group,
|
get_tp_group,
|
||||||
patch_tensor_parallel_group,
|
patch_tensor_parallel_group,
|
||||||
)
|
)
|
||||||
@@ -92,7 +91,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
self.padded_static_len = -1
|
self.padded_static_len = -1
|
||||||
|
|
||||||
# Override context length with target model's context length
|
# Override the context length of the draft model to be the same as the target model.
|
||||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||||
|
|
||||||
# Do not capture cuda graph in `super().__init__()`
|
# Do not capture cuda graph in `super().__init__()`
|
||||||
|
|||||||
Reference in New Issue
Block a user