Fix the prefix indices (#1037)

This commit is contained in:
Liangsheng Yin
2024-08-11 17:57:02 -07:00
committed by GitHub
parent d84c5e70f7
commit 7de6034534
4 changed files with 25 additions and 15 deletions

View File

@@ -43,11 +43,14 @@ class PolicyScheduler:
def calc_priority(self, waiting_queue: List[Req]):
# Compute matched prefix length
for r in waiting_queue:
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
)
prefix_computed = False
if self.policy in ["lpm", "dfs-weight"]:
for r in waiting_queue:
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
)
prefix_computed = True
if self.policy == "lpm":
# Longest Prefix Match
@@ -80,6 +83,8 @@ class PolicyScheduler:
else:
raise ValueError(f"Unknown schedule_policy: {self.policy}")
return prefix_computed
def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
for child in cur_node.children.values():
self.calc_weight(child, node_to_weight)

View File

@@ -18,9 +18,8 @@ limitations under the License.
import logging
import warnings
from dataclasses import dataclass
from typing import List, Union
from typing import List, Optional, Union
import numpy as np
import torch
from flashinfer.sampling import top_k_top_p_sampling_from_probs
@@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
@@ -164,8 +163,12 @@ class Req:
def finished(self) -> bool:
return self.finished_reason is not None
def init_next_round_input(self):
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self):
@@ -312,7 +315,7 @@ class ScheduleBatch:
reqs: List[Req]
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: BaseTokenToKVPool
tree_cache: RadixCache
tree_cache: BasePrefixCache
# Batched arguments to model runner
input_ids: torch.Tensor = None
@@ -534,7 +537,7 @@ class ScheduleBatch:
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
req.prefix_indices = None
req.prefix_indices = []
req.last_node = None
req.extend_input_len = 0

View File

@@ -369,7 +369,7 @@ class ModelTpServer:
return None
# Get priority queue
self.scheduler.calc_priority(self.waiting_queue)
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
adder = PrefillAdder(
self.tree_cache,
@@ -383,13 +383,15 @@ class ModelTpServer:
has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None:
self.current_inflight_req.init_next_round_input()
self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache
)
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
)
for req in self.waiting_queue:
req.init_next_round_input()
req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req)
if (
not res

View File

@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, List, Optional
import torch