Add decode req pool (#6980)
This commit is contained in:
@@ -916,12 +916,26 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
if self.req_to_token_pool is None:
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
if self.server_args.disaggregation_mode == "decode":
|
||||
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
|
||||
|
||||
# subscribe memory for pre-allocated requests
|
||||
# if max_num_reqs <= 32, we pre-allocate 2x requests
|
||||
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
||||
self.req_to_token_pool = DecodeReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
pre_alloc_size=pre_alloc_size,
|
||||
)
|
||||
else:
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
else:
|
||||
# Draft worker shares req_to_token_pool with the target worker.
|
||||
assert self.is_draft_worker
|
||||
|
||||
Reference in New Issue
Block a user