diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 779bf988d..1261b6d0c 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -55,6 +55,7 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/flush_cache", api_key=self.api_key, verify=self.verify, + method="POST", ) self._assert_success(res) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a347b4d43..65c029a16 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -256,6 +256,7 @@ class Req: # Prefix info self.prefix_indices = [] + # Tokens to run prefill. input_tokens - shared_prefix_tokens. self.extend_input_len = 0 self.last_node = None @@ -316,6 +317,7 @@ class Req: def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): self.fill_ids = self.origin_input_ids + self.output_ids if tree_cache is not None: + # tree cache is None if the prefix is not computed with tree cache. self.prefix_indices, self.last_node = tree_cache.match_prefix( rid=self.rid, key=self.adjust_max_prefix_ids() ) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index abe7da9ea..1bb872fdf 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -20,9 +20,11 @@ from contextlib import contextmanager from enum import Enum, auto from typing import Dict, List, Optional +import torch + from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.radix_cache import TreeNode +from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. # This can prevent the server from being too conservative. @@ -32,6 +34,13 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int( os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096") ) +# The threshold to apply in-batch prefix caching. +# If we use too small value, in-batch prefix caching cannot be used. E.g., +# imagine "the" prefix. +IN_BATCH_PREFIX_CACHING_THRESHOLD = int( + os.environ.get("SGLANG_IN_BATCH_PREFIX_CACHING_THRESHOLD", "32") +) + class SchedulePolicy: def __init__(self, policy: str, tree_cache: BasePrefixCache): @@ -51,18 +60,50 @@ class SchedulePolicy: # Compute matched prefix length prefix_computed = False + # rid to deprioritize in the current run. + temporary_deprioritized = {} if policy == "lpm" or policy == "dfs-weight": + # It is used to find the matching prefix for in-batch prefix caching. + temp_radix = RadixCache(None, None, False) for r in waiting_queue: + prefix_ids = r.adjust_max_prefix_ids() # NOTE: the prefix_indices must always be aligned with last_node r.prefix_indices, r.last_node = self.tree_cache.match_prefix( - rid=r.rid, key=r.adjust_max_prefix_ids() + rid=r.rid, key=prefix_ids ) + # NOTE(sang): This logic is for In-batch prefix caching; + # If there are more than 1 request that have small matching prefix from + # existing cache, but all those requests share the same prefix, we prefer + # to schedule only one of them so that we can increase the cache hit rate. + # We prefer to set IN_BATCH_PREFIX_CACHING_THRESHOLD > 0 because too small + # threshold means we cannot use in-batch prefix caching for short prefixes. + # It is kind of common when the engine is long running (e.g., imagine "the"). + if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_THRESHOLD: + in_batch_matching_prefixes, _ = temp_radix.match_prefix( + rid=r.rid, key=prefix_ids + ) + if ( + len(in_batch_matching_prefixes) + >= IN_BATCH_PREFIX_CACHING_THRESHOLD + ): + temporary_deprioritized[r.rid] = r + else: + temp_radix.insert(prefix_ids, torch.tensor(prefix_ids)) + prefix_computed = True if policy == "lpm": # Longest Prefix Match - waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) + def get_priority(r: Req): + score = 0 + if r.rid in temporary_deprioritized: + score = float("inf") + else: + score = -len(r.prefix_indices) + return score + + waiting_queue.sort(key=get_priority) elif policy == "fcfs": # first come first serve pass @@ -76,6 +117,7 @@ class SchedulePolicy: for req in waiting_queue: last_node_to_reqs[req.last_node].append(req) + # node -> # of requests for that node. node_to_weight = defaultdict(int) for node in last_node_to_reqs: node_to_weight[node] = len(last_node_to_reqs[node]) @@ -87,7 +129,9 @@ class SchedulePolicy: node_to_weight, last_node_to_reqs, waiting_queue, + temporary_deprioritized, ) + waiting_queue.extend(temporary_deprioritized.values()) else: raise ValueError(f"Unknown schedule_policy: {policy=}") @@ -101,15 +145,22 @@ class SchedulePolicy: def get_dfs_priority( self, cur_node: TreeNode, - node_to_priority: Dict, - last_node_to_reqs: Dict, + node_to_priority: Dict[TreeNode, int], + last_node_to_reqs: Dict[TreeNode, List[Req]], q: List, + temporary_deprioritized: Dict[str, Req], ): childs = [child for child in cur_node.children.values()] childs.sort(key=lambda x: -node_to_priority[x]) for child in childs: - self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) - q.extend(last_node_to_reqs[cur_node]) + self.get_dfs_priority( + child, node_to_priority, last_node_to_reqs, q, temporary_deprioritized + ) + + for req in last_node_to_reqs[cur_node]: + if req.rid in temporary_deprioritized: + continue + q.append(req) class AddReqResult(Enum): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4680b042d..465dfbfc3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -713,7 +713,7 @@ class Scheduler: if crash_on_warnings(): raise ValueError(msg) - def get_next_batch_to_run(self): + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge the prefill batch into the running batch if self.last_batch and self.last_batch.forward_mode.is_extend(): if self.being_chunked_req: diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index 2808ca872..acdd2898f 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, List, Tuple class BasePrefixCache(ABC): @@ -10,7 +10,7 @@ class BasePrefixCache(ABC): pass @abstractmethod - def match_prefix(self, **kwargs): + def match_prefix(self, **kwargs) -> Tuple[List[int], int]: pass @abstractmethod diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 3c430aba3..ab8965a01 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -2,7 +2,7 @@ from __future__ import annotations """Cache for chunked prefill, used when RadixCache is disabled.""" -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache): def reset(self): self.entries = {} - def match_prefix(self, rid: int, key: List[int]): + def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]: if rid not in self.entries: return [], None diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 8cd8354b6..1673d4f0c 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache. import heapq import time from collections import defaultdict -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple import torch @@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache): self.root_node.lock_ref = 1 self.evictable_size_ = 0 - def match_prefix(self, key: List, **kwargs): + def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: + """Find the matching prefix from the radix tree. + Args: + key: A list of token IDs to find a matching prefix. + Returns: + A tuple of a tensor of matching prefix token IDs and + the last node that contains the prefix values. Note that + this API can modify the internal state of the Radix tree. + The last node create a new child if the prefix is shorter + than the last node's value. + """ if self.disable: return [], self.root_node diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 5689b097d..98e0f3f4f 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -79,7 +79,14 @@ class HttpResponse: return self.resp.status -def http_request(url, json=None, stream=False, api_key=None, verify=None): +def http_request( + url, + json=None, + stream=False, + api_key=None, + verify=None, + method: Optional[str] = None, +): """A faster version of requests.post with low-level urllib API.""" headers = {"Content-Type": "application/json; charset=utf-8"} @@ -90,7 +97,7 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None): if stream: return requests.post(url, json=json, stream=True, headers=headers) else: - req = urllib.request.Request(url, headers=headers) + req = urllib.request.Request(url, headers=headers, method=method) if json is None: data = None else: