fix schedule bug (#1450)
This commit is contained in:
@@ -119,19 +119,32 @@ class PrefillAdder:
|
||||
self.running_batch = running_batch
|
||||
self.new_token_ratio = new_token_ratio
|
||||
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
||||
self.rem_total_tokens_ = self.rem_total_tokens
|
||||
self.total_tokens = rem_total_tokens
|
||||
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
||||
self.rem_chunk_tokens = rem_chunk_tokens
|
||||
if self.rem_chunk_tokens is not None:
|
||||
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
||||
|
||||
self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
|
||||
|
||||
self.req_states = None
|
||||
self.can_run_list = []
|
||||
self.new_inflight_req = None
|
||||
self.log_hit_tokens = 0
|
||||
self.log_input_tokens = 0
|
||||
|
||||
if running_batch is not None:
|
||||
# Pre-remove the tokens which will be occupied by the running requests
|
||||
self.rem_total_tokens -= sum(
|
||||
[
|
||||
min(
|
||||
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
||||
CLIP_MAX_NEW_TOKENS,
|
||||
)
|
||||
* self.new_token_ratio
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
)
|
||||
|
||||
def no_remaining_tokens(self):
|
||||
return (
|
||||
self.rem_total_tokens <= 0
|
||||
@@ -141,31 +154,14 @@ class PrefillAdder:
|
||||
if self.rem_chunk_tokens is not None
|
||||
else False
|
||||
)
|
||||
)
|
||||
|
||||
def remove_running_tokens(self, running_batch: ScheduleBatch):
|
||||
self.rem_total_tokens -= sum(
|
||||
[
|
||||
min(
|
||||
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
||||
CLIP_MAX_NEW_TOKENS,
|
||||
)
|
||||
* self.new_token_ratio
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
)
|
||||
self.rem_total_tokens_ -= sum(
|
||||
[
|
||||
r.sampling_params.max_new_tokens - len(r.output_ids)
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
or self.cur_rem_tokens <= 0
|
||||
)
|
||||
|
||||
def _prefill_one_req(
|
||||
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
||||
):
|
||||
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
||||
self.rem_total_tokens_ -= extend_input_len + max_new_tokens
|
||||
self.cur_rem_tokens -= extend_input_len
|
||||
self.rem_input_tokens -= extend_input_len
|
||||
if self.rem_chunk_tokens is not None:
|
||||
self.rem_chunk_tokens -= extend_input_len
|
||||
@@ -173,29 +169,7 @@ class PrefillAdder:
|
||||
self.log_hit_tokens += prefix_len
|
||||
self.log_input_tokens += extend_input_len
|
||||
|
||||
def add_inflight_req_ignore_eos(self, req: Req):
|
||||
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||
self.can_run_list.append(req)
|
||||
|
||||
self._prefill_one_req(
|
||||
0,
|
||||
req.extend_input_len,
|
||||
(
|
||||
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
||||
if not truncated
|
||||
else 0
|
||||
),
|
||||
)
|
||||
|
||||
# Return if chunked prefill not finished
|
||||
return req if truncated else None
|
||||
|
||||
def add_inflight_req(self, req: Req):
|
||||
if req.sampling_params.ignore_eos:
|
||||
return self.add_inflight_req_ignore_eos(req)
|
||||
|
||||
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||
@@ -225,7 +199,7 @@ class PrefillAdder:
|
||||
self.rem_total_tokens += delta
|
||||
|
||||
def add_one_req_ignore_eos(self, req: Req):
|
||||
def get_req_state(r):
|
||||
def add_req_state(r, insert_sort=False):
|
||||
new_token_ratio = (
|
||||
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
||||
)
|
||||
@@ -235,56 +209,37 @@ class PrefillAdder:
|
||||
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
|
||||
|
||||
if tokens_left > 0:
|
||||
return (tokens_left, tokens_occupied)
|
||||
|
||||
return None
|
||||
|
||||
# Quick Check
|
||||
can_run = False
|
||||
if (
|
||||
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||
<= self.rem_total_tokens
|
||||
):
|
||||
can_run = True
|
||||
|
||||
if not can_run:
|
||||
if self.req_states is None:
|
||||
self.req_states = []
|
||||
if self.running_batch is not None:
|
||||
for r in self.running_batch.reqs:
|
||||
state = get_req_state(r)
|
||||
if state is not None:
|
||||
self.req_states.append(state)
|
||||
for r in self.can_run_list:
|
||||
state = get_req_state(r)
|
||||
if state is not None:
|
||||
self.req_states.append(state)
|
||||
state = get_req_state(req)
|
||||
if state is not None:
|
||||
self.req_states.append(state)
|
||||
|
||||
self.req_states.sort(key=lambda x: x[0])
|
||||
else:
|
||||
state = get_req_state(req)
|
||||
if state is not None:
|
||||
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||
if tokens_left >= state[0]:
|
||||
self.req_states.insert(i, state)
|
||||
if not insert_sort:
|
||||
self.req_states.append((tokens_left, tokens_occupied))
|
||||
else:
|
||||
for i in range(len(self.req_states)):
|
||||
if tokens_left <= self.req_states[i][0]:
|
||||
break
|
||||
else:
|
||||
self.req_states.append(state)
|
||||
self.req_states.insert(i, (tokens_left, tokens_occupied))
|
||||
|
||||
tokens_freed = 0
|
||||
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||
decode_steps = (
|
||||
self.req_states[i + 1][0]
|
||||
if i + 1 < len(self.req_states)
|
||||
else tokens_left
|
||||
)
|
||||
bs = len(self.req_states) - i
|
||||
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
|
||||
return False
|
||||
tokens_freed += tokens_occupied
|
||||
if self.req_states is None:
|
||||
self.req_states = []
|
||||
add_req_state(req)
|
||||
if self.running_batch is not None:
|
||||
for r in self.running_batch.reqs:
|
||||
add_req_state(r)
|
||||
for r in self.can_run_list:
|
||||
add_req_state(r)
|
||||
self.req_states.sort(key=lambda x: x[0])
|
||||
else:
|
||||
add_req_state(req, insert_sort=True)
|
||||
|
||||
tokens_freed = 0
|
||||
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||
decode_steps = (
|
||||
self.req_states[i + 1][0]
|
||||
if i + 1 < len(self.req_states)
|
||||
else tokens_left
|
||||
)
|
||||
bs = len(self.req_states) - i
|
||||
if self.cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
|
||||
return False
|
||||
tokens_freed += tokens_occupied
|
||||
|
||||
if req.extend_input_len <= self.rem_chunk_tokens:
|
||||
self.can_run_list.append(req)
|
||||
|
||||
@@ -445,9 +445,6 @@ class ModelTpServer:
|
||||
num_mixed_running,
|
||||
)
|
||||
|
||||
if self.running_batch is not None:
|
||||
adder.remove_running_tokens(self.running_batch)
|
||||
|
||||
has_inflight = self.current_inflight_req is not None
|
||||
if self.current_inflight_req is not None:
|
||||
self.current_inflight_req.init_next_round_input(
|
||||
@@ -465,9 +462,6 @@ class ModelTpServer:
|
||||
)
|
||||
|
||||
for req in self.waiting_queue:
|
||||
if adder.no_remaining_tokens():
|
||||
break
|
||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||
if (
|
||||
self.lora_paths is not None
|
||||
and len(
|
||||
@@ -478,6 +472,10 @@ class ModelTpServer:
|
||||
> self.max_loras_per_batch
|
||||
):
|
||||
break
|
||||
|
||||
if adder.no_remaining_tokens():
|
||||
break
|
||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||
res = adder.add_one_req(req)
|
||||
if (
|
||||
not res
|
||||
@@ -507,6 +505,11 @@ class ModelTpServer:
|
||||
else:
|
||||
tree_cache_hit_rate = 0.0
|
||||
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
if num_mixed_running > 0:
|
||||
logger.info(
|
||||
f"Prefill batch"
|
||||
@@ -515,6 +518,7 @@ class ModelTpServer:
|
||||
f"#new-token: {adder.log_input_tokens}, "
|
||||
f"#cached-token: {adder.log_hit_tokens}, "
|
||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
||||
)
|
||||
else:
|
||||
@@ -524,6 +528,7 @@ class ModelTpServer:
|
||||
f"#new-token: {adder.log_input_tokens}, "
|
||||
f"#cached-token: {adder.log_hit_tokens}, "
|
||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"#running-req: {running_bs}, "
|
||||
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user