Fix chunked prefill when ignore eos (#2290)
This commit is contained in:
@@ -142,7 +142,7 @@ class PrefillAdder:
|
|||||||
|
|
||||||
self.req_states = None
|
self.req_states = None
|
||||||
self.can_run_list = []
|
self.can_run_list = []
|
||||||
self.new_inflight_req = None
|
self.new_being_chunked_req = None
|
||||||
self.log_hit_tokens = 0
|
self.log_hit_tokens = 0
|
||||||
self.log_input_tokens = 0
|
self.log_input_tokens = 0
|
||||||
|
|
||||||
@@ -182,7 +182,7 @@ class PrefillAdder:
|
|||||||
self.log_hit_tokens += prefix_len
|
self.log_hit_tokens += prefix_len
|
||||||
self.log_input_tokens += extend_input_len
|
self.log_input_tokens += extend_input_len
|
||||||
|
|
||||||
def add_inflight_req(self, req: Req):
|
def add_being_chunked_req(self, req: Req):
|
||||||
truncated = req.extend_input_len > self.rem_chunk_tokens
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
||||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||||
@@ -269,10 +269,13 @@ class PrefillAdder:
|
|||||||
else:
|
else:
|
||||||
# Chunked prefill
|
# Chunked prefill
|
||||||
trunc_len = self.rem_chunk_tokens
|
trunc_len = self.rem_chunk_tokens
|
||||||
|
if trunc_len == 0:
|
||||||
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
req.fill_ids = req.fill_ids[:trunc_len]
|
req.fill_ids = req.fill_ids[:trunc_len]
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
self.new_inflight_req = req
|
self.new_being_chunked_req = req
|
||||||
self._prefill_one_req(0, trunc_len, 0)
|
self._prefill_one_req(0, trunc_len, 0)
|
||||||
|
|
||||||
return self.budget_state()
|
return self.budget_state()
|
||||||
@@ -326,7 +329,7 @@ class PrefillAdder:
|
|||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
self.new_inflight_req = req
|
self.new_being_chunked_req = req
|
||||||
self.tree_cache.inc_lock_ref(req.last_node)
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
self._prefill_one_req(prefix_len, trunc_len, 0)
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||||
|
|
||||||
|
|||||||
@@ -660,7 +660,7 @@ class Scheduler:
|
|||||||
|
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
|
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
|
||||||
if isinstance(self.tree_cache, RadixCache):
|
if isinstance(self.tree_cache, RadixCache):
|
||||||
self.tree_cache_metrics["total"] += (
|
self.tree_cache_metrics["total"] += (
|
||||||
adder.log_input_tokens + adder.log_hit_tokens
|
adder.log_input_tokens + adder.log_hit_tokens
|
||||||
@@ -684,14 +684,14 @@ class Scheduler:
|
|||||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||||
f"#running-req: {running_bs}, "
|
f"#running-req: {running_bs}, "
|
||||||
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
|
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_metrics:
|
if self.enable_metrics:
|
||||||
self.stats.num_running_reqs = running_bs
|
self.stats.num_running_reqs = running_bs
|
||||||
self.stats.num_used_tokens = num_used
|
self.stats.num_used_tokens = num_used
|
||||||
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
||||||
self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
|
self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
|
||||||
self.stats.cache_hit_rate = tree_cache_hit_rate
|
self.stats.cache_hit_rate = tree_cache_hit_rate
|
||||||
self.metrics_collector.log_stats(self.stats)
|
self.metrics_collector.log_stats(self.stats)
|
||||||
|
|
||||||
@@ -752,7 +752,7 @@ class Scheduler:
|
|||||||
# Move the chunked request out of the batch
|
# Move the chunked request out of the batch
|
||||||
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
||||||
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
||||||
# Inflight request keeps its rid but will get a new req_pool_idx
|
# being chunked request keeps its rid but will get a new req_pool_idx
|
||||||
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
|
|
||||||
@@ -803,10 +803,10 @@ class Scheduler:
|
|||||||
running_bs if self.is_mixed_chunk else 0,
|
running_bs if self.is_mixed_chunk else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
has_inflight = self.being_chunked_req is not None
|
has_being_chunked = self.being_chunked_req is not None
|
||||||
if has_inflight:
|
if has_being_chunked:
|
||||||
self.being_chunked_req.init_next_round_input()
|
self.being_chunked_req.init_next_round_input()
|
||||||
self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
|
self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
|
||||||
|
|
||||||
if self.lora_paths:
|
if self.lora_paths:
|
||||||
lora_set = (
|
lora_set = (
|
||||||
@@ -848,16 +848,16 @@ class Scheduler:
|
|||||||
x for x in self.waiting_queue if x not in set(can_run_list)
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
||||||
]
|
]
|
||||||
|
|
||||||
if adder.new_inflight_req is not None:
|
if adder.new_being_chunked_req is not None:
|
||||||
assert self.being_chunked_req is None
|
assert self.being_chunked_req is None
|
||||||
self.being_chunked_req = adder.new_inflight_req
|
self.being_chunked_req = adder.new_being_chunked_req
|
||||||
|
|
||||||
if self.being_chunked_req:
|
if self.being_chunked_req:
|
||||||
self.being_chunked_req.is_being_chunked += 1
|
self.being_chunked_req.is_being_chunked += 1
|
||||||
|
|
||||||
# Print stats
|
# Print stats
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
|
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
|
||||||
|
|
||||||
# Create a new batch
|
# Create a new batch
|
||||||
new_batch = ScheduleBatch.init_new(
|
new_batch = ScheduleBatch.init_new(
|
||||||
@@ -1030,7 +1030,7 @@ class Scheduler:
|
|||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
req.grammar.accept_token(next_token_id)
|
req.grammar.accept_token(next_token_id)
|
||||||
else:
|
else:
|
||||||
# Inflight reqs' prefill is not finished
|
# being chunked reqs' prefill is not finished
|
||||||
req.is_being_chunked -= 1
|
req.is_being_chunked -= 1
|
||||||
|
|
||||||
if batch.next_batch_sampling_info:
|
if batch.next_batch_sampling_info:
|
||||||
@@ -1058,7 +1058,7 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
else:
|
else:
|
||||||
# Inflight reqs' prefill is not finished
|
# being chunked reqs' prefill is not finished
|
||||||
req.is_being_chunked -= 1
|
req.is_being_chunked -= 1
|
||||||
|
|
||||||
self.stream_output(batch.reqs)
|
self.stream_output(batch.reqs)
|
||||||
|
|||||||
Reference in New Issue
Block a user