Query remaining memory dynamically for PrefillAdder (#2941)
This commit is contained in:
@@ -24,6 +24,7 @@ import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
|
||||
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
||||
@@ -250,23 +251,24 @@ class PrefillAdder:
|
||||
def __init__(
|
||||
self,
|
||||
tree_cache: BasePrefixCache,
|
||||
token_to_kv_pool: BaseTokenToKVPool,
|
||||
running_batch: ScheduleBatch,
|
||||
new_token_ratio: float,
|
||||
rem_total_tokens: int,
|
||||
rem_input_tokens: int,
|
||||
rem_chunk_tokens: Optional[int],
|
||||
mixed_with_decode_tokens: int = 0,
|
||||
):
|
||||
self.tree_cache = tree_cache
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.running_batch = running_batch
|
||||
self.new_token_ratio = new_token_ratio
|
||||
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
||||
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
||||
self.rem_chunk_tokens = rem_chunk_tokens
|
||||
if self.rem_chunk_tokens is not None:
|
||||
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
||||
|
||||
self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
|
||||
self.rem_total_token_offset = mixed_with_decode_tokens
|
||||
self.cur_rem_token_offset = mixed_with_decode_tokens
|
||||
|
||||
self.req_states = None
|
||||
self.can_run_list = []
|
||||
@@ -275,8 +277,7 @@ class PrefillAdder:
|
||||
self.log_input_tokens = 0
|
||||
|
||||
if running_batch is not None:
|
||||
# Pre-remove the tokens which will be occupied by the running requests
|
||||
self.rem_total_tokens -= sum(
|
||||
self.rem_total_token_offset += sum(
|
||||
[
|
||||
min(
|
||||
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
||||
@@ -287,6 +288,22 @@ class PrefillAdder:
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def rem_total_tokens(self):
|
||||
return (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
- self.rem_total_token_offset
|
||||
)
|
||||
|
||||
@property
|
||||
def cur_rem_tokens(self):
|
||||
return (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
- self.cur_rem_token_offset
|
||||
)
|
||||
|
||||
def budget_state(self):
|
||||
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
|
||||
return AddReqResult.NO_TOKEN
|
||||
@@ -301,8 +318,8 @@ class PrefillAdder:
|
||||
def _prefill_one_req(
|
||||
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
||||
):
|
||||
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
||||
self.cur_rem_tokens -= extend_input_len
|
||||
self.rem_total_token_offset += extend_input_len + max_new_tokens
|
||||
self.cur_rem_token_offset += extend_input_len
|
||||
self.rem_input_tokens -= extend_input_len
|
||||
if self.rem_chunk_tokens is not None:
|
||||
self.rem_chunk_tokens -= extend_input_len
|
||||
@@ -332,12 +349,10 @@ class PrefillAdder:
|
||||
@contextmanager
|
||||
def _lock_node(self, last_node: TreeNode):
|
||||
try:
|
||||
delta = self.tree_cache.inc_lock_ref(last_node)
|
||||
self.rem_total_tokens += delta
|
||||
self.tree_cache.inc_lock_ref(last_node)
|
||||
yield None
|
||||
finally:
|
||||
delta = self.tree_cache.dec_lock_ref(last_node)
|
||||
self.rem_total_tokens += delta
|
||||
self.tree_cache.dec_lock_ref(last_node)
|
||||
|
||||
def add_one_req_ignore_eos(self, req: Req):
|
||||
def add_req_state(r, insert_sort=False):
|
||||
|
||||
@@ -891,9 +891,9 @@ class Scheduler:
|
||||
# Prefill policy
|
||||
adder = PrefillAdder(
|
||||
self.tree_cache,
|
||||
self.token_to_kv_pool,
|
||||
self.running_batch,
|
||||
self.new_token_ratio,
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
||||
self.max_prefill_tokens,
|
||||
self.chunked_prefill_size,
|
||||
running_bs if self.is_mixed_chunk else 0,
|
||||
|
||||
Reference in New Issue
Block a user