Revert "Chunked prefill support" (#799)
This commit is contained in:
@@ -77,10 +77,6 @@ class ModelTpServer:
|
||||
self.schedule_heuristic = server_args.schedule_heuristic
|
||||
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
||||
|
||||
# Chunked prefill
|
||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||
self.current_inflight_req = None
|
||||
|
||||
# Init model and tokenizer
|
||||
self.model_config = ModelConfig(
|
||||
server_args.model_path,
|
||||
@@ -161,7 +157,7 @@ class ModelTpServer:
|
||||
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
self.forward_queue: List[Req] = []
|
||||
self.running_batch: Batch = None
|
||||
self.out_pyobjs = []
|
||||
self.decode_forward_ct = 0
|
||||
@@ -224,7 +220,6 @@ class ModelTpServer:
|
||||
# Run a new prefill batch
|
||||
self.forward_prefill_batch(new_batch)
|
||||
self.cache_filled_batch(new_batch)
|
||||
self.filter_out_inflight(new_batch)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
@@ -266,7 +261,7 @@ class ModelTpServer:
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {throughput:.2f}, "
|
||||
f"#queue-req: {len(self.waiting_queue)}"
|
||||
f"#queue-req: {len(self.forward_queue)}"
|
||||
)
|
||||
|
||||
def check_memory(self):
|
||||
@@ -333,10 +328,9 @@ class ModelTpServer:
|
||||
),
|
||||
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||
)
|
||||
self.waiting_queue.append(req)
|
||||
self.forward_queue.append(req)
|
||||
|
||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||
# TODO(lsyin): organize this function
|
||||
running_bs = (
|
||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
)
|
||||
@@ -344,7 +338,7 @@ class ModelTpServer:
|
||||
return
|
||||
|
||||
# Compute matched prefix length
|
||||
for req in self.waiting_queue:
|
||||
for req in self.forward_queue:
|
||||
req.input_ids = req.origin_input_ids + req.output_ids
|
||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||
if req.return_logprob:
|
||||
@@ -354,7 +348,7 @@ class ModelTpServer:
|
||||
req.last_node = last_node
|
||||
|
||||
# Get priority queue
|
||||
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
|
||||
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
|
||||
|
||||
# Add requests if there is available space
|
||||
can_run_list = []
|
||||
@@ -373,33 +367,7 @@ class ModelTpServer:
|
||||
]
|
||||
)
|
||||
|
||||
# Handle the current inflight request
|
||||
take_inflight = 0
|
||||
if self.current_inflight_req:
|
||||
take_inflight = 1
|
||||
r = self.current_inflight_req
|
||||
r.input_ids = r.origin_input_ids + r.output_ids
|
||||
truncated = (
|
||||
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
|
||||
)
|
||||
r.extend_input_len = min(
|
||||
len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
|
||||
)
|
||||
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
|
||||
can_run_list.append(r)
|
||||
|
||||
if not truncated:
|
||||
# Finish inflight
|
||||
self.current_inflight_req = None
|
||||
new_batch_total_tokens += (
|
||||
r.extend_input_len + r.sampling_params.max_new_tokens
|
||||
)
|
||||
new_batch_input_tokens += r.extend_input_len
|
||||
else:
|
||||
new_batch_total_tokens += r.extend_input_len
|
||||
new_batch_input_tokens += r.extend_input_len
|
||||
|
||||
for req in self.waiting_queue:
|
||||
for req in self.forward_queue:
|
||||
if req.return_logprob and req.normalized_prompt_logprob is None:
|
||||
# Need at least two tokens to compute normalized logprob
|
||||
if req.extend_input_len < 2:
|
||||
@@ -441,36 +409,11 @@ class ModelTpServer:
|
||||
break
|
||||
else:
|
||||
# Add this request to the running batch
|
||||
if (
|
||||
new_batch_input_tokens + req.extend_input_len
|
||||
<= self.chunked_prefill_size
|
||||
or (
|
||||
req.return_logprob and req.normalized_prompt_logprob is None
|
||||
)
|
||||
):
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += (
|
||||
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||
)
|
||||
new_batch_input_tokens += req.extend_input_len
|
||||
else:
|
||||
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
|
||||
|
||||
if trunc_len <= 0:
|
||||
# Undo locking
|
||||
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
||||
available_size += delta
|
||||
break
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.input_ids = req.input_ids[
|
||||
: len(req.prefix_indices) + req.extend_input_len
|
||||
]
|
||||
can_run_list.append(req)
|
||||
self.current_inflight_req = req
|
||||
new_batch_input_tokens += req.extend_input_len
|
||||
new_batch_total_tokens += req.extend_input_len
|
||||
break
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += (
|
||||
req.extend_input_len + req.sampling_params.max_new_tokens
|
||||
)
|
||||
new_batch_input_tokens += req.extend_input_len
|
||||
else:
|
||||
break
|
||||
|
||||
@@ -497,7 +440,7 @@ class ModelTpServer:
|
||||
f"#cached-token: {hit_tokens}, "
|
||||
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
||||
f"#running-req: {running_bs}, "
|
||||
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
|
||||
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
|
||||
)
|
||||
|
||||
# Return the new batch
|
||||
@@ -507,7 +450,7 @@ class ModelTpServer:
|
||||
self.token_to_kv_pool,
|
||||
self.tree_cache,
|
||||
)
|
||||
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
||||
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
||||
return new_batch
|
||||
|
||||
def forward_prefill_batch(self, batch: Batch):
|
||||
@@ -539,10 +482,9 @@ class ModelTpServer:
|
||||
# Check finish conditions
|
||||
pt = 0
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req is not self.current_inflight_req:
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_ids[i])
|
||||
req.check_finished()
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_ids[i])
|
||||
req.check_finished()
|
||||
|
||||
if req.return_logprob:
|
||||
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
||||
@@ -603,7 +545,7 @@ class ModelTpServer:
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||
for i, req in enumerate(batch.reqs):
|
||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||
token_ids=tuple(req.input_ids),
|
||||
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
del_in_memory_pool=False,
|
||||
@@ -611,10 +553,6 @@ class ModelTpServer:
|
||||
)
|
||||
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
||||
|
||||
if req is self.current_inflight_req:
|
||||
# inflight request would get a new req idx
|
||||
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
|
||||
|
||||
def forward_decode_batch(self, batch: Batch):
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
@@ -628,7 +566,7 @@ class ModelTpServer:
|
||||
f"#retracted_reqs: {len(retracted_reqs)}, "
|
||||
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
||||
)
|
||||
self.waiting_queue.extend(retracted_reqs)
|
||||
self.forward_queue.extend(retracted_reqs)
|
||||
else:
|
||||
self.new_token_ratio = max(
|
||||
self.new_token_ratio - self.new_token_ratio_decay,
|
||||
@@ -638,7 +576,7 @@ class ModelTpServer:
|
||||
if not self.disable_regex_jump_forward:
|
||||
# Check for jump-forward
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
||||
self.waiting_queue.extend(jump_forward_reqs)
|
||||
self.forward_queue.extend(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
return
|
||||
|
||||
@@ -773,18 +711,8 @@ class ModelTpServer:
|
||||
else:
|
||||
batch.reqs = []
|
||||
|
||||
def filter_out_inflight(self, batch: Batch):
|
||||
# TODO(lsyin): reduce the overhead, make a special version for this
|
||||
if self.current_inflight_req is None:
|
||||
return
|
||||
|
||||
unfinished_indices = list(range(len(batch.reqs)))
|
||||
unfinished_indices.remove(batch.reqs.index(self.current_inflight_req))
|
||||
|
||||
batch.filter_batch(unfinished_indices)
|
||||
|
||||
def flush_cache(self):
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
if len(self.forward_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
):
|
||||
self.tree_cache.reset()
|
||||
@@ -797,20 +725,20 @@ class ModelTpServer:
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Cache not flushed because there are pending requests. "
|
||||
f"#queue-req: {len(self.waiting_queue)}, "
|
||||
f"#queue-req: {len(self.forward_queue)}, "
|
||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
)
|
||||
|
||||
def abort_request(self, recv_req):
|
||||
# Delete requests in the waiting queue
|
||||
to_del = None
|
||||
for i, req in enumerate(self.waiting_queue):
|
||||
for i, req in enumerate(self.forward_queue):
|
||||
if req.rid == recv_req.rid:
|
||||
to_del = i
|
||||
break
|
||||
|
||||
if to_del is not None:
|
||||
del self.waiting_queue[to_del]
|
||||
del self.forward_queue[to_del]
|
||||
|
||||
# Delete requests in the running batch
|
||||
if self.running_batch:
|
||||
|
||||
Reference in New Issue
Block a user