feat: add priority based scheduling with priority based request acceptance and preemption (#8746)

This commit is contained in:
harrisonlimh
2025-09-16 17:10:10 -07:00
committed by GitHub
parent f949ad5794
commit 14fdd52740
16 changed files with 822 additions and 71 deletions

View File

@@ -453,6 +453,7 @@ class Req:
bootstrap_room: Optional[int] = None,
data_parallel_rank: Optional[int] = None,
vocab_size: Optional[int] = None,
priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None,
):
# Input and output info
@@ -504,6 +505,7 @@ class Req:
self.stream = stream
self.eos_token_ids = eos_token_ids
self.vocab_size = vocab_size
self.priority = priority
# For incremental decoding
# ----- | --------- read_ids -------|
@@ -1517,37 +1519,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
idx = sorted_indices.pop()
req = self.reqs[idx]
retracted_reqs.append(req)
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator
)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = (
len(req.prefix_indices) // server_args.page_size
) * server_args.page_size
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
if self.is_hybrid:
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
else:
self.tree_cache.dec_lock_ref(req.last_node)
req.reset_for_retract()
self.release_req(idx, len(sorted_indices), server_args)
if len(retracted_reqs) == 0:
# Corner case: only one request left
@@ -1568,6 +1540,44 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
return retracted_reqs, new_estimate_ratio
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx]
seq_lens_cpu = self.seq_lens.cpu().numpy()
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator
)
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = (
len(req.prefix_indices) // server_args.page_size
) * server_args.page_size
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
if self.is_hybrid:
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
else:
self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = remaing_req_count * global_config.retract_decode_steps
self._evict_tree_cache_if_needed(num_tokens)
req.reset_for_retract()
def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)