diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 33acf98e8..d7dedc29d 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -364,12 +364,13 @@ class ModelTpServer: # Compute matched prefix length for req in self.waiting_queue: req.input_ids = req.origin_input_ids + req.output_ids - prefix_indices, last_node = self.tree_cache.match_prefix( - rid=req.rid, - key=req.input_ids, - ) + try_match_ids = req.input_ids if req.return_logprob: - prefix_indices = prefix_indices[: req.logprob_start_len] + try_match_ids = req.input_ids[: req.logprob_start_len] + # NOTE: the prefix_indices must always be aligned with last_node + prefix_indices, last_node = self.tree_cache.match_prefix( + rid=req.rid, key=try_match_ids + ) req.extend_input_len = len(req.input_ids) - len(prefix_indices) req.prefix_indices = prefix_indices req.last_node = last_node