Query remaining memory dynamically for PrefillAdder (#2941)
This commit is contained in:
@@ -24,6 +24,7 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||||
|
|
||||||
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
||||||
@@ -250,23 +251,24 @@ class PrefillAdder:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
|
token_to_kv_pool: BaseTokenToKVPool,
|
||||||
running_batch: ScheduleBatch,
|
running_batch: ScheduleBatch,
|
||||||
new_token_ratio: float,
|
new_token_ratio: float,
|
||||||
rem_total_tokens: int,
|
|
||||||
rem_input_tokens: int,
|
rem_input_tokens: int,
|
||||||
rem_chunk_tokens: Optional[int],
|
rem_chunk_tokens: Optional[int],
|
||||||
mixed_with_decode_tokens: int = 0,
|
mixed_with_decode_tokens: int = 0,
|
||||||
):
|
):
|
||||||
self.tree_cache = tree_cache
|
self.tree_cache = tree_cache
|
||||||
|
self.token_to_kv_pool = token_to_kv_pool
|
||||||
self.running_batch = running_batch
|
self.running_batch = running_batch
|
||||||
self.new_token_ratio = new_token_ratio
|
self.new_token_ratio = new_token_ratio
|
||||||
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
|
||||||
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
||||||
self.rem_chunk_tokens = rem_chunk_tokens
|
self.rem_chunk_tokens = rem_chunk_tokens
|
||||||
if self.rem_chunk_tokens is not None:
|
if self.rem_chunk_tokens is not None:
|
||||||
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
||||||
|
|
||||||
self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
|
self.rem_total_token_offset = mixed_with_decode_tokens
|
||||||
|
self.cur_rem_token_offset = mixed_with_decode_tokens
|
||||||
|
|
||||||
self.req_states = None
|
self.req_states = None
|
||||||
self.can_run_list = []
|
self.can_run_list = []
|
||||||
@@ -275,8 +277,7 @@ class PrefillAdder:
|
|||||||
self.log_input_tokens = 0
|
self.log_input_tokens = 0
|
||||||
|
|
||||||
if running_batch is not None:
|
if running_batch is not None:
|
||||||
# Pre-remove the tokens which will be occupied by the running requests
|
self.rem_total_token_offset += sum(
|
||||||
self.rem_total_tokens -= sum(
|
|
||||||
[
|
[
|
||||||
min(
|
min(
|
||||||
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
||||||
@@ -287,6 +288,22 @@ class PrefillAdder:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rem_total_tokens(self):
|
||||||
|
return (
|
||||||
|
self.token_to_kv_pool.available_size()
|
||||||
|
+ self.tree_cache.evictable_size()
|
||||||
|
- self.rem_total_token_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cur_rem_tokens(self):
|
||||||
|
return (
|
||||||
|
self.token_to_kv_pool.available_size()
|
||||||
|
+ self.tree_cache.evictable_size()
|
||||||
|
- self.cur_rem_token_offset
|
||||||
|
)
|
||||||
|
|
||||||
def budget_state(self):
|
def budget_state(self):
|
||||||
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
|
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
|
||||||
return AddReqResult.NO_TOKEN
|
return AddReqResult.NO_TOKEN
|
||||||
@@ -301,8 +318,8 @@ class PrefillAdder:
|
|||||||
def _prefill_one_req(
|
def _prefill_one_req(
|
||||||
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
||||||
):
|
):
|
||||||
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
self.rem_total_token_offset += extend_input_len + max_new_tokens
|
||||||
self.cur_rem_tokens -= extend_input_len
|
self.cur_rem_token_offset += extend_input_len
|
||||||
self.rem_input_tokens -= extend_input_len
|
self.rem_input_tokens -= extend_input_len
|
||||||
if self.rem_chunk_tokens is not None:
|
if self.rem_chunk_tokens is not None:
|
||||||
self.rem_chunk_tokens -= extend_input_len
|
self.rem_chunk_tokens -= extend_input_len
|
||||||
@@ -332,12 +349,10 @@ class PrefillAdder:
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def _lock_node(self, last_node: TreeNode):
|
def _lock_node(self, last_node: TreeNode):
|
||||||
try:
|
try:
|
||||||
delta = self.tree_cache.inc_lock_ref(last_node)
|
self.tree_cache.inc_lock_ref(last_node)
|
||||||
self.rem_total_tokens += delta
|
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
delta = self.tree_cache.dec_lock_ref(last_node)
|
self.tree_cache.dec_lock_ref(last_node)
|
||||||
self.rem_total_tokens += delta
|
|
||||||
|
|
||||||
def add_one_req_ignore_eos(self, req: Req):
|
def add_one_req_ignore_eos(self, req: Req):
|
||||||
def add_req_state(r, insert_sort=False):
|
def add_req_state(r, insert_sort=False):
|
||||||
|
|||||||
@@ -891,9 +891,9 @@ class Scheduler:
|
|||||||
# Prefill policy
|
# Prefill policy
|
||||||
adder = PrefillAdder(
|
adder = PrefillAdder(
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
|
self.token_to_kv_pool,
|
||||||
self.running_batch,
|
self.running_batch,
|
||||||
self.new_token_ratio,
|
self.new_token_ratio,
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
self.chunked_prefill_size,
|
self.chunked_prefill_size,
|
||||||
running_bs if self.is_mixed_chunk else 0,
|
running_bs if self.is_mixed_chunk else 0,
|
||||||
|
|||||||
Reference in New Issue
Block a user