Support page size > 1 (#4356)
This commit is contained in:
@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
|
||||
crash_on_warnings,
|
||||
get_bool_env_var,
|
||||
get_zmq_socket,
|
||||
kill_itself_when_parent_died,
|
||||
pyspy_dump_schedulers,
|
||||
set_gpu_proc_affinity,
|
||||
set_random_seed,
|
||||
@@ -159,6 +160,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
)
|
||||
self.gpu_id = gpu_id
|
||||
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
||||
self.page_size = server_args.page_size
|
||||
|
||||
# Distributed rank info
|
||||
self.dp_size = server_args.dp_size
|
||||
@@ -265,20 +267,23 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
f"context_len={self.model_config.context_len}"
|
||||
)
|
||||
|
||||
# Init memory pool and cache
|
||||
self.init_memory_pool_and_cache()
|
||||
|
||||
# Init running status
|
||||
self.waiting_queue: List[Req] = []
|
||||
# The running decoding batch for continuous batching
|
||||
self.running_batch: Optional[ScheduleBatch] = None
|
||||
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
|
||||
# The current forward batch
|
||||
self.cur_batch: Optional[ScheduleBatch] = None
|
||||
# The current forward batch
|
||||
# The last forward batch
|
||||
self.last_batch: Optional[ScheduleBatch] = None
|
||||
self.forward_ct = 0
|
||||
self.forward_ct_decode = 0
|
||||
self.num_generated_tokens = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.last_decode_stats_tic = time.time()
|
||||
self.last_prefill_stats_tic = time.time()
|
||||
self.return_health_check_ct = 0
|
||||
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||
if self.device == "cpu":
|
||||
@@ -307,7 +312,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
# Init schedule policy and new token estimation
|
||||
self.policy = SchedulePolicy(
|
||||
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache
|
||||
self.schedule_policy,
|
||||
self.tree_cache,
|
||||
self.enable_hierarchical_cache,
|
||||
)
|
||||
assert (
|
||||
server_args.schedule_conservativeness >= 0
|
||||
@@ -327,11 +334,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
) / global_config.default_new_token_ratio_decay_steps
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
# Tell whether the current running batch is full so that we can skip
|
||||
# the check of whether to prefill new requests.
|
||||
# This is an optimization to reduce the overhead of the prefill check.
|
||||
self.batch_is_full = False
|
||||
|
||||
# Init watchdog thread
|
||||
self.watchdog_timeout = server_args.watchdog_timeout
|
||||
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
||||
@@ -437,6 +439,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
page_size=self.page_size,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
|
||||
@@ -458,6 +461,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
# The largest context length (prefill + generation) of a single request
|
||||
self._largest_prefill_decode_len: int = 0
|
||||
self.last_gen_throughput: float = 0.0
|
||||
self.last_input_throughput: float = 0.0
|
||||
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
||||
self.spec_num_total_accepted_tokens = 0
|
||||
self.spec_num_total_forward_ct = 0
|
||||
@@ -487,7 +491,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
else:
|
||||
# When the server is idle, so self-check and re-init some states
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
@@ -527,7 +531,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
elif batch is None:
|
||||
# When the server is idle, so self-check and re-init some states
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
@@ -588,7 +592,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
for recv_req in recv_reqs:
|
||||
# If it is a health check generation request and there are running requests, ignore it.
|
||||
if is_health_check_generate_req(recv_req) and (
|
||||
self.chunked_req is not None or self.running_batch is not None
|
||||
self.chunked_req is not None or not self.running_batch.is_empty()
|
||||
):
|
||||
self.return_health_check_ct += 1
|
||||
continue
|
||||
@@ -812,6 +816,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
can_run_list: List[Req],
|
||||
running_bs: int,
|
||||
):
|
||||
gap_latency = time.time() - self.last_prefill_stats_tic
|
||||
self.last_prefill_stats_tic = time.time()
|
||||
self.last_input_throughput = self.num_prefill_tokens / gap_latency
|
||||
self.num_prefill_tokens = 0
|
||||
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
@@ -847,7 +856,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.last_decode_stats_tic = time.time()
|
||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
num_running_reqs = len(self.running_batch.reqs)
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
@@ -911,8 +920,10 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
)
|
||||
if memory_leak:
|
||||
msg = (
|
||||
"KV cache pool leak detected!"
|
||||
"KV cache pool leak detected! "
|
||||
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
||||
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
||||
f"{self.tree_cache.evictable_size()=}\n"
|
||||
)
|
||||
warnings.warn(msg)
|
||||
if crash_on_warnings():
|
||||
@@ -938,7 +949,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
num_running_reqs = len(self.running_batch.reqs)
|
||||
self.stats.num_running_reqs = num_running_reqs
|
||||
self.stats.num_used_tokens = num_used
|
||||
self.stats.token_usage = num_used / self.max_total_num_tokens
|
||||
@@ -956,20 +967,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
||||
# chunked request keeps its rid but will get a new req_pool_idx
|
||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||
self.batch_is_full = False
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
# Filter batch
|
||||
last_bs = self.last_batch.batch_size()
|
||||
self.last_batch.filter_batch()
|
||||
if self.last_batch.batch_size() < last_bs:
|
||||
self.batch_is_full = False
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
# Merge the new batch into the running batch
|
||||
if not self.last_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = self.last_batch
|
||||
else:
|
||||
# merge running_batch with prefill batch
|
||||
# Merge running_batch with prefill batch
|
||||
self.running_batch.merge_batch(self.last_batch)
|
||||
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
@@ -978,11 +989,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
ret = new_batch
|
||||
else:
|
||||
# Run decode
|
||||
if self.running_batch is None:
|
||||
ret = None
|
||||
else:
|
||||
if not self.running_batch.is_empty():
|
||||
self.running_batch = self.update_running_batch(self.running_batch)
|
||||
ret = self.running_batch
|
||||
ret = self.running_batch if not self.running_batch.is_empty() else None
|
||||
else:
|
||||
ret = None
|
||||
|
||||
# Handle DP attention
|
||||
if self.server_args.enable_dp_attention:
|
||||
@@ -997,13 +1008,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
# Handle the cases where prefill is not allowed
|
||||
if (
|
||||
self.batch_is_full or len(self.waiting_queue) == 0
|
||||
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
||||
) and self.chunked_req is None:
|
||||
return None
|
||||
|
||||
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
running_bs = len(self.running_batch.reqs)
|
||||
if running_bs >= self.max_running_requests:
|
||||
self.batch_is_full = True
|
||||
self.running_batch.batch_is_full = True
|
||||
return None
|
||||
|
||||
if self.enable_hierarchical_cache:
|
||||
@@ -1025,17 +1036,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
running_bs if self.is_mixed_chunk else 0,
|
||||
)
|
||||
|
||||
is_chunked = self.chunked_req is not None
|
||||
if is_chunked:
|
||||
if self.chunked_req is not None:
|
||||
self.chunked_req.init_next_round_input()
|
||||
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
||||
|
||||
if self.lora_paths:
|
||||
lora_set = (
|
||||
set([req.lora_path for req in self.running_batch.reqs])
|
||||
if self.running_batch is not None
|
||||
else set([])
|
||||
)
|
||||
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
||||
|
||||
# Get requests from the waiting queue to a new prefill batch
|
||||
for req in self.waiting_queue:
|
||||
if (
|
||||
@@ -1047,11 +1054,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
)
|
||||
> self.max_loras_per_batch
|
||||
):
|
||||
self.batch_is_full = True
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
|
||||
self.batch_is_full = True
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
req.init_next_round_input(
|
||||
@@ -1066,12 +1073,14 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
if res == AddReqResult.NO_TOKEN:
|
||||
if self.enable_hierarchical_cache:
|
||||
# Set batch_is_full after making sure there are requests that can be served
|
||||
self.batch_is_full = len(adder.can_run_list) > 0 or (
|
||||
self.running_batch.batch_is_full = len(
|
||||
adder.can_run_list
|
||||
) > 0 or (
|
||||
self.running_batch is not None
|
||||
and not self.running_batch.is_empty()
|
||||
)
|
||||
else:
|
||||
self.batch_is_full = True
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
# Update waiting queue
|
||||
@@ -1112,7 +1121,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
# Mixed-style chunked prefill
|
||||
if (
|
||||
self.is_mixed_chunk
|
||||
and self.running_batch is not None
|
||||
and not self.running_batch.is_empty()
|
||||
and not (new_batch.return_logprob or self.running_batch.return_logprob)
|
||||
):
|
||||
# TODO (lianmin): support return_logprob + mixed chunked prefill
|
||||
@@ -1121,7 +1130,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.running_batch.prepare_for_decode()
|
||||
new_batch.mix_with_running(self.running_batch)
|
||||
new_batch.decoding_reqs = self.running_batch.reqs
|
||||
self.running_batch = None
|
||||
self.running_batch = ScheduleBatch(
|
||||
reqs=[], batch_is_full=self.running_batch.batch_is_full
|
||||
)
|
||||
else:
|
||||
new_batch.decoding_reqs = None
|
||||
|
||||
@@ -1133,8 +1144,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
batch.filter_batch()
|
||||
if batch.is_empty():
|
||||
self.batch_is_full = False
|
||||
return None
|
||||
batch.batch_is_full = False
|
||||
return batch
|
||||
|
||||
# Check if decode out of memory
|
||||
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
||||
@@ -1158,7 +1169,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
)
|
||||
|
||||
if batch.batch_size() < initial_bs:
|
||||
self.batch_is_full = False
|
||||
batch.batch_is_full = False
|
||||
|
||||
# Update batch tensors
|
||||
batch.prepare_for_decode()
|
||||
@@ -1233,8 +1244,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if batch.is_empty():
|
||||
self.running_batch = None
|
||||
elif batch.forward_mode.is_extend():
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_idle():
|
||||
@@ -1375,9 +1384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
def flush_cache(self):
|
||||
"""Flush the memory pool and cache."""
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
):
|
||||
if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
|
||||
self.cur_batch = None
|
||||
self.last_batch = None
|
||||
self.tree_cache.reset()
|
||||
@@ -1403,7 +1410,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
logging.warning(
|
||||
f"Cache not flushed because there are pending requests. "
|
||||
f"#queue-req: {len(self.waiting_queue)}, "
|
||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
f"#running-req: {len(self.running_batch.reqs)}"
|
||||
)
|
||||
if_success = False
|
||||
return if_success
|
||||
@@ -1453,24 +1460,24 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
def abort_request(self, recv_req: AbortReq):
|
||||
# Delete requests in the waiting queue
|
||||
to_del = None
|
||||
to_del = []
|
||||
for i, req in enumerate(self.waiting_queue):
|
||||
if req.rid == recv_req.rid:
|
||||
to_del = i
|
||||
if req.rid.startswith(recv_req.rid):
|
||||
to_del.append(i)
|
||||
break
|
||||
|
||||
if to_del is not None:
|
||||
del self.waiting_queue[to_del]
|
||||
# Sort in reverse order to avoid index issues when deleting
|
||||
for i in sorted(to_del, reverse=True):
|
||||
req = self.waiting_queue.pop(i)
|
||||
logger.debug(f"Abort queued request. {req.rid=}")
|
||||
return
|
||||
|
||||
# Delete requests in the running batch
|
||||
if self.running_batch:
|
||||
for req in self.running_batch.reqs:
|
||||
if req.rid == recv_req.rid and not req.finished():
|
||||
logger.debug(f"Abort running request. {req.rid=}")
|
||||
req.to_abort = True
|
||||
break
|
||||
for req in self.running_batch.reqs:
|
||||
if req.rid.startswith(recv_req.rid) and not req.finished():
|
||||
logger.debug(f"Abort running request. {req.rid=}")
|
||||
req.to_abort = True
|
||||
return
|
||||
|
||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
Reference in New Issue
Block a user