[Hack] Add pd-disaggregation decode polling interval (#10411)
This commit is contained in:
@@ -886,9 +886,18 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
# if there are still retracted requests, we do not allocate new requests
|
# if there are still retracted requests, we do not allocate new requests
|
||||||
return
|
return
|
||||||
|
|
||||||
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
if not hasattr(self, "polling_count"):
|
||||||
self.disagg_decode_transfer_queue.extend(req_conns)
|
self.polling_count = 0
|
||||||
alloc_reqs = (
|
self.polling_interval = (
|
||||||
self.disagg_decode_transfer_queue.pop_transferred()
|
self.server_args.disaggregation_decode_polling_interval
|
||||||
) # the requests which kv has arrived
|
)
|
||||||
self.waiting_queue.extend(alloc_reqs)
|
|
||||||
|
self.polling_count = (self.polling_count + 1) % self.polling_interval
|
||||||
|
|
||||||
|
if self.polling_count % self.polling_interval == 0:
|
||||||
|
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
||||||
|
self.disagg_decode_transfer_queue.extend(req_conns)
|
||||||
|
alloc_reqs = (
|
||||||
|
self.disagg_decode_transfer_queue.pop_transferred()
|
||||||
|
) # the requests which kv has arrived
|
||||||
|
self.waiting_queue.extend(alloc_reqs)
|
||||||
|
|||||||
@@ -394,6 +394,9 @@ class ServerArgs:
|
|||||||
disaggregation_ib_device: Optional[str] = None
|
disaggregation_ib_device: Optional[str] = None
|
||||||
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
||||||
|
|
||||||
|
# FIXME: hack to reduce ITL when decode bs is small
|
||||||
|
disaggregation_decode_polling_interval: int = 1
|
||||||
|
|
||||||
# For model weight update
|
# For model weight update
|
||||||
custom_weight_loader: Optional[List[str]] = None
|
custom_weight_loader: Optional[List[str]] = None
|
||||||
weight_loader_disable_mmap: bool = False
|
weight_loader_disable_mmap: bool = False
|
||||||
@@ -2245,6 +2248,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.num_reserved_decode_tokens,
|
default=ServerArgs.num_reserved_decode_tokens,
|
||||||
help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
|
help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-decode-polling-interval",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.disaggregation_decode_polling_interval,
|
||||||
|
help="The interval to poll requests in decode server. Can be set to >1 to reduce the overhead of this.",
|
||||||
|
)
|
||||||
|
|
||||||
# Custom weight loader
|
# Custom weight loader
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user