Add decode req pool (#6980)

This commit is contained in:
Byron Hsu
2025-06-09 21:23:36 -07:00
committed by GitHub
parent f6ebba537a
commit c2b16795b5
2 changed files with 83 additions and 7 deletions

View File

@@ -916,12 +916,26 @@ class ModelRunner:
)
if self.req_to_token_pool is None:
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
if self.server_args.disaggregation_mode == "decode":
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
# subscribe memory for pre-allocated requests
# if max_num_reqs <= 32, we pre-allocate 2x requests
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
self.req_to_token_pool = DecodeReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
else:
self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len + 4,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
else:
# Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker