Fix stuck in get_new_prefill_batch (#948)
This commit is contained in:
@@ -364,12 +364,13 @@ class ModelTpServer:
|
|||||||
# Compute matched prefix length
|
# Compute matched prefix length
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
req.input_ids = req.origin_input_ids + req.output_ids
|
req.input_ids = req.origin_input_ids + req.output_ids
|
||||||
prefix_indices, last_node = self.tree_cache.match_prefix(
|
try_match_ids = req.input_ids
|
||||||
rid=req.rid,
|
|
||||||
key=req.input_ids,
|
|
||||||
)
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
prefix_indices = prefix_indices[: req.logprob_start_len]
|
try_match_ids = req.input_ids[: req.logprob_start_len]
|
||||||
|
# NOTE: the prefix_indices must always be aligned with last_node
|
||||||
|
prefix_indices, last_node = self.tree_cache.match_prefix(
|
||||||
|
rid=req.rid, key=try_match_ids
|
||||||
|
)
|
||||||
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
||||||
req.prefix_indices = prefix_indices
|
req.prefix_indices = prefix_indices
|
||||||
req.last_node = last_node
|
req.last_node = last_node
|
||||||
|
|||||||
Reference in New Issue
Block a user