Fix the prefix indices (#1037)
This commit is contained in:
@@ -43,11 +43,14 @@ class PolicyScheduler:
|
||||
|
||||
def calc_priority(self, waiting_queue: List[Req]):
|
||||
# Compute matched prefix length
|
||||
for r in waiting_queue:
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||
rid=r.rid, key=r.adjust_max_prefix_ids()
|
||||
)
|
||||
prefix_computed = False
|
||||
if self.policy in ["lpm", "dfs-weight"]:
|
||||
for r in waiting_queue:
|
||||
# NOTE: the prefix_indices must always be aligned with last_node
|
||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||
rid=r.rid, key=r.adjust_max_prefix_ids()
|
||||
)
|
||||
prefix_computed = True
|
||||
|
||||
if self.policy == "lpm":
|
||||
# Longest Prefix Match
|
||||
@@ -80,6 +83,8 @@ class PolicyScheduler:
|
||||
else:
|
||||
raise ValueError(f"Unknown schedule_policy: {self.policy}")
|
||||
|
||||
return prefix_computed
|
||||
|
||||
def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
|
||||
for child in cur_node.children.values():
|
||||
self.calc_weight(child, node_to_weight)
|
||||
|
||||
@@ -18,9 +18,8 @@ limitations under the License.
|
||||
import logging
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||
|
||||
@@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained import RegexGuide
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
@@ -164,8 +163,12 @@ class Req:
|
||||
def finished(self) -> bool:
|
||||
return self.finished_reason is not None
|
||||
|
||||
def init_next_round_input(self):
|
||||
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
||||
self.fill_ids = self.origin_input_ids + self.output_ids
|
||||
if tree_cache is not None:
|
||||
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
||||
rid=self.rid, key=self.adjust_max_prefix_ids()
|
||||
)
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
@@ -312,7 +315,7 @@ class ScheduleBatch:
|
||||
reqs: List[Req]
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: BaseTokenToKVPool
|
||||
tree_cache: RadixCache
|
||||
tree_cache: BasePrefixCache
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None
|
||||
@@ -534,7 +537,7 @@ class ScheduleBatch:
|
||||
residual_size = max(0, residual_size)
|
||||
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
||||
|
||||
req.prefix_indices = None
|
||||
req.prefix_indices = []
|
||||
req.last_node = None
|
||||
req.extend_input_len = 0
|
||||
|
||||
|
||||
@@ -369,7 +369,7 @@ class ModelTpServer:
|
||||
return None
|
||||
|
||||
# Get priority queue
|
||||
self.scheduler.calc_priority(self.waiting_queue)
|
||||
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
|
||||
|
||||
adder = PrefillAdder(
|
||||
self.tree_cache,
|
||||
@@ -383,13 +383,15 @@ class ModelTpServer:
|
||||
|
||||
has_inflight = self.current_inflight_req is not None
|
||||
if self.current_inflight_req is not None:
|
||||
self.current_inflight_req.init_next_round_input()
|
||||
self.current_inflight_req.init_next_round_input(
|
||||
None if prefix_computed else self.tree_cache
|
||||
)
|
||||
self.current_inflight_req = adder.add_inflight_req(
|
||||
self.current_inflight_req
|
||||
)
|
||||
|
||||
for req in self.waiting_queue:
|
||||
req.init_next_round_input()
|
||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||
res = adder.add_one_req(req)
|
||||
if (
|
||||
not res
|
||||
|
||||
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user