Fix some online scheduling delay (#1345)
This commit is contained in:
@@ -119,6 +119,7 @@ class PrefillAdder:
|
||||
self.running_batch = running_batch
|
||||
self.new_token_ratio = new_token_ratio
|
||||
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
||||
self.rem_total_tokens_ = self.rem_total_tokens
|
||||
self.total_tokens = rem_total_tokens
|
||||
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
||||
self.rem_chunk_tokens = rem_chunk_tokens
|
||||
@@ -153,11 +154,18 @@ class PrefillAdder:
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
)
|
||||
self.rem_total_tokens_ -= sum(
|
||||
[
|
||||
r.sampling_params.max_new_tokens - len(r.output_ids)
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
)
|
||||
|
||||
def _prefill_one_req(
|
||||
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
||||
):
|
||||
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
||||
self.rem_total_tokens_ -= extend_input_len + max_new_tokens
|
||||
self.rem_input_tokens -= extend_input_len
|
||||
if self.rem_chunk_tokens is not None:
|
||||
self.rem_chunk_tokens -= extend_input_len
|
||||
@@ -231,43 +239,52 @@ class PrefillAdder:
|
||||
|
||||
return None
|
||||
|
||||
if self.req_states is None:
|
||||
self.req_states = []
|
||||
if self.running_batch is not None:
|
||||
for r in self.running_batch.reqs:
|
||||
# Quick Check
|
||||
can_run = False
|
||||
if (
|
||||
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||
<= self.rem_total_tokens
|
||||
):
|
||||
can_run = True
|
||||
|
||||
if not can_run:
|
||||
if self.req_states is None:
|
||||
self.req_states = []
|
||||
if self.running_batch is not None:
|
||||
for r in self.running_batch.reqs:
|
||||
state = get_req_state(r)
|
||||
if state is not None:
|
||||
self.req_states.append(state)
|
||||
for r in self.can_run_list:
|
||||
state = get_req_state(r)
|
||||
if state is not None:
|
||||
self.req_states.append(state)
|
||||
for r in self.can_run_list:
|
||||
state = get_req_state(r)
|
||||
state = get_req_state(req)
|
||||
if state is not None:
|
||||
self.req_states.append(state)
|
||||
state = get_req_state(req)
|
||||
if state is not None:
|
||||
self.req_states.append(state)
|
||||
|
||||
self.req_states.sort(key=lambda x: x[0])
|
||||
else:
|
||||
state = get_req_state(req)
|
||||
if state is not None:
|
||||
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||
if tokens_left >= state[0]:
|
||||
self.req_states.insert(i, state)
|
||||
break
|
||||
else:
|
||||
self.req_states.append(state)
|
||||
self.req_states.sort(key=lambda x: x[0])
|
||||
else:
|
||||
state = get_req_state(req)
|
||||
if state is not None:
|
||||
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||
if tokens_left >= state[0]:
|
||||
self.req_states.insert(i, state)
|
||||
break
|
||||
else:
|
||||
self.req_states.append(state)
|
||||
|
||||
tokens_freed = 0
|
||||
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||
decode_steps = (
|
||||
self.req_states[i + 1][0]
|
||||
if i + 1 < len(self.req_states)
|
||||
else tokens_left
|
||||
)
|
||||
bs = len(self.req_states) - i
|
||||
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
|
||||
return False
|
||||
tokens_freed += tokens_occupied
|
||||
tokens_freed = 0
|
||||
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||
decode_steps = (
|
||||
self.req_states[i + 1][0]
|
||||
if i + 1 < len(self.req_states)
|
||||
else tokens_left
|
||||
)
|
||||
bs = len(self.req_states) - i
|
||||
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
|
||||
return False
|
||||
tokens_freed += tokens_occupied
|
||||
|
||||
if req.extend_input_len <= self.rem_chunk_tokens:
|
||||
self.can_run_list.append(req)
|
||||
|
||||
@@ -231,6 +231,7 @@ class ModelTpServer:
|
||||
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||
):
|
||||
self.handle_generate_request(recv_req)
|
||||
self.do_not_get_new_batch = False
|
||||
elif isinstance(recv_req, FlushCacheReq):
|
||||
self.flush_cache()
|
||||
elif isinstance(recv_req, AbortReq):
|
||||
@@ -254,12 +255,10 @@ class ModelTpServer:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_step(self):
|
||||
if self.current_inflight_req is not None:
|
||||
self.do_not_get_new_batch = False
|
||||
|
||||
new_batch = (
|
||||
self.get_new_prefill_batch() if not self.do_not_get_new_batch else None
|
||||
)
|
||||
if self.do_not_get_new_batch and self.current_inflight_req is None:
|
||||
new_batch = None
|
||||
else:
|
||||
new_batch = self.get_new_prefill_batch()
|
||||
self.do_not_get_new_batch = False
|
||||
|
||||
if new_batch is not None:
|
||||
|
||||
Reference in New Issue
Block a user