Fix the prefix indices (#1037)
This commit is contained in:
@@ -43,11 +43,14 @@ class PolicyScheduler:
|
|||||||
|
|
||||||
def calc_priority(self, waiting_queue: List[Req]):
|
def calc_priority(self, waiting_queue: List[Req]):
|
||||||
# Compute matched prefix length
|
# Compute matched prefix length
|
||||||
for r in waiting_queue:
|
prefix_computed = False
|
||||||
# NOTE: the prefix_indices must always be aligned with last_node
|
if self.policy in ["lpm", "dfs-weight"]:
|
||||||
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
for r in waiting_queue:
|
||||||
rid=r.rid, key=r.adjust_max_prefix_ids()
|
# NOTE: the prefix_indices must always be aligned with last_node
|
||||||
)
|
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
||||||
|
rid=r.rid, key=r.adjust_max_prefix_ids()
|
||||||
|
)
|
||||||
|
prefix_computed = True
|
||||||
|
|
||||||
if self.policy == "lpm":
|
if self.policy == "lpm":
|
||||||
# Longest Prefix Match
|
# Longest Prefix Match
|
||||||
@@ -80,6 +83,8 @@ class PolicyScheduler:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown schedule_policy: {self.policy}")
|
raise ValueError(f"Unknown schedule_policy: {self.policy}")
|
||||||
|
|
||||||
|
return prefix_computed
|
||||||
|
|
||||||
def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
|
def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
|
||||||
for child in cur_node.children.values():
|
for child in cur_node.children.values():
|
||||||
self.calc_weight(child, node_to_weight)
|
self.calc_weight(child, node_to_weight)
|
||||||
|
|||||||
@@ -18,9 +18,8 @@ limitations under the License.
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
||||||
|
|
||||||
@@ -28,9 +27,9 @@ import sglang.srt.sampling.penaltylib as penaltylib
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.constrained import RegexGuide
|
from sglang.srt.constrained import RegexGuide
|
||||||
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
||||||
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
@@ -164,8 +163,12 @@ class Req:
|
|||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
return self.finished_reason is not None
|
return self.finished_reason is not None
|
||||||
|
|
||||||
def init_next_round_input(self):
|
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
||||||
self.fill_ids = self.origin_input_ids + self.output_ids
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
||||||
|
if tree_cache is not None:
|
||||||
|
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
||||||
|
rid=self.rid, key=self.adjust_max_prefix_ids()
|
||||||
|
)
|
||||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||||
|
|
||||||
def adjust_max_prefix_ids(self):
|
def adjust_max_prefix_ids(self):
|
||||||
@@ -312,7 +315,7 @@ class ScheduleBatch:
|
|||||||
reqs: List[Req]
|
reqs: List[Req]
|
||||||
req_to_token_pool: ReqToTokenPool
|
req_to_token_pool: ReqToTokenPool
|
||||||
token_to_kv_pool: BaseTokenToKVPool
|
token_to_kv_pool: BaseTokenToKVPool
|
||||||
tree_cache: RadixCache
|
tree_cache: BasePrefixCache
|
||||||
|
|
||||||
# Batched arguments to model runner
|
# Batched arguments to model runner
|
||||||
input_ids: torch.Tensor = None
|
input_ids: torch.Tensor = None
|
||||||
@@ -534,7 +537,7 @@ class ScheduleBatch:
|
|||||||
residual_size = max(0, residual_size)
|
residual_size = max(0, residual_size)
|
||||||
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
||||||
|
|
||||||
req.prefix_indices = None
|
req.prefix_indices = []
|
||||||
req.last_node = None
|
req.last_node = None
|
||||||
req.extend_input_len = 0
|
req.extend_input_len = 0
|
||||||
|
|
||||||
|
|||||||
@@ -369,7 +369,7 @@ class ModelTpServer:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Get priority queue
|
# Get priority queue
|
||||||
self.scheduler.calc_priority(self.waiting_queue)
|
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
|
||||||
|
|
||||||
adder = PrefillAdder(
|
adder = PrefillAdder(
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
@@ -383,13 +383,15 @@ class ModelTpServer:
|
|||||||
|
|
||||||
has_inflight = self.current_inflight_req is not None
|
has_inflight = self.current_inflight_req is not None
|
||||||
if self.current_inflight_req is not None:
|
if self.current_inflight_req is not None:
|
||||||
self.current_inflight_req.init_next_round_input()
|
self.current_inflight_req.init_next_round_input(
|
||||||
|
None if prefix_computed else self.tree_cache
|
||||||
|
)
|
||||||
self.current_inflight_req = adder.add_inflight_req(
|
self.current_inflight_req = adder.add_inflight_req(
|
||||||
self.current_inflight_req
|
self.current_inflight_req
|
||||||
)
|
)
|
||||||
|
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
req.init_next_round_input()
|
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||||
res = adder.add_one_req(req)
|
res = adder.add_one_req(req)
|
||||||
if (
|
if (
|
||||||
not res
|
not res
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
|
|||||||
import heapq
|
import heapq
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user