Fix some online scheduling delay (#1345)
This commit is contained in:
@@ -119,6 +119,7 @@ class PrefillAdder:
|
|||||||
self.running_batch = running_batch
|
self.running_batch = running_batch
|
||||||
self.new_token_ratio = new_token_ratio
|
self.new_token_ratio = new_token_ratio
|
||||||
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
||||||
|
self.rem_total_tokens_ = self.rem_total_tokens
|
||||||
self.total_tokens = rem_total_tokens
|
self.total_tokens = rem_total_tokens
|
||||||
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
||||||
self.rem_chunk_tokens = rem_chunk_tokens
|
self.rem_chunk_tokens = rem_chunk_tokens
|
||||||
@@ -153,11 +154,18 @@ class PrefillAdder:
|
|||||||
for r in running_batch.reqs
|
for r in running_batch.reqs
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
self.rem_total_tokens_ -= sum(
|
||||||
|
[
|
||||||
|
r.sampling_params.max_new_tokens - len(r.output_ids)
|
||||||
|
for r in running_batch.reqs
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def _prefill_one_req(
|
def _prefill_one_req(
|
||||||
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
||||||
):
|
):
|
||||||
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
||||||
|
self.rem_total_tokens_ -= extend_input_len + max_new_tokens
|
||||||
self.rem_input_tokens -= extend_input_len
|
self.rem_input_tokens -= extend_input_len
|
||||||
if self.rem_chunk_tokens is not None:
|
if self.rem_chunk_tokens is not None:
|
||||||
self.rem_chunk_tokens -= extend_input_len
|
self.rem_chunk_tokens -= extend_input_len
|
||||||
@@ -231,6 +239,15 @@ class PrefillAdder:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Quick Check
|
||||||
|
can_run = False
|
||||||
|
if (
|
||||||
|
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||||
|
<= self.rem_total_tokens
|
||||||
|
):
|
||||||
|
can_run = True
|
||||||
|
|
||||||
|
if not can_run:
|
||||||
if self.req_states is None:
|
if self.req_states is None:
|
||||||
self.req_states = []
|
self.req_states = []
|
||||||
if self.running_batch is not None:
|
if self.running_batch is not None:
|
||||||
|
|||||||
@@ -231,6 +231,7 @@ class ModelTpServer:
|
|||||||
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||||
):
|
):
|
||||||
self.handle_generate_request(recv_req)
|
self.handle_generate_request(recv_req)
|
||||||
|
self.do_not_get_new_batch = False
|
||||||
elif isinstance(recv_req, FlushCacheReq):
|
elif isinstance(recv_req, FlushCacheReq):
|
||||||
self.flush_cache()
|
self.flush_cache()
|
||||||
elif isinstance(recv_req, AbortReq):
|
elif isinstance(recv_req, AbortReq):
|
||||||
@@ -254,12 +255,10 @@ class ModelTpServer:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_step(self):
|
def forward_step(self):
|
||||||
if self.current_inflight_req is not None:
|
if self.do_not_get_new_batch and self.current_inflight_req is None:
|
||||||
self.do_not_get_new_batch = False
|
new_batch = None
|
||||||
|
else:
|
||||||
new_batch = (
|
new_batch = self.get_new_prefill_batch()
|
||||||
self.get_new_prefill_batch() if not self.do_not_get_new_batch else None
|
|
||||||
)
|
|
||||||
self.do_not_get_new_batch = False
|
self.do_not_get_new_batch = False
|
||||||
|
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user