diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index e4a22242f..e252c2737 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -18,13 +18,15 @@ limitations under the License. import random from collections import defaultdict from contextlib import contextmanager -from typing import List +from typing import Dict, List from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.radix_cache import TreeNode class PolicyScheduler: - def __init__(self, policy, tree_cache): + def __init__(self, policy: str, tree_cache: BasePrefixCache): if tree_cache.disable and policy in ["lpm", "dfs-weight"]: # LPM and DFS-weight is meaningless when the tree cache is disabled. policy = "fcfs" @@ -72,12 +74,18 @@ class PolicyScheduler: else: raise ValueError(f"Unknown schedule_policy: {self.policy}") - def calc_weight(self, cur_node, node_to_weight): + def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict): for child in cur_node.children.values(): self.calc_weight(child, node_to_weight) node_to_weight[cur_node] += node_to_weight[child] - def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q): + def get_dfs_priority( + self, + cur_node: TreeNode, + node_to_priority: Dict, + last_node_to_reqs: Dict, + q: List, + ): childs = [child for child in cur_node.children.values()] childs.sort(key=lambda x: -node_to_priority[x]) for child in childs: @@ -88,10 +96,10 @@ class PolicyScheduler: class PrefillAdder: def __init__( self, - tree_cache, - rem_total_tokens, - rem_input_tokens, - rem_chunk_tokens, + tree_cache: BasePrefixCache, + rem_total_tokens: int, + rem_input_tokens: int, + rem_chunk_tokens: int, ): self.tree_cache = tree_cache self.rem_total_tokens = rem_total_tokens @@ -151,7 +159,7 @@ class PrefillAdder: return req if truncated else None @contextmanager - def _lock_node(self, last_node): + def _lock_node(self, last_node: TreeNode): try: delta = self.tree_cache.inc_lock_ref(last_node) self.rem_total_tokens += delta diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 645c30a4f..49ee8c839 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -21,15 +21,17 @@ import os import pickle import time import warnings -from typing import List, Optional, Union +from typing import Any, List, Optional, Union import torch +import torch.distributed import torch.distributed as dist from sglang.global_config import global_config from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -62,6 +64,10 @@ from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) +# TODO: Rename "CI" to "SGLANG_IS_IN_CI". +crash_on_warning = os.getenv("CI", "false") == "true" + + class ModelTpServer: def __init__( self, @@ -198,7 +204,7 @@ class ModelTpServer: self.new_token_ratio = self.min_new_token_ratio self.new_token_ratio_decay = global_config.new_token_ratio_decay - def exposed_step(self, recv_reqs): + def exposed_step(self, recv_reqs: List): try: # Recv requests for recv_req in recv_reqs: @@ -247,7 +253,7 @@ class ModelTpServer: # Print stats if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: - self.print_stats() + self.print_decode_stats() if self.running_batch.is_empty(): self.running_batch = None @@ -259,7 +265,7 @@ class ModelTpServer: self.check_memory() self.new_token_ratio = global_config.init_new_token_ratio - def print_stats(self): + def print_decode_stats(self): num_used = self.max_total_num_tokens - ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) @@ -276,7 +282,6 @@ class ModelTpServer: ) def check_memory(self): - crash = os.getenv("CI", "false") == "true" available_size = ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) @@ -286,7 +291,7 @@ class ModelTpServer: f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n" "KV cache pool leak detected!" ) - exit(1) if crash else None + exit(1) if crash_on_warning else None if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: warnings.warn( @@ -295,7 +300,7 @@ class ModelTpServer: f"total slots={self.req_to_token_pool.size}\n" "Memory pool leak detected!" ) - exit(1) if crash else None + exit(1) if crash_on_warning else None def handle_generate_request( self, @@ -511,7 +516,14 @@ class ModelTpServer: self.handle_finished_requests(batch) - def add_logprob_return_values(self, i, req: Req, pt, next_token_ids, output): + def add_logprob_return_values( + self, + i, + req: Req, + pt: int, + next_token_ids: List[int], + output: LogitProcessorOutput, + ): if req.normalized_prompt_logprob is None: req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] @@ -786,7 +798,11 @@ def run_tp_server( def launch_tp_servers( - gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args + gpu_ids: List[int], + tp_rank_range: List[int], + server_args: ServerArgs, + nccl_port: int, + model_overide_args: dict, ): """Launch multiple tensor parallel servers.""" procs = [] @@ -801,7 +817,9 @@ def launch_tp_servers( return procs -def broadcast_recv_input(data, rank, dist_group): +def broadcast_recv_input( + data: Any, rank: int, dist_group: torch.distributed.ProcessGroup +): """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" if rank == 0: diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index fb2b7a627..2808ca872 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Callable class BasePrefixCache(ABC): @@ -25,7 +26,7 @@ class BasePrefixCache(ABC): pass @abstractmethod - def evict(self, num_tokens, evict_callback): + def evict(self, num_tokens: int, evict_callback: Callable): pass @abstractmethod @@ -41,7 +42,7 @@ class BasePrefixCache(ABC): pass def total_size(self): - raise NotImplementedError + raise NotImplementedError() def pretty_print(self): - raise NotImplementedError + raise NotImplementedError() diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 0b4448bff..c6e6507a0 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -1,8 +1,11 @@ +from __future__ import annotations + """Cache for chunked prefill, used when RadixCache is disabled.""" -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -15,7 +18,9 @@ class ChunkCacheEntry: class ChunkCache(BasePrefixCache): - def __init__(self, req_to_token_pool, token_to_kv_pool): + def __init__( + self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool + ): self.disable = True self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool = token_to_kv_pool @@ -32,7 +37,7 @@ class ChunkCache(BasePrefixCache): entry = self.entries[rid] return entry.value, entry - def cache_finished_req(self, req: "Req", token_ids=None): + def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): if token_ids is None: token_ids = (req.origin_input_ids + req.output_ids)[:-1] @@ -45,7 +50,7 @@ class ChunkCache(BasePrefixCache): if req.rid in self.entries: del self.entries[req.rid] - def cache_unfinished_req(self, req: "Req", token_ids=None): + def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): if token_ids is None: token_ids = req.fill_ids @@ -64,7 +69,7 @@ class ChunkCache(BasePrefixCache): def insert(self): raise NotImplementedError - def evict(self, num_tokens, evict_callback): + def evict(self, num_tokens: int, evict_callback: Callable): pass def inc_lock_ref(self, node): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 37ce4296d..68cefbbf9 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -16,7 +16,7 @@ limitations under the License. """Memory pool.""" import logging -from typing import List +from typing import List, Union import torch @@ -42,7 +42,7 @@ class ReqToTokenPool: return select_index - def free(self, free_index): + def free(self, free_index: Union[int, List[int]]): if isinstance(free_index, (int,)): self.free_slots.append(free_index) else: diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 25a467304..8ebe903c7 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,6 +27,7 @@ from typing import TYPE_CHECKING import torch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -43,7 +46,7 @@ class TreeNode: return self.last_access_time < other.last_access_time -def _key_match(key0, key1): +def _key_match(key0: List, key1: List): i = 0 for k0, k1 in zip(key0, key1): if k0 != k1: @@ -53,7 +56,12 @@ def _key_match(key0, key1): class RadixCache(BasePrefixCache): - def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False): + def __init__( + self, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool: BaseTokenToKVPool, + disable: bool = False, + ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool = token_to_kv_pool self.disable = disable @@ -68,7 +76,7 @@ class RadixCache(BasePrefixCache): self.root_node.lock_ref = 1 self.evictable_size_ = 0 - def match_prefix(self, key, **kwargs): + def match_prefix(self, key: List, **kwargs): if self.disable: return [], self.root_node @@ -81,7 +89,7 @@ class RadixCache(BasePrefixCache): value = torch.tensor([], dtype=torch.int32) return value, last_node[0] - def insert(self, key, value=None): + def insert(self, key: List, value=None): if self.disable: return 0 @@ -89,7 +97,7 @@ class RadixCache(BasePrefixCache): value = [x for x in key] return self._insert_helper(self.root_node, key, value) - def cache_finished_req(self, req: "Req", token_ids=None): + def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): """Cache request when it finishes.""" if token_ids is None: token_ids = (req.origin_input_ids + req.output_ids)[:-1] @@ -110,7 +118,7 @@ class RadixCache(BasePrefixCache): self.req_to_token_pool.free(req.req_pool_idx) self.dec_lock_ref(req.last_node) - def cache_unfinished_req(self, req: "Req", token_ids=None): + def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): """Cache request when it is unfinished.""" if self.disable: return @@ -145,7 +153,7 @@ class RadixCache(BasePrefixCache): def total_size(self): return self._total_size_helper(self.root_node) - def evict(self, num_tokens, evict_callback): + def evict(self, num_tokens: int, evict_callback: Callable): if self.disable: return @@ -199,7 +207,9 @@ class RadixCache(BasePrefixCache): ##### Internal Helper Functions ##### - def _match_prefix_helper(self, node, key, value, last_node): + def _match_prefix_helper( + self, node: TreeNode, key: List, value, last_node: TreeNode + ): node.last_access_time = time.time() if len(key) == 0: return @@ -216,7 +226,7 @@ class RadixCache(BasePrefixCache): last_node[0] = child self._match_prefix_helper(child, key[prefix_len:], value, last_node) - def _split_node(self, key, child: TreeNode, split_len): + def _split_node(self, key, child: TreeNode, split_len: int): # new_node -> child new_node = TreeNode() new_node.children = {key[split_len:][0]: child} @@ -230,7 +240,7 @@ class RadixCache(BasePrefixCache): new_node.parent.children[key[:split_len][0]] = new_node return new_node - def _insert_helper(self, node, key, value): + def _insert_helper(self, node: TreeNode, key: List, value): node.last_access_time = time.time() if len(key) == 0: return 0 @@ -261,7 +271,7 @@ class RadixCache(BasePrefixCache): self.evictable_size_ += len(value) return 0 - def _print_helper(self, node: TreeNode, indent): + def _print_helper(self, node: TreeNode, indent: int): for _, child in node.children.items(): print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}") self._print_helper(child, indent=indent + 2) @@ -273,7 +283,7 @@ class RadixCache(BasePrefixCache): del node.parent.children[k] self.evictable_size_ -= len(node.key) - def _total_size_helper(self, node): + def _total_size_helper(self, node: TreeNode): x = len(node.value) for child in node.children.values(): x += self._total_size_helper(child)