Refine the add request reasons to avoid corner cases. (#1574)
This commit is contained in:
@@ -19,6 +19,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from enum import Enum, auto
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||||
@@ -104,6 +105,12 @@ class SchedulePolicy:
|
|||||||
q.extend(last_node_to_reqs[cur_node])
|
q.extend(last_node_to_reqs[cur_node])
|
||||||
|
|
||||||
|
|
||||||
|
class AddReqResult(Enum):
|
||||||
|
CONTINUE = auto() # Continue to add requests
|
||||||
|
NO_TOKEN = auto() # No token left
|
||||||
|
OTHER = auto() # Other reasons to stop adding requests
|
||||||
|
|
||||||
|
|
||||||
class PrefillAdder:
|
class PrefillAdder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -145,17 +152,16 @@ class PrefillAdder:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def no_remaining_tokens(self):
|
def budget_state(self):
|
||||||
return (
|
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
|
||||||
self.rem_total_tokens <= 0
|
return AddReqResult.NO_TOKEN
|
||||||
or self.rem_input_tokens <= 0
|
|
||||||
or (
|
if self.rem_input_tokens <= 0 or (
|
||||||
self.rem_chunk_tokens <= 0
|
self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
|
||||||
if self.rem_chunk_tokens is not None
|
):
|
||||||
else False
|
return AddReqResult.OTHER
|
||||||
)
|
|
||||||
or self.cur_rem_tokens <= 0
|
return AddReqResult.CONTINUE
|
||||||
)
|
|
||||||
|
|
||||||
def _prefill_one_req(
|
def _prefill_one_req(
|
||||||
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
||||||
@@ -239,7 +245,7 @@ class PrefillAdder:
|
|||||||
)
|
)
|
||||||
bs = len(self.req_states) - i
|
bs = len(self.req_states) - i
|
||||||
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
|
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
|
||||||
return False
|
return AddReqResult.NO_TOKEN
|
||||||
tokens_freed += tokens_occupied
|
tokens_freed += tokens_occupied
|
||||||
|
|
||||||
if req.extend_input_len <= self.rem_chunk_tokens:
|
if req.extend_input_len <= self.rem_chunk_tokens:
|
||||||
@@ -258,7 +264,7 @@ class PrefillAdder:
|
|||||||
self.new_inflight_req = req
|
self.new_inflight_req = req
|
||||||
self._prefill_one_req(0, trunc_len, 0)
|
self._prefill_one_req(0, trunc_len, 0)
|
||||||
|
|
||||||
return True
|
return self.budget_state()
|
||||||
|
|
||||||
def add_one_req(self, req: Req):
|
def add_one_req(self, req: Req):
|
||||||
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
||||||
@@ -271,14 +277,14 @@ class PrefillAdder:
|
|||||||
prefix_len = len(req.prefix_indices)
|
prefix_len = len(req.prefix_indices)
|
||||||
|
|
||||||
if total_tokens >= self.rem_total_tokens:
|
if total_tokens >= self.rem_total_tokens:
|
||||||
return False
|
return AddReqResult.NO_TOKEN
|
||||||
|
|
||||||
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
|
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
|
||||||
return False
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
with self._lock_node(req.last_node):
|
with self._lock_node(req.last_node):
|
||||||
if total_tokens > self.rem_total_tokens:
|
if total_tokens > self.rem_total_tokens:
|
||||||
return False
|
return AddReqResult.NO_TOKEN
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.rem_chunk_tokens is None
|
self.rem_chunk_tokens is None
|
||||||
@@ -297,7 +303,7 @@ class PrefillAdder:
|
|||||||
# Chunked prefill
|
# Chunked prefill
|
||||||
trunc_len = self.rem_chunk_tokens
|
trunc_len = self.rem_chunk_tokens
|
||||||
if trunc_len == 0:
|
if trunc_len == 0:
|
||||||
return False
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||||
@@ -306,4 +312,4 @@ class PrefillAdder:
|
|||||||
self.tree_cache.inc_lock_ref(req.last_node)
|
self.tree_cache.inc_lock_ref(req.last_node)
|
||||||
self._prefill_one_req(prefix_len, trunc_len, 0)
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||||
|
|
||||||
return True and not self.no_remaining_tokens()
|
return self.budget_state()
|
||||||
|
|||||||
@@ -50,7 +50,11 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
Req,
|
Req,
|
||||||
ScheduleBatch,
|
ScheduleBatch,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy
|
from sglang.srt.managers.schedule_policy import (
|
||||||
|
AddReqResult,
|
||||||
|
PrefillAdder,
|
||||||
|
SchedulePolicy,
|
||||||
|
)
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
@@ -493,16 +497,15 @@ class Scheduler:
|
|||||||
self.batch_is_full = True
|
self.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if adder.no_remaining_tokens():
|
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
|
||||||
self.batch_is_full = True
|
self.batch_is_full = True
|
||||||
break
|
break
|
||||||
|
|
||||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||||
res = adder.add_one_req(req)
|
res = adder.add_one_req(req)
|
||||||
if (
|
if res != AddReqResult.CONTINUE:
|
||||||
not res
|
if res == AddReqResult.NO_TOKEN:
|
||||||
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
self.batch_is_full = True
|
||||||
):
|
|
||||||
self.batch_is_full = True
|
|
||||||
break
|
break
|
||||||
|
|
||||||
can_run_list = adder.can_run_list
|
can_run_list = adder.can_run_list
|
||||||
|
|||||||
Reference in New Issue
Block a user