[Minor] Improve style (#1666)
This commit is contained in:
@@ -203,6 +203,7 @@ class Req:
|
|||||||
self.prefix_indices = []
|
self.prefix_indices = []
|
||||||
self.extend_input_len = 0
|
self.extend_input_len = 0
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
|
self.is_inflight_req = 0
|
||||||
|
|
||||||
# Logprobs (arguments)
|
# Logprobs (arguments)
|
||||||
self.return_logprob = False
|
self.return_logprob = False
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class SchedulePolicy:
|
|||||||
def calc_priority(self, waiting_queue: List[Req]):
|
def calc_priority(self, waiting_queue: List[Req]):
|
||||||
# Compute matched prefix length
|
# Compute matched prefix length
|
||||||
prefix_computed = False
|
prefix_computed = False
|
||||||
if self.policy in ["lpm", "dfs-weight"]:
|
if self.policy == "lpm" or self.policy == "dfs-weight":
|
||||||
for r in waiting_queue:
|
for r in waiting_queue:
|
||||||
# NOTE: the prefix_indices must always be aligned with last_node
|
# NOTE: the prefix_indices must always be aligned with last_node
|
||||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ class Scheduler:
|
|||||||
|
|
||||||
# Init running status
|
# Init running status
|
||||||
self.waiting_queue: List[Req] = []
|
self.waiting_queue: List[Req] = []
|
||||||
self.running_batch: ScheduleBatch = None
|
self.running_batch: Optional[ScheduleBatch] = None
|
||||||
self.decode_forward_ct = 0
|
self.decode_forward_ct = 0
|
||||||
self.stream_interval = server_args.stream_interval
|
self.stream_interval = server_args.stream_interval
|
||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
@@ -273,6 +273,9 @@ class Scheduler:
|
|||||||
break
|
break
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.process_batch_result(batch, result)
|
self.process_batch_result(batch, result)
|
||||||
|
else:
|
||||||
|
self.check_memory()
|
||||||
|
self.new_token_ratio = global_config.init_new_token_ratio
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
@@ -468,8 +471,6 @@ class Scheduler:
|
|||||||
|
|
||||||
# Check memory
|
# Check memory
|
||||||
if self.running_batch is None:
|
if self.running_batch is None:
|
||||||
self.check_memory()
|
|
||||||
self.new_token_ratio = global_config.init_new_token_ratio
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Run decode
|
# Run decode
|
||||||
@@ -489,9 +490,7 @@ class Scheduler:
|
|||||||
) and self.current_inflight_req is None:
|
) and self.current_inflight_req is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
running_bs = (
|
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
|
||||||
)
|
|
||||||
if running_bs >= self.max_running_requests:
|
if running_bs >= self.max_running_requests:
|
||||||
self.batch_is_full = True
|
self.batch_is_full = True
|
||||||
return None
|
return None
|
||||||
@@ -512,7 +511,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
has_inflight = self.current_inflight_req is not None
|
has_inflight = self.current_inflight_req is not None
|
||||||
if self.current_inflight_req is not None:
|
if has_inflight:
|
||||||
self.current_inflight_req.init_next_round_input(
|
self.current_inflight_req.init_next_round_input(
|
||||||
None if prefix_computed else self.tree_cache
|
None if prefix_computed else self.tree_cache
|
||||||
)
|
)
|
||||||
@@ -520,7 +519,7 @@ class Scheduler:
|
|||||||
self.current_inflight_req
|
self.current_inflight_req
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.lora_paths is not None:
|
if self.lora_paths:
|
||||||
lora_set = (
|
lora_set = (
|
||||||
set([req.lora_path for req in self.running_batch.reqs])
|
set([req.lora_path for req in self.running_batch.reqs])
|
||||||
if self.running_batch is not None
|
if self.running_batch is not None
|
||||||
@@ -529,7 +528,7 @@ class Scheduler:
|
|||||||
|
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
if (
|
if (
|
||||||
self.lora_paths is not None
|
self.lora_paths
|
||||||
and len(
|
and len(
|
||||||
lora_set
|
lora_set
|
||||||
| set([req.lora_path for req in adder.can_run_list])
|
| set([req.lora_path for req in adder.can_run_list])
|
||||||
@@ -551,16 +550,20 @@ class Scheduler:
|
|||||||
self.batch_is_full = True
|
self.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Update waiting queue
|
||||||
can_run_list = adder.can_run_list
|
can_run_list = adder.can_run_list
|
||||||
|
if len(can_run_list) == 0:
|
||||||
|
return None
|
||||||
|
self.waiting_queue = [
|
||||||
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||||
|
]
|
||||||
|
|
||||||
if adder.new_inflight_req is not None:
|
if adder.new_inflight_req is not None:
|
||||||
assert self.current_inflight_req is None
|
assert self.current_inflight_req is None
|
||||||
self.current_inflight_req = adder.new_inflight_req
|
self.current_inflight_req = adder.new_inflight_req
|
||||||
|
|
||||||
if len(can_run_list) == 0:
|
if self.current_inflight_req:
|
||||||
return None
|
self.current_inflight_req.is_inflight_req += 1
|
||||||
|
|
||||||
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
|
||||||
|
|
||||||
# Print stats
|
# Print stats
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
@@ -613,13 +616,13 @@ class Scheduler:
|
|||||||
new_batch.prepare_for_extend(self.model_config.vocab_size)
|
new_batch.prepare_for_extend(self.model_config.vocab_size)
|
||||||
|
|
||||||
# Mixed-style chunked prefill
|
# Mixed-style chunked prefill
|
||||||
decoding_reqs = []
|
|
||||||
if self.is_mixed_chunk and self.running_batch is not None:
|
if self.is_mixed_chunk and self.running_batch is not None:
|
||||||
self.running_batch.prepare_for_decode()
|
self.running_batch.prepare_for_decode()
|
||||||
new_batch.mix_with_running(self.running_batch)
|
new_batch.mix_with_running(self.running_batch)
|
||||||
decoding_reqs = self.running_batch.reqs
|
new_batch.decoding_reqs = self.running_batch.reqs
|
||||||
self.running_batch = None
|
self.running_batch = None
|
||||||
new_batch.decoding_reqs = decoding_reqs
|
else:
|
||||||
|
new_batch.decoding_reqs = None
|
||||||
|
|
||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
@@ -738,12 +741,12 @@ class Scheduler:
|
|||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
elif req not in batch.decoding_reqs:
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||||
# To reduce overhead, only cache prefill reqs
|
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
if req is self.current_inflight_req:
|
if req.is_inflight_req > 0:
|
||||||
# Inflight request would get a new req idx
|
# Inflight request would get a new req idx
|
||||||
|
req.is_inflight_req -= 1
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
@@ -768,8 +771,9 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
if req is self.current_inflight_req:
|
if req.is_inflight_req > 0:
|
||||||
# Inflight request would get a new req idx
|
# Inflight request would get a new req idx
|
||||||
|
req.is_inflight_req -= 1
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
self.stream_output(batch)
|
self.stream_output(batch)
|
||||||
@@ -906,13 +910,11 @@ class Scheduler:
|
|||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
output_embeddings = []
|
output_embeddings = []
|
||||||
|
|
||||||
|
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
|
||||||
|
|
||||||
for req in batch.reqs:
|
for req in batch.reqs:
|
||||||
if req.finished() or (
|
if req.finished() or (
|
||||||
req.stream
|
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
||||||
and (
|
|
||||||
self.decode_forward_ct % self.stream_interval == 0
|
|
||||||
or len(req.output_ids) == 1
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
output_rids.append(req.rid)
|
output_rids.append(req.rid)
|
||||||
output_finished_reason.append(req.finished_reason)
|
output_finished_reason.append(req.finished_reason)
|
||||||
|
|||||||
Reference in New Issue
Block a user