[bug] fix errors related to context length in SD (#9388)
This commit is contained in:
@@ -1236,6 +1236,11 @@ class ModelRunner:
|
||||
|
||||
# Initialize req_to_token_pool
|
||||
if self.req_to_token_pool is None:
|
||||
# FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
|
||||
extra_max_context_len = 4
|
||||
if self.server_args.speculative_num_draft_tokens is not None:
|
||||
extra_max_context_len += self.server_args.speculative_num_draft_tokens
|
||||
|
||||
if self.server_args.disaggregation_mode == "decode":
|
||||
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
||||
|
||||
@@ -1244,7 +1249,8 @@ class ModelRunner:
|
||||
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
||||
self.req_to_token_pool = DecodeReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
max_context_len=self.model_config.context_len
|
||||
+ extra_max_context_len,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
pre_alloc_size=pre_alloc_size,
|
||||
@@ -1252,7 +1258,8 @@ class ModelRunner:
|
||||
else:
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
max_context_len=self.model_config.context_len
|
||||
+ extra_max_context_len,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user