Refine the add request reasons to avoid corner cases. (#1574)
This commit is contained in:
@@ -19,6 +19,7 @@ import os
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
@@ -104,6 +105,12 @@ class SchedulePolicy:
|
||||
q.extend(last_node_to_reqs[cur_node])
|
||||
|
||||
|
||||
class AddReqResult(Enum):
|
||||
CONTINUE = auto() # Continue to add requests
|
||||
NO_TOKEN = auto() # No token left
|
||||
OTHER = auto() # Other reasons to stop adding requests
|
||||
|
||||
|
||||
class PrefillAdder:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -145,17 +152,16 @@ class PrefillAdder:
|
||||
]
|
||||
)
|
||||
|
||||
def no_remaining_tokens(self):
|
||||
return (
|
||||
self.rem_total_tokens <= 0
|
||||
or self.rem_input_tokens <= 0
|
||||
or (
|
||||
self.rem_chunk_tokens <= 0
|
||||
if self.rem_chunk_tokens is not None
|
||||
else False
|
||||
)
|
||||
or self.cur_rem_tokens <= 0
|
||||
)
|
||||
def budget_state(self):
|
||||
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
|
||||
return AddReqResult.NO_TOKEN
|
||||
|
||||
if self.rem_input_tokens <= 0 or (
|
||||
self.rem_chunk_tokens is not None and self.rem_chunk_tokens <= 0
|
||||
):
|
||||
return AddReqResult.OTHER
|
||||
|
||||
return AddReqResult.CONTINUE
|
||||
|
||||
def _prefill_one_req(
|
||||
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
||||
@@ -239,7 +245,7 @@ class PrefillAdder:
|
||||
)
|
||||
bs = len(self.req_states) - i
|
||||
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
|
||||
return False
|
||||
return AddReqResult.NO_TOKEN
|
||||
tokens_freed += tokens_occupied
|
||||
|
||||
if req.extend_input_len <= self.rem_chunk_tokens:
|
||||
@@ -258,7 +264,7 @@ class PrefillAdder:
|
||||
self.new_inflight_req = req
|
||||
self._prefill_one_req(0, trunc_len, 0)
|
||||
|
||||
return True
|
||||
return self.budget_state()
|
||||
|
||||
def add_one_req(self, req: Req):
|
||||
if req.sampling_params.ignore_eos and self.tree_cache.disable:
|
||||
@@ -271,14 +277,14 @@ class PrefillAdder:
|
||||
prefix_len = len(req.prefix_indices)
|
||||
|
||||
if total_tokens >= self.rem_total_tokens:
|
||||
return False
|
||||
return AddReqResult.NO_TOKEN
|
||||
|
||||
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
|
||||
return False
|
||||
return AddReqResult.OTHER
|
||||
|
||||
with self._lock_node(req.last_node):
|
||||
if total_tokens > self.rem_total_tokens:
|
||||
return False
|
||||
return AddReqResult.NO_TOKEN
|
||||
|
||||
if (
|
||||
self.rem_chunk_tokens is None
|
||||
@@ -297,7 +303,7 @@ class PrefillAdder:
|
||||
# Chunked prefill
|
||||
trunc_len = self.rem_chunk_tokens
|
||||
if trunc_len == 0:
|
||||
return False
|
||||
return AddReqResult.OTHER
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||
@@ -306,4 +312,4 @@ class PrefillAdder:
|
||||
self.tree_cache.inc_lock_ref(req.last_node)
|
||||
self._prefill_one_req(prefix_len, trunc_len, 0)
|
||||
|
||||
return True and not self.no_remaining_tokens()
|
||||
return self.budget_state()
|
||||
|
||||
@@ -50,7 +50,11 @@ from sglang.srt.managers.schedule_batch import (
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
)
|
||||
from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy
|
||||
from sglang.srt.managers.schedule_policy import (
|
||||
AddReqResult,
|
||||
PrefillAdder,
|
||||
SchedulePolicy,
|
||||
)
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
@@ -493,16 +497,15 @@ class Scheduler:
|
||||
self.batch_is_full = True
|
||||
break
|
||||
|
||||
if adder.no_remaining_tokens():
|
||||
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
|
||||
self.batch_is_full = True
|
||||
break
|
||||
|
||||
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
||||
res = adder.add_one_req(req)
|
||||
if (
|
||||
not res
|
||||
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
||||
):
|
||||
self.batch_is_full = True
|
||||
if res != AddReqResult.CONTINUE:
|
||||
if res == AddReqResult.NO_TOKEN:
|
||||
self.batch_is_full = True
|
||||
break
|
||||
|
||||
can_run_list = adder.can_run_list
|
||||
|
||||
Reference in New Issue
Block a user