Optimize schedule (#1339)
This commit is contained in:
@@ -108,18 +108,24 @@ class PrefillAdder:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
|
running_batch: ScheduleBatch,
|
||||||
|
new_token_ratio: float,
|
||||||
rem_total_tokens: int,
|
rem_total_tokens: int,
|
||||||
rem_input_tokens: int,
|
rem_input_tokens: int,
|
||||||
rem_chunk_tokens: Optional[int],
|
rem_chunk_tokens: Optional[int],
|
||||||
mixed_with_decode_tokens: int = 0,
|
mixed_with_decode_tokens: int = 0,
|
||||||
):
|
):
|
||||||
self.tree_cache = tree_cache
|
self.tree_cache = tree_cache
|
||||||
|
self.running_batch = running_batch
|
||||||
|
self.new_token_ratio = new_token_ratio
|
||||||
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
||||||
|
self.total_tokens = rem_total_tokens
|
||||||
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
||||||
self.rem_chunk_tokens = rem_chunk_tokens
|
self.rem_chunk_tokens = rem_chunk_tokens
|
||||||
if self.rem_chunk_tokens is not None:
|
if self.rem_chunk_tokens is not None:
|
||||||
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
||||||
|
|
||||||
|
self.req_states = None
|
||||||
self.can_run_list = []
|
self.can_run_list = []
|
||||||
self.new_inflight_req = None
|
self.new_inflight_req = None
|
||||||
self.log_hit_tokens = 0
|
self.log_hit_tokens = 0
|
||||||
@@ -136,16 +142,14 @@ class PrefillAdder:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def remove_running_tokens(
|
def remove_running_tokens(self, running_batch: ScheduleBatch):
|
||||||
self, running_batch: ScheduleBatch, new_token_ratio: float
|
|
||||||
):
|
|
||||||
self.rem_total_tokens -= sum(
|
self.rem_total_tokens -= sum(
|
||||||
[
|
[
|
||||||
min(
|
min(
|
||||||
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
||||||
CLIP_MAX_NEW_TOKENS,
|
CLIP_MAX_NEW_TOKENS,
|
||||||
)
|
)
|
||||||
* new_token_ratio
|
* self.new_token_ratio
|
||||||
for r in running_batch.reqs
|
for r in running_batch.reqs
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -161,7 +165,29 @@ class PrefillAdder:
|
|||||||
self.log_hit_tokens += prefix_len
|
self.log_hit_tokens += prefix_len
|
||||||
self.log_input_tokens += extend_input_len
|
self.log_input_tokens += extend_input_len
|
||||||
|
|
||||||
|
def add_inflight_req_ignore_eos(self, req: Req):
|
||||||
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||||
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||||
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||||
|
self.can_run_list.append(req)
|
||||||
|
|
||||||
|
self._prefill_one_req(
|
||||||
|
0,
|
||||||
|
req.extend_input_len,
|
||||||
|
(
|
||||||
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
||||||
|
if not truncated
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return if chunked prefill not finished
|
||||||
|
return req if truncated else None
|
||||||
|
|
||||||
def add_inflight_req(self, req: Req):
|
def add_inflight_req(self, req: Req):
|
||||||
|
if req.sampling_params.ignore_eos:
|
||||||
|
return self.add_inflight_req_ignore_eos(req)
|
||||||
|
|
||||||
truncated = req.extend_input_len > self.rem_chunk_tokens
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||||
@@ -190,7 +216,81 @@ class PrefillAdder:
|
|||||||
delta = self.tree_cache.dec_lock_ref(last_node)
|
delta = self.tree_cache.dec_lock_ref(last_node)
|
||||||
self.rem_total_tokens += delta
|
self.rem_total_tokens += delta
|
||||||
|
|
||||||
|
def add_one_req_ignore_eos(self, req: Req):
|
||||||
|
def get_req_state(r):
|
||||||
|
new_token_ratio = (
|
||||||
|
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
||||||
|
)
|
||||||
|
tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len(
|
||||||
|
r.output_ids
|
||||||
|
)
|
||||||
|
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
|
||||||
|
|
||||||
|
if tokens_left > 0:
|
||||||
|
return (tokens_left, tokens_occupied)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.req_states is None:
|
||||||
|
self.req_states = []
|
||||||
|
if self.running_batch is not None:
|
||||||
|
for r in self.running_batch.reqs:
|
||||||
|
state = get_req_state(r)
|
||||||
|
if state is not None:
|
||||||
|
self.req_states.append(state)
|
||||||
|
for r in self.can_run_list:
|
||||||
|
state = get_req_state(r)
|
||||||
|
if state is not None:
|
||||||
|
self.req_states.append(state)
|
||||||
|
state = get_req_state(req)
|
||||||
|
if state is not None:
|
||||||
|
self.req_states.append(state)
|
||||||
|
|
||||||
|
self.req_states.sort(key=lambda x: x[0])
|
||||||
|
else:
|
||||||
|
state = get_req_state(req)
|
||||||
|
if state is not None:
|
||||||
|
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||||
|
if tokens_left >= state[0]:
|
||||||
|
self.req_states.insert(i, state)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.req_states.append(state)
|
||||||
|
|
||||||
|
tokens_freed = 0
|
||||||
|
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
||||||
|
decode_steps = (
|
||||||
|
self.req_states[i + 1][0]
|
||||||
|
if i + 1 < len(self.req_states)
|
||||||
|
else tokens_left
|
||||||
|
)
|
||||||
|
bs = len(self.req_states) - i
|
||||||
|
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
|
||||||
|
return False
|
||||||
|
tokens_freed += tokens_occupied
|
||||||
|
|
||||||
|
if req.extend_input_len <= self.rem_chunk_tokens:
|
||||||
|
self.can_run_list.append(req)
|
||||||
|
self._prefill_one_req(
|
||||||
|
0,
|
||||||
|
req.extend_input_len,
|
||||||
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Chunked prefill
|
||||||
|
trunc_len = self.rem_chunk_tokens
|
||||||
|
req.extend_input_len = trunc_len
|
||||||
|
req.fill_ids = req.fill_ids[:trunc_len]
|
||||||
|
self.can_run_list.append(req)
|
||||||
|
self.new_inflight_req = req
|
||||||
|
self._prefill_one_req(0, trunc_len, 0)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def add_one_req(self, req: Req):
|
def add_one_req(self, req: Req):
|
||||||
|
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
||||||
|
return self.add_one_req_ignore_eos(req)
|
||||||
|
|
||||||
total_tokens = req.extend_input_len + min(
|
total_tokens = req.extend_input_len + min(
|
||||||
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
||||||
)
|
)
|
||||||
@@ -233,4 +333,4 @@ class PrefillAdder:
|
|||||||
self.tree_cache.inc_lock_ref(req.last_node)
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
self._prefill_one_req(prefix_len, trunc_len, 0)
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||||
|
|
||||||
return True
|
return True and not self.no_remaining_tokens()
|
||||||
|
|||||||
@@ -221,6 +221,7 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
self.new_token_ratio = self.min_new_token_ratio
|
self.new_token_ratio = self.min_new_token_ratio
|
||||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||||
|
self.do_not_get_new_batch = False
|
||||||
|
|
||||||
def exposed_step(self, recv_reqs: List):
|
def exposed_step(self, recv_reqs: List):
|
||||||
try:
|
try:
|
||||||
@@ -253,7 +254,13 @@ class ModelTpServer:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward_step(self):
|
def forward_step(self):
|
||||||
new_batch = self.get_new_prefill_batch()
|
if self.current_inflight_req is not None:
|
||||||
|
self.do_not_get_new_batch = False
|
||||||
|
|
||||||
|
new_batch = (
|
||||||
|
self.get_new_prefill_batch() if not self.do_not_get_new_batch else None
|
||||||
|
)
|
||||||
|
self.do_not_get_new_batch = False
|
||||||
|
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
# Run a new prefill batch
|
# Run a new prefill batch
|
||||||
@@ -409,6 +416,8 @@ class ModelTpServer:
|
|||||||
|
|
||||||
adder = PrefillAdder(
|
adder = PrefillAdder(
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
|
self.running_batch,
|
||||||
|
self.new_token_ratio,
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
self.chunked_prefill_size,
|
self.chunked_prefill_size,
|
||||||
@@ -416,7 +425,7 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.running_batch is not None:
|
if self.running_batch is not None:
|
||||||
adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
|
adder.remove_running_tokens(self.running_batch)
|
||||||
|
|
||||||
has_inflight = self.current_inflight_req is not None
|
has_inflight = self.current_inflight_req is not None
|
||||||
if self.current_inflight_req is not None:
|
if self.current_inflight_req is not None:
|
||||||
@@ -428,11 +437,12 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
|
if adder.no_remaining_tokens():
|
||||||
|
break
|
||||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||||
res = adder.add_one_req(req)
|
res = adder.add_one_req(req)
|
||||||
if (
|
if (
|
||||||
not res
|
not res
|
||||||
or adder.no_remaining_tokens()
|
|
||||||
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
@@ -700,6 +710,7 @@ class ModelTpServer:
|
|||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
|
has_finished = False
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
@@ -712,6 +723,7 @@ class ModelTpServer:
|
|||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
has_finished = True
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
req.output_token_logprobs.append(
|
req.output_token_logprobs.append(
|
||||||
@@ -720,6 +732,9 @@ class ModelTpServer:
|
|||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||||
|
|
||||||
|
if not has_finished:
|
||||||
|
self.do_not_get_new_batch = True
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
def handle_finished_requests(self, batch: ScheduleBatch):
|
def handle_finished_requests(self, batch: ScheduleBatch):
|
||||||
|
|||||||
Reference in New Issue
Block a user