diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 4bf700f51..9d5f99197 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -43,11 +43,14 @@ class PolicyScheduler: def calc_priority(self, waiting_queue: List[Req]): # Compute matched prefix length - for r in waiting_queue: - # NOTE: the prefix_indices must always be aligned with last_node - r.prefix_indices, r.last_node = self.tree_cache.match_prefix( - rid=r.rid, key=r.adjust_max_prefix_ids() - ) + prefix_computed = False + if self.policy in ["lpm", "dfs-weight"]: + for r in waiting_queue: + # NOTE: the prefix_indices must always be aligned with last_node + r.prefix_indices, r.last_node = self.tree_cache.match_prefix( + rid=r.rid, key=r.adjust_max_prefix_ids() + ) + prefix_computed = True if self.policy == "lpm": # Longest Prefix Match @@ -80,6 +83,8 @@ class PolicyScheduler: else: raise ValueError(f"Unknown schedule_policy: {self.policy}") + return prefix_computed + def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict): for child in cur_node.children.values(): self.calc_weight(child, node_to_weight) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a62e612b0..a461fa181 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -18,9 +18,8 @@ limitations under the License. import logging import warnings from dataclasses import dataclass -from typing import List, Union +from typing import List, Optional, Union -import numpy as np import torch from flashinfer.sampling import top_k_top_p_sampling_from_probs @@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib from sglang.global_config import global_config from sglang.srt.constrained import RegexGuide from sglang.srt.constrained.jump_forward import JumpForwardMap +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool -from sglang.srt.mem_cache.radix_cache import RadixCache INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 @@ -164,8 +163,12 @@ class Req: def finished(self) -> bool: return self.finished_reason is not None - def init_next_round_input(self): + def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): self.fill_ids = self.origin_input_ids + self.output_ids + if tree_cache is not None: + self.prefix_indices, self.last_node = tree_cache.match_prefix( + rid=self.rid, key=self.adjust_max_prefix_ids() + ) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): @@ -312,7 +315,7 @@ class ScheduleBatch: reqs: List[Req] req_to_token_pool: ReqToTokenPool token_to_kv_pool: BaseTokenToKVPool - tree_cache: RadixCache + tree_cache: BasePrefixCache # Batched arguments to model runner input_ids: torch.Tensor = None @@ -534,7 +537,7 @@ class ScheduleBatch: residual_size = max(0, residual_size) self.tree_cache.evict(residual_size, self.token_to_kv_pool.free) - req.prefix_indices = None + req.prefix_indices = [] req.last_node = None req.extend_input_len = 0 diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 49ee8c839..4c757737e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -369,7 +369,7 @@ class ModelTpServer: return None # Get priority queue - self.scheduler.calc_priority(self.waiting_queue) + prefix_computed = self.scheduler.calc_priority(self.waiting_queue) adder = PrefillAdder( self.tree_cache, @@ -383,13 +383,15 @@ class ModelTpServer: has_inflight = self.current_inflight_req is not None if self.current_inflight_req is not None: - self.current_inflight_req.init_next_round_input() + self.current_inflight_req.init_next_round_input( + None if prefix_computed else self.tree_cache + ) self.current_inflight_req = adder.add_inflight_req( self.current_inflight_req ) for req in self.waiting_queue: - req.init_next_round_input() + req.init_next_round_input(None if prefix_computed else self.tree_cache) res = adder.add_one_req(req) if ( not res diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 8ebe903c7..a1c685405 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache. import heapq import time from collections import defaultdict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, List, Optional import torch