Support page size > 1 (#4356)

This commit is contained in:
Lianmin Zheng
2025-03-12 22:22:39 -07:00
committed by GitHub
parent 2f6bacee03
commit c76040e31b
23 changed files with 877 additions and 284 deletions

View File

@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
crash_on_warnings,
get_bool_env_var,
get_zmq_socket,
kill_itself_when_parent_died,
pyspy_dump_schedulers,
set_gpu_proc_affinity,
set_random_seed,
@@ -159,6 +160,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
)
self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.page_size = server_args.page_size
# Distributed rank info
self.dp_size = server_args.dp_size
@@ -265,20 +267,23 @@ class Scheduler(SchedulerOutputProcessorMixin):
f"context_len={self.model_config.context_len}"
)
# Init memory pool and cache
self.init_memory_pool_and_cache()
# Init running status
self.waiting_queue: List[Req] = []
# The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
# The current forward batch
self.cur_batch: Optional[ScheduleBatch] = None
# The current forward batch
# The last forward batch
self.last_batch: Optional[ScheduleBatch] = None
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.num_prefill_tokens = 0
self.last_decode_stats_tic = time.time()
self.last_prefill_stats_tic = time.time()
self.return_health_check_ct = 0
self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
@@ -307,7 +312,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Init schedule policy and new token estimation
self.policy = SchedulePolicy(
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache
self.schedule_policy,
self.tree_cache,
self.enable_hierarchical_cache,
)
assert (
server_args.schedule_conservativeness >= 0
@@ -327,11 +334,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
) / global_config.default_new_token_ratio_decay_steps
self.new_token_ratio = self.init_new_token_ratio
# Tell whether the current running batch is full so that we can skip
# the check of whether to prefill new requests.
# This is an optimization to reduce the overhead of the prefill check.
self.batch_is_full = False
# Init watchdog thread
self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True)
@@ -437,6 +439,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
@@ -458,6 +461,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
@@ -487,7 +491,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
# When the server is idle, so self-check and re-init some states
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
@@ -527,7 +531,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
)
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
# When the server is idle, so self-check and re-init some states
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
@@ -588,7 +592,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
for recv_req in recv_reqs:
# If it is a health check generation request and there are running requests, ignore it.
if is_health_check_generate_req(recv_req) and (
self.chunked_req is not None or self.running_batch is not None
self.chunked_req is not None or not self.running_batch.is_empty()
):
self.return_health_check_ct += 1
continue
@@ -812,6 +816,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
can_run_list: List[Req],
running_bs: int,
):
gap_latency = time.time() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.time()
self.last_input_throughput = self.num_prefill_tokens / gap_latency
self.num_prefill_tokens = 0
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
@@ -847,7 +856,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.last_decode_stats_tic = time.time()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
num_running_reqs = len(self.running_batch.reqs)
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
@@ -911,8 +920,10 @@ class Scheduler(SchedulerOutputProcessorMixin):
)
if memory_leak:
msg = (
"KV cache pool leak detected!"
"KV cache pool leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
)
warnings.warn(msg)
if crash_on_warnings():
@@ -938,7 +949,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
num_running_reqs = len(self.running_batch.reqs)
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
@@ -956,20 +967,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.tree_cache.cache_unfinished_req(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.batch_is_full = False
self.running_batch.batch_is_full = False
# Filter batch
last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch()
if self.last_batch.batch_size() < last_bs:
self.batch_is_full = False
self.running_batch.batch_is_full = False
# Merge the new batch into the running batch
if not self.last_batch.is_empty():
if self.running_batch is None:
if self.running_batch.is_empty():
self.running_batch = self.last_batch
else:
# merge running_batch with prefill batch
# Merge running_batch with prefill batch
self.running_batch.merge_batch(self.last_batch)
new_batch = self.get_new_batch_prefill()
@@ -978,11 +989,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
ret = new_batch
else:
# Run decode
if self.running_batch is None:
ret = None
else:
if not self.running_batch.is_empty():
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch
ret = self.running_batch if not self.running_batch.is_empty() else None
else:
ret = None
# Handle DP attention
if self.server_args.enable_dp_attention:
@@ -997,13 +1008,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Handle the cases where prefill is not allowed
if (
self.batch_is_full or len(self.waiting_queue) == 0
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
) and self.chunked_req is None:
return None
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
running_bs = len(self.running_batch.reqs)
if running_bs >= self.max_running_requests:
self.batch_is_full = True
self.running_batch.batch_is_full = True
return None
if self.enable_hierarchical_cache:
@@ -1025,17 +1036,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
running_bs if self.is_mixed_chunk else 0,
)
is_chunked = self.chunked_req is not None
if is_chunked:
if self.chunked_req is not None:
self.chunked_req.init_next_round_input()
self.chunked_req = adder.add_chunked_req(self.chunked_req)
if self.lora_paths:
lora_set = (
set([req.lora_path for req in self.running_batch.reqs])
if self.running_batch is not None
else set([])
)
lora_set = set([req.lora_path for req in self.running_batch.reqs])
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
@@ -1047,11 +1054,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
)
> self.max_loras_per_batch
):
self.batch_is_full = True
self.running_batch.batch_is_full = True
break
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
self.batch_is_full = True
self.running_batch.batch_is_full = True
break
req.init_next_round_input(
@@ -1066,12 +1073,14 @@ class Scheduler(SchedulerOutputProcessorMixin):
if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache:
# Set batch_is_full after making sure there are requests that can be served
self.batch_is_full = len(adder.can_run_list) > 0 or (
self.running_batch.batch_is_full = len(
adder.can_run_list
) > 0 or (
self.running_batch is not None
and not self.running_batch.is_empty()
)
else:
self.batch_is_full = True
self.running_batch.batch_is_full = True
break
# Update waiting queue
@@ -1112,7 +1121,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Mixed-style chunked prefill
if (
self.is_mixed_chunk
and self.running_batch is not None
and not self.running_batch.is_empty()
and not (new_batch.return_logprob or self.running_batch.return_logprob)
):
# TODO (lianmin): support return_logprob + mixed chunked prefill
@@ -1121,7 +1130,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.running_batch.prepare_for_decode()
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None
self.running_batch = ScheduleBatch(
reqs=[], batch_is_full=self.running_batch.batch_is_full
)
else:
new_batch.decoding_reqs = None
@@ -1133,8 +1144,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
batch.filter_batch()
if batch.is_empty():
self.batch_is_full = False
return None
batch.batch_is_full = False
return batch
# Check if decode out of memory
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
@@ -1158,7 +1169,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
)
if batch.batch_size() < initial_bs:
self.batch_is_full = False
batch.batch_is_full = False
# Update batch tensors
batch.prepare_for_decode()
@@ -1233,8 +1244,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_idle():
@@ -1375,9 +1384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
def flush_cache(self):
"""Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
self.cur_batch = None
self.last_batch = None
self.tree_cache.reset()
@@ -1403,7 +1410,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
logging.warning(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.waiting_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
f"#running-req: {len(self.running_batch.reqs)}"
)
if_success = False
return if_success
@@ -1453,24 +1460,24 @@ class Scheduler(SchedulerOutputProcessorMixin):
def abort_request(self, recv_req: AbortReq):
# Delete requests in the waiting queue
to_del = None
to_del = []
for i, req in enumerate(self.waiting_queue):
if req.rid == recv_req.rid:
to_del = i
if req.rid.startswith(recv_req.rid):
to_del.append(i)
break
if to_del is not None:
del self.waiting_queue[to_del]
# Sort in reverse order to avoid index issues when deleting
for i in sorted(to_del, reverse=True):
req = self.waiting_queue.pop(i)
logger.debug(f"Abort queued request. {req.rid=}")
return
# Delete requests in the running batch
if self.running_batch:
for req in self.running_batch.reqs:
if req.rid == recv_req.rid and not req.finished():
logger.debug(f"Abort running request. {req.rid=}")
req.to_abort = True
break
for req in self.running_batch.reqs:
if req.rid.startswith(recv_req.rid) and not req.finished():
logger.debug(f"Abort running request. {req.rid=}")
req.to_abort = True
return
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()