Add a new event loop (#1677)
This commit is contained in:
@@ -736,6 +736,10 @@ class ScheduleBatch:
|
|||||||
self.input_ids = self.output_ids
|
self.input_ids = self.output_ids
|
||||||
self.seq_lens.add_(1)
|
self.seq_lens.add_(1)
|
||||||
self.output_ids = None
|
self.output_ids = None
|
||||||
|
if self.sampling_info.penalizer_orchestrator:
|
||||||
|
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
|
self.input_ids
|
||||||
|
)
|
||||||
|
|
||||||
# Alloc mem
|
# Alloc mem
|
||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import deque
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
@@ -192,9 +193,20 @@ class Scheduler:
|
|||||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||||
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
|
||||||
|
|
||||||
|
if self.server_args.enable_overlap_schedule:
|
||||||
|
|
||||||
|
def cache_finished_req(req):
|
||||||
|
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
|
||||||
|
self.tree_cache.cache_finished_req(req, free_delta=free_delta)
|
||||||
|
|
||||||
|
else:
|
||||||
|
cache_finished_req = self.tree_cache.cache_finished_req
|
||||||
|
self.cache_finished_req = cache_finished_req
|
||||||
|
|
||||||
# Init running status
|
# Init running status
|
||||||
self.waiting_queue: List[Req] = []
|
self.waiting_queue: List[Req] = []
|
||||||
self.running_batch: Optional[ScheduleBatch] = None
|
self.running_batch: Optional[ScheduleBatch] = None
|
||||||
|
self.cur_batch: Optional[ScheduleBatch] = None
|
||||||
self.decode_forward_ct = 0
|
self.decode_forward_ct = 0
|
||||||
self.stream_interval = server_args.stream_interval
|
self.stream_interval = server_args.stream_interval
|
||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
@@ -279,6 +291,32 @@ class Scheduler:
|
|||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def event_loop_overlap(self):
|
||||||
|
result_queue = deque()
|
||||||
|
|
||||||
|
self.last_batch = None
|
||||||
|
self.running_batch = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
recv_reqs = self.recv_requests()
|
||||||
|
self.process_input_requests(recv_reqs)
|
||||||
|
|
||||||
|
batch = self.get_next_batch_to_run()
|
||||||
|
self.cur_batch = batch
|
||||||
|
if batch:
|
||||||
|
result = self.run_batch(batch)
|
||||||
|
result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
|
if self.last_batch:
|
||||||
|
tmp_batch, tmp_result = result_queue.popleft()
|
||||||
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
|
elif batch is None:
|
||||||
|
self.check_memory()
|
||||||
|
self.new_token_ratio = global_config.init_new_token_ratio
|
||||||
|
|
||||||
|
self.last_batch = batch
|
||||||
|
|
||||||
def recv_requests(self):
|
def recv_requests(self):
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
recv_reqs = []
|
recv_reqs = []
|
||||||
@@ -705,11 +743,6 @@ class Scheduler:
|
|||||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
logits_output, next_token_ids = result
|
logits_output, next_token_ids = result
|
||||||
if batch.sampling_info.penalizer_orchestrator:
|
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
|
||||||
next_token_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if logits_output.next_token_logprobs is not None:
|
if logits_output.next_token_logprobs is not None:
|
||||||
@@ -742,7 +775,7 @@ class Scheduler:
|
|||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.cache_finished_req(req)
|
||||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
@@ -771,7 +804,7 @@ class Scheduler:
|
|||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.cache_finished_req(req)
|
||||||
else:
|
else:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
@@ -779,10 +812,6 @@ class Scheduler:
|
|||||||
|
|
||||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||||
logits_output, next_token_ids = result
|
logits_output, next_token_ids = result
|
||||||
if batch.sampling_info.penalizer_orchestrator:
|
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
|
||||||
next_token_ids
|
|
||||||
)
|
|
||||||
self.num_generated_tokens += len(batch.reqs)
|
self.num_generated_tokens += len(batch.reqs)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
@@ -796,6 +825,9 @@ class Scheduler:
|
|||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
|
if self.server_args.enable_overlap_schedule and req.finished():
|
||||||
|
continue
|
||||||
|
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
@@ -806,7 +838,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.cache_finished_req(req)
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
req.output_token_logprobs.append(
|
req.output_token_logprobs.append(
|
||||||
@@ -1027,7 +1059,7 @@ class Scheduler:
|
|||||||
for req in self.running_batch.reqs:
|
for req in self.running_batch.reqs:
|
||||||
if req.rid == recv_req.rid and not req.finished():
|
if req.rid == recv_req.rid and not req.finished():
|
||||||
req.finished_reason = FINISH_ABORT()
|
req.finished_reason = FINISH_ABORT()
|
||||||
self.tree_cache.cache_finished_req(req)
|
self.cache_finished_req(req)
|
||||||
break
|
break
|
||||||
|
|
||||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||||
@@ -1072,7 +1104,10 @@ def run_scheduler_process(
|
|||||||
try:
|
try:
|
||||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
|
||||||
pipe_writer.send("ready")
|
pipe_writer.send("ready")
|
||||||
scheduler.event_loop_normal()
|
if server_args.enable_overlap_schedule:
|
||||||
|
scheduler.event_loop_overlap()
|
||||||
|
else:
|
||||||
|
scheduler.event_loop_normal()
|
||||||
except Exception:
|
except Exception:
|
||||||
msg = get_exception_traceback()
|
msg = get_exception_traceback()
|
||||||
logger.error(msg)
|
logger.error(msg)
|
||||||
|
|||||||
@@ -38,12 +38,16 @@ class ChunkCache(BasePrefixCache):
|
|||||||
max_prefix_len = len(key)
|
max_prefix_len = len(key)
|
||||||
return entry.value[:max_prefix_len], entry
|
return entry.value[:max_prefix_len], entry
|
||||||
|
|
||||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
def cache_finished_req(
|
||||||
|
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
|
||||||
|
):
|
||||||
if token_ids is None:
|
if token_ids is None:
|
||||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||||
|
else:
|
||||||
|
token_id_len = len(token_ids)
|
||||||
|
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : len(token_ids)
|
req.req_pool_idx, : token_id_len + free_delta
|
||||||
]
|
]
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
self.token_to_kv_pool.free(kv_indices)
|
self.token_to_kv_pool.free(kv_indices)
|
||||||
@@ -53,10 +57,12 @@ class ChunkCache(BasePrefixCache):
|
|||||||
|
|
||||||
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||||
if token_ids is None:
|
if token_ids is None:
|
||||||
token_ids = req.fill_ids
|
token_id_len = len(req.fill_ids)
|
||||||
|
else:
|
||||||
|
token_id_len = len(token_ids)
|
||||||
|
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : len(token_ids)
|
req.req_pool_idx, :token_id_len
|
||||||
]
|
]
|
||||||
|
|
||||||
if req.rid not in self.entries:
|
if req.rid not in self.entries:
|
||||||
|
|||||||
@@ -97,22 +97,38 @@ class RadixCache(BasePrefixCache):
|
|||||||
value = [x for x in key]
|
value = [x for x in key]
|
||||||
return self._insert_helper(self.root_node, key, value)
|
return self._insert_helper(self.root_node, key, value)
|
||||||
|
|
||||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
def cache_finished_req(
|
||||||
|
self, req: Req, token_ids: Optional[List[int]] = None, free_delta: int = 0
|
||||||
|
):
|
||||||
"""Cache request when it finishes."""
|
"""Cache request when it finishes."""
|
||||||
|
if self.disable:
|
||||||
|
if token_ids is None:
|
||||||
|
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||||
|
else:
|
||||||
|
token_ids_len = len(token_ids)
|
||||||
|
|
||||||
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
|
req.req_pool_idx, : token_ids_len + free_delta
|
||||||
|
]
|
||||||
|
self.token_to_kv_pool.free(kv_indices)
|
||||||
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
return
|
||||||
|
|
||||||
if token_ids is None:
|
if token_ids is None:
|
||||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : len(token_ids)
|
req.req_pool_idx, : len(token_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.disable:
|
|
||||||
self.token_to_kv_pool.free(kv_indices)
|
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Radix Cache takes one ref in memory pool
|
# Radix Cache takes one ref in memory pool
|
||||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
||||||
|
if free_delta:
|
||||||
|
self.token_to_kv_pool.free(
|
||||||
|
self.req_to_token_pool.req_to_token[
|
||||||
|
req.req_pool_idx, len(token_ids) : len(token_ids) + 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Remove req slot release the cache lock
|
# Remove req slot release the cache lock
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|||||||
@@ -528,6 +528,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|||||||
kill_child_process(pid, including_parent=False)
|
kill_child_process(pid, including_parent=False)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# print(f"{res.json()=}")
|
||||||
|
|
||||||
logger.info("The server is fired up and ready to roll!")
|
logger.info("The server is fired up and ready to roll!")
|
||||||
if pipe_finish_writer is not None:
|
if pipe_finish_writer is not None:
|
||||||
pipe_finish_writer.send("ready")
|
pipe_finish_writer.send("ready")
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ class ServerArgs:
|
|||||||
disable_custom_all_reduce: bool = False
|
disable_custom_all_reduce: bool = False
|
||||||
disable_mla: bool = False
|
disable_mla: bool = False
|
||||||
disable_penalizer: bool = False
|
disable_penalizer: bool = False
|
||||||
|
enable_overlap_schedule: bool = False
|
||||||
enable_mixed_chunk: bool = False
|
enable_mixed_chunk: bool = False
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
max_torch_compile_bs: int = 32
|
max_torch_compile_bs: int = 32
|
||||||
@@ -572,6 +573,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable the logit penalizer (e.g., frequency and repetition penalty).",
|
help="Disable the logit penalizer (e.g., frequency and repetition penalty).",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-overlap-schedule",
|
||||||
|
action="store_true",
|
||||||
|
help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-mixed-chunk",
|
"--enable-mixed-chunk",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
|||||||
@@ -584,6 +584,7 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
|
|||||||
|
|
||||||
def configure_logger(server_args, prefix: str = ""):
|
def configure_logger(server_args, prefix: str = ""):
|
||||||
format = f"[%(asctime)s{prefix}] %(message)s"
|
format = f"[%(asctime)s{prefix}] %(message)s"
|
||||||
|
# format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=getattr(logging, server_args.log_level.upper()),
|
level=getattr(logging, server_args.log_level.upper()),
|
||||||
format=format,
|
format=format,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ suites = {
|
|||||||
"test_json_constrained.py",
|
"test_json_constrained.py",
|
||||||
"test_large_max_new_tokens.py",
|
"test_large_max_new_tokens.py",
|
||||||
"test_openai_server.py",
|
"test_openai_server.py",
|
||||||
|
"test_overlap_schedule.py",
|
||||||
"test_pytorch_sampling_backend.py",
|
"test_pytorch_sampling_backend.py",
|
||||||
"test_retract_decode.py",
|
"test_retract_decode.py",
|
||||||
"test_server_args.py",
|
"test_server_args.py",
|
||||||
|
|||||||
65
test/srt/test_overlap_schedule.py
Normal file
65
test/srt/test_overlap_schedule.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
SGLANG_IS_IN_CI=true python3 -m unittest test_overlap_schedule.TestOverlapSchedule.test_radix_attention_chunked_prefill
|
||||||
|
SGLANG_IS_IN_CI=true python3 test_overlap_schedule.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOverlapSchedule(unittest.TestCase):
|
||||||
|
def run_mmlu(self, disable_radix_cache, chunked_prefill_size=32):
|
||||||
|
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
|
||||||
|
if disable_radix_cache:
|
||||||
|
other_args += ["--disable-radix-cache"]
|
||||||
|
other_args += ["--enable-overlap-schedule"]
|
||||||
|
|
||||||
|
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
process = popen_launch_server(
|
||||||
|
model,
|
||||||
|
base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=other_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=base_url,
|
||||||
|
model=model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
metrics = run_eval(args)
|
||||||
|
assert metrics["score"] >= 0.65
|
||||||
|
finally:
|
||||||
|
kill_child_process(process.pid)
|
||||||
|
|
||||||
|
def test_no_radix_attention_chunked_prefill(self):
|
||||||
|
self.run_mmlu(disable_radix_cache=True, chunked_prefill_size=32)
|
||||||
|
|
||||||
|
def test_no_radix_attention_no_chunked_prefill(self):
|
||||||
|
self.run_mmlu(disable_radix_cache=True, chunked_prefill_size=-1)
|
||||||
|
|
||||||
|
def test_radix_attention_chunked_prefill(self):
|
||||||
|
self.run_mmlu(disable_radix_cache=False, chunked_prefill_size=32)
|
||||||
|
|
||||||
|
def test_radix_attention_no_chunked_prefill(self):
|
||||||
|
self.run_mmlu(disable_radix_cache=False, chunked_prefill_size=-1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
# @unittest.skip("did not support")
|
||||||
Reference in New Issue
Block a user