feat: add priority based scheduling with priority based request acceptance and preemption (#8746)
This commit is contained in:
@@ -453,6 +453,7 @@ class Req:
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
vocab_size: Optional[int] = None,
|
||||
priority: Optional[int] = None,
|
||||
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
||||
):
|
||||
# Input and output info
|
||||
@@ -504,6 +505,7 @@ class Req:
|
||||
self.stream = stream
|
||||
self.eos_token_ids = eos_token_ids
|
||||
self.vocab_size = vocab_size
|
||||
self.priority = priority
|
||||
|
||||
# For incremental decoding
|
||||
# ----- | --------- read_ids -------|
|
||||
@@ -1517,37 +1519,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
idx = sorted_indices.pop()
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
|
||||
if server_args.disaggregation_mode == "decode":
|
||||
req.offload_kv_cache(
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
||||
)
|
||||
|
||||
if isinstance(self.tree_cache, ChunkCache):
|
||||
# ChunkCache does not have eviction
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
else:
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = (
|
||||
len(req.prefix_indices) // server_args.page_size
|
||||
) * server_args.page_size
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
# release the last node
|
||||
if self.is_hybrid:
|
||||
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
||||
else:
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
req.reset_for_retract()
|
||||
self.release_req(idx, len(sorted_indices), server_args)
|
||||
|
||||
if len(retracted_reqs) == 0:
|
||||
# Corner case: only one request left
|
||||
@@ -1568,6 +1540,44 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
return retracted_reqs, new_estimate_ratio
|
||||
|
||||
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
||||
req = self.reqs[idx]
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
|
||||
if server_args.disaggregation_mode == "decode":
|
||||
req.offload_kv_cache(
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
||||
)
|
||||
if isinstance(self.tree_cache, ChunkCache):
|
||||
# ChunkCache does not have eviction
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
else:
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = (
|
||||
len(req.prefix_indices) // server_args.page_size
|
||||
) * server_args.page_size
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
# release the last node
|
||||
if self.is_hybrid:
|
||||
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
||||
else:
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||
num_tokens = remaing_req_count * global_config.retract_decode_steps
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
def prepare_encoder_info_decode(self):
|
||||
# Reset the encoder cached status
|
||||
self.encoder_cached = [True] * len(self.reqs)
|
||||
|
||||
Reference in New Issue
Block a user