[Minor] Improve style (#1666)
This commit is contained in:
@@ -203,6 +203,7 @@ class Req:
|
||||
self.prefix_indices = []
|
||||
self.extend_input_len = 0
|
||||
self.last_node = None
|
||||
self.is_inflight_req = 0
|
||||
|
||||
# Logprobs (arguments)
|
||||
self.return_logprob = False
|
||||
|
||||
@@ -45,7 +45,7 @@ class SchedulePolicy:
|
||||
def calc_priority(self, waiting_queue: List[Req]):
|
||||
# Compute matched prefix length
|
||||
prefix_computed = False
|
||||
if self.policy in ["lpm", "dfs-weight"]:
|
||||
if self.policy == "lpm" or self.policy == "dfs-weight":
|
||||
for r in waiting_queue:
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||
|
||||
@@ -194,7 +194,7 @@ class Scheduler:
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
self.running_batch: ScheduleBatch = None
|
||||
self.running_batch: Optional[ScheduleBatch] = None
|
||||
self.decode_forward_ct = 0
|
||||
self.stream_interval = server_args.stream_interval
|
||||
self.num_generated_tokens = 0
|
||||
@@ -273,6 +273,9 @@ class Scheduler:
|
||||
break
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
else:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = global_config.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
@@ -468,8 +471,6 @@ class Scheduler:
|
||||
|
||||
# Check memory
|
||||
if self.running_batch is None:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = global_config.init_new_token_ratio
|
||||
return
|
||||
|
||||
# Run decode
|
||||
@@ -489,9 +490,7 @@ class Scheduler:
|
||||
) and self.current_inflight_req is None:
|
||||
return None
|
||||
|
||||
running_bs = (
|
||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
)
|
||||
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
if running_bs >= self.max_running_requests:
|
||||
self.batch_is_full = True
|
||||
return None
|
||||
@@ -512,7 +511,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
has_inflight = self.current_inflight_req is not None
|
||||
if self.current_inflight_req is not None:
|
||||
if has_inflight:
|
||||
self.current_inflight_req.init_next_round_input(
|
||||
None if prefix_computed else self.tree_cache
|
||||
)
|
||||
@@ -520,7 +519,7 @@ class Scheduler:
|
||||
self.current_inflight_req
|
||||
)
|
||||
|
||||
if self.lora_paths is not None:
|
||||
if self.lora_paths:
|
||||
lora_set = (
|
||||
set([req.lora_path for req in self.running_batch.reqs])
|
||||
if self.running_batch is not None
|
||||
@@ -529,7 +528,7 @@ class Scheduler:
|
||||
|
||||
for req in self.waiting_queue:
|
||||
if (
|
||||
self.lora_paths is not None
|
||||
self.lora_paths
|
||||
and len(
|
||||
lora_set
|
||||
| set([req.lora_path for req in adder.can_run_list])
|
||||
@@ -551,16 +550,20 @@ class Scheduler:
|
||||
self.batch_is_full = True
|
||||
break
|
||||
|
||||
# Update waiting queue
|
||||
can_run_list = adder.can_run_list
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
self.waiting_queue = [
|
||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||
]
|
||||
|
||||
if adder.new_inflight_req is not None:
|
||||
assert self.current_inflight_req is None
|
||||
self.current_inflight_req = adder.new_inflight_req
|
||||
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
|
||||
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
||||
if self.current_inflight_req:
|
||||
self.current_inflight_req.is_inflight_req += 1
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0:
|
||||
@@ -613,13 +616,13 @@ class Scheduler:
|
||||
new_batch.prepare_for_extend(self.model_config.vocab_size)
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
decoding_reqs = []
|
||||
if self.is_mixed_chunk and self.running_batch is not None:
|
||||
self.running_batch.prepare_for_decode()
|
||||
new_batch.mix_with_running(self.running_batch)
|
||||
decoding_reqs = self.running_batch.reqs
|
||||
new_batch.decoding_reqs = self.running_batch.reqs
|
||||
self.running_batch = None
|
||||
new_batch.decoding_reqs = decoding_reqs
|
||||
else:
|
||||
new_batch.decoding_reqs = None
|
||||
|
||||
return new_batch
|
||||
|
||||
@@ -738,12 +741,12 @@ class Scheduler:
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
elif req not in batch.decoding_reqs:
|
||||
# To reduce overhead, only cache prefill reqs
|
||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
if req is self.current_inflight_req:
|
||||
if req.is_inflight_req > 0:
|
||||
# Inflight request would get a new req idx
|
||||
req.is_inflight_req -= 1
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
if req.return_logprob:
|
||||
@@ -768,8 +771,9 @@ class Scheduler:
|
||||
else:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
if req is self.current_inflight_req:
|
||||
if req.is_inflight_req > 0:
|
||||
# Inflight request would get a new req idx
|
||||
req.is_inflight_req -= 1
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
self.stream_output(batch)
|
||||
@@ -906,13 +910,11 @@ class Scheduler:
|
||||
else: # embedding or reward model
|
||||
output_embeddings = []
|
||||
|
||||
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
|
||||
|
||||
for req in batch.reqs:
|
||||
if req.finished() or (
|
||||
req.stream
|
||||
and (
|
||||
self.decode_forward_ct % self.stream_interval == 0
|
||||
or len(req.output_ids) == 1
|
||||
)
|
||||
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
output_finished_reason.append(req.finished_reason)
|
||||
|
||||
Reference in New Issue
Block a user