Fix some online scheduling delay (#1345)
This commit is contained in:
@@ -119,6 +119,7 @@ class PrefillAdder:
|
|||||||
self.running_batch = running_batch
|
self.running_batch = running_batch
|
||||||
self.new_token_ratio = new_token_ratio
|
self.new_token_ratio = new_token_ratio
|
||||||
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
||||||
|
self.rem_total_tokens_ = self.rem_total_tokens
|
||||||
self.total_tokens = rem_total_tokens
|
self.total_tokens = rem_total_tokens
|
||||||
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
||||||
self.rem_chunk_tokens = rem_chunk_tokens
|
self.rem_chunk_tokens = rem_chunk_tokens
|
||||||
@@ -153,11 +154,18 @@ class PrefillAdder:
|
|||||||
for r in running_batch.reqs
|
for r in running_batch.reqs
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
self.rem_total_tokens_ -= sum(
|
||||||
|
[
|
||||||
|
r.sampling_params.max_new_tokens - len(r.output_ids)
|
||||||
|
for r in running_batch.reqs
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def _prefill_one_req(
|
def _prefill_one_req(
|
||||||
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
||||||
):
|
):
|
||||||
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
||||||
|
self.rem_total_tokens_ -= extend_input_len + max_new_tokens
|
||||||
self.rem_input_tokens -= extend_input_len
|
self.rem_input_tokens -= extend_input_len
|
||||||
if self.rem_chunk_tokens is not None:
|
if self.rem_chunk_tokens is not None:
|
||||||
self.rem_chunk_tokens -= extend_input_len
|
self.rem_chunk_tokens -= extend_input_len
|
||||||
@@ -231,43 +239,52 @@ class PrefillAdder:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.req_states is None:
|
# Quick Check
|
||||||
self.req_states = []
|
can_run = False
|
||||||
if self.running_batch is not None:
|
if (
|
||||||
for r in self.running_batch.reqs:
|
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||||
|
<= self.rem_total_tokens
|
||||||
|
):
|
||||||
|
can_run = True
|
||||||
|
|
||||||
|
if not can_run:
|
||||||
|
if self.req_states is None:
|
||||||
|
self.req_states = []
|
||||||
|
if self.running_batch is not None:
|
||||||
|
for r in self.running_batch.reqs:
|
||||||
|
state = get_req_state(r)
|
||||||
|
if state is not None:
|
||||||
|
self.req_states.append(state)
|
||||||
|
for r in self.can_run_list:
|
||||||
state = get_req_state(r)
|
state = get_req_state(r)
|
||||||
if state is not None:
|
if state is not None:
|
||||||
self.req_states.append(state)
|
self.req_states.append(state)
|
||||||
for r in self.can_run_list:
|
state = get_req_state(req)
|
||||||
state = get_req_state(r)
|
|
||||||
if state is not None:
|
if state is not None:
|
||||||
self.req_states.append(state)
|
self.req_states.append(state)
|
||||||
state = get_req_state(req)
|
|
||||||
if state is not None:
|
|
||||||
self.req_states.append(state)
|
|
||||||
|
|
||||||
self.req_states.sort(key=lambda x: x[0])
|
self.req_states.sort(key=lambda x: x[0])
|
||||||
else:
|
else:
|
||||||
state = get_req_state(req)
|
state = get_req_state(req)
|
||||||
if state is not None:
|
if state is not None:
|
||||||
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||||
if tokens_left >= state[0]:
|
if tokens_left >= state[0]:
|
||||||
self.req_states.insert(i, state)
|
self.req_states.insert(i, state)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.req_states.append(state)
|
self.req_states.append(state)
|
||||||
|
|
||||||
tokens_freed = 0
|
tokens_freed = 0
|
||||||
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||||
decode_steps = (
|
decode_steps = (
|
||||||
self.req_states[i + 1][0]
|
self.req_states[i + 1][0]
|
||||||
if i + 1 < len(self.req_states)
|
if i + 1 < len(self.req_states)
|
||||||
else tokens_left
|
else tokens_left
|
||||||
)
|
)
|
||||||
bs = len(self.req_states) - i
|
bs = len(self.req_states) - i
|
||||||
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
|
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
|
||||||
return False
|
return False
|
||||||
tokens_freed += tokens_occupied
|
tokens_freed += tokens_occupied
|
||||||
|
|
||||||
if req.extend_input_len <= self.rem_chunk_tokens:
|
if req.extend_input_len <= self.rem_chunk_tokens:
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
|
|||||||
@@ -231,6 +231,7 @@ class ModelTpServer:
|
|||||||
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||||
):
|
):
|
||||||
self.handle_generate_request(recv_req)
|
self.handle_generate_request(recv_req)
|
||||||
|
self.do_not_get_new_batch = False
|
||||||
elif isinstance(recv_req, FlushCacheReq):
|
elif isinstance(recv_req, FlushCacheReq):
|
||||||
self.flush_cache()
|
self.flush_cache()
|
||||||
elif isinstance(recv_req, AbortReq):
|
elif isinstance(recv_req, AbortReq):
|
||||||
@@ -254,12 +255,10 @@ class ModelTpServer:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_step(self):
|
def forward_step(self):
|
||||||
if self.current_inflight_req is not None:
|
if self.do_not_get_new_batch and self.current_inflight_req is None:
|
||||||
self.do_not_get_new_batch = False
|
new_batch = None
|
||||||
|
else:
|
||||||
new_batch = (
|
new_batch = self.get_new_prefill_batch()
|
||||||
self.get_new_prefill_batch() if not self.do_not_get_new_batch else None
|
|
||||||
)
|
|
||||||
self.do_not_get_new_batch = False
|
self.do_not_get_new_batch = False
|
||||||
|
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user