From bf98d2e3770fd23d2e6cb7b95c7a9af82e2ddc7e Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 21 Apr 2025 12:12:56 -0700 Subject: [PATCH] [PD] Support prefill overlap + Ensure no race condition (#5609) --- python/sglang/srt/disaggregation/prefill.py | 89 +++++++++++++++++--- python/sglang/srt/managers/schedule_batch.py | 5 ++ python/sglang/srt/managers/scheduler.py | 5 +- python/sglang/srt/server_args.py | 2 - scripts/playground/disaggregation/cli.py | 24 +++++- 5 files changed, 107 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 568c9973e..48743ef1f 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -20,6 +20,7 @@ Life cycle of a request in the prefill server from __future__ import annotations import logging +from collections import deque from typing import TYPE_CHECKING, List, Optional import torch @@ -204,6 +205,40 @@ class SchedulerDisaggregationPrefillMixin: # Otherwise, it hangs under high concurrency self.running_batch.batch_is_full = False + @torch.no_grad() + def event_loop_overlap_disagg_prefill(self): + self.result_queue = deque() + + while True: + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) + self.waiting_queue.extend( + self.disagg_prefill_pending_queue.pop_bootstrapped() + ) + self.process_prefill_chunk() + batch = self.get_new_batch_prefill() + self.cur_batch = batch + + if batch: + result = self.run_batch(batch) + self.result_queue.append((batch.copy(), result)) + + if self.last_batch: + tmp_batch, tmp_result = self.result_queue.popleft() + self.process_batch_result_disagg_prefill(tmp_batch, tmp_result) + + if len(self.disagg_prefill_inflight_queue) > 0: + self.process_disagg_prefill_inflight_queue() + + if batch is None and len(self.disagg_prefill_inflight_queue) == 0: + self.check_memory() + self.new_token_ratio = self.init_new_token_ratio + + self.last_batch = batch + # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it + # Otherwise, it hangs under high concurrency + self.running_batch.batch_is_full = False + def process_batch_result_disagg_prefill( self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult ) -> None: @@ -212,7 +247,26 @@ class SchedulerDisaggregationPrefillMixin: Adapted from process_batch_result_prefill """ - next_token_ids = result.next_token_ids.tolist() + ( + logits_output, + next_token_ids, + extend_input_len_per_req, + extend_logprob_start_len_per_req, + bid, + ) = ( + result.logits_output, + result.next_token_ids, + result.extend_input_len_per_req, + result.extend_logprob_start_len_per_req, + result.bid, + ) + + # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue + if self.enable_overlap: + # wait + _, next_token_ids = self.tp_worker.resolve_batch_result(bid) + else: + next_token_ids = result.next_token_ids.tolist() for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True): req: Req @@ -226,12 +280,8 @@ class SchedulerDisaggregationPrefillMixin: # being chunked reqs' prefill is not finished req.is_chunked -= 1 - # TODO: Not sure if this is necessary - if batch.next_batch_sampling_info: - batch.next_batch_sampling_info.update_regex_vocab_mask() - # We need to remove this for overlap schedule. - self.current_stream.synchronize() - batch.next_batch_sampling_info.sampling_info_done.set() + if self.enable_overlap: + self.send_kv_chunk(req, end_idx=req.tmp_end_idx) def process_disagg_prefill_inflight_queue(self: Scheduler) -> None: """ @@ -276,20 +326,37 @@ class SchedulerDisaggregationPrefillMixin: # only finished requests to running_batch. self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req) self.tree_cache.cache_unfinished_req(self.chunked_req) - self.send_kv_chunk(self.chunked_req) + if ( + self.enable_overlap + ): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved + self.chunked_req.tmp_end_idx = min( + len(self.chunked_req.fill_ids), + len(self.chunked_req.origin_input_ids), + ) + else: + self.send_kv_chunk(self.chunked_req) # chunked request keeps its rid but will get a new req_pool_idx self.req_to_token_pool.free(self.chunked_req.req_pool_idx) self.running_batch.batch_is_full = False def send_kv_chunk( - self: Scheduler, req: Req, token_id: Optional[int] = None + self: Scheduler, + req: Req, + token_id: Optional[int] = None, + end_idx: Optional[int] = None, ) -> None: """ Send a prefilled chunk to the decode server """ page_size = self.token_to_kv_pool_allocator.page_size start_idx = req.start_send_idx - end_idx = min(len(req.fill_ids), len(req.origin_input_ids)) + # if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule, + # the resolved length is not the same as fill_ids's length + end_idx = ( + end_idx + if end_idx is not None + else min(len(req.fill_ids), len(req.origin_input_ids)) + ) last_chunk = token_id is not None if (not last_chunk) and ( @@ -302,7 +369,7 @@ class SchedulerDisaggregationPrefillMixin: req.start_send_idx = end_idx kv_indices = ( - self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx] + self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx] .cpu() .numpy() ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ddacb7441..baf3adeb1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -539,6 +539,11 @@ class Req: # The first output_id transferred from prefill instance. self.transferred_output_id: Optional[int] = None + # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap + # This is because kv is not ready in `process_prefill_chunk`. + # We use `tmp_end_idx` to store the end index of the kv cache to send. + self.tmp_end_idx: int = -1 + @property def seqlen(self): return len(self.origin_input_ids) + len(self.output_ids) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 53c5ea4f9..303c22059 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2014,7 +2014,10 @@ def run_scheduler_process( else: scheduler.event_loop_normal() elif disaggregation_mode == DisaggregationMode.PREFILL: - scheduler.event_loop_normal_disagg_prefill() + if scheduler.enable_overlap: + scheduler.event_loop_overlap_disagg_prefill() + else: + scheduler.event_loop_normal_disagg_prefill() elif disaggregation_mode == DisaggregationMode.DECODE: if scheduler.enable_overlap: scheduler.event_loop_overlap_disagg_decode() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index dee87dbb2..ec3f11437 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -388,8 +388,6 @@ class ServerArgs: if self.disaggregation_mode == "prefill": self.disable_cuda_graph = True logger.warning("Cuda graph is disabled for prefill server") - self.disable_overlap_schedule = True - logger.warning("Overlap scheduler is disabled for prefill server") elif self.disaggregation_mode == "decode": self.disable_radix_cache = True logger.warning("KV cache is forced as chunk cache for decode server") diff --git a/scripts/playground/disaggregation/cli.py b/scripts/playground/disaggregation/cli.py index 5bcc5629e..721a6dd5e 100644 --- a/scripts/playground/disaggregation/cli.py +++ b/scripts/playground/disaggregation/cli.py @@ -1,13 +1,29 @@ -prompt = [0] * 431 - import json import requests +prompt = """ +According to CNBC's Faber, the investors present on the call interpreted this statement as an indication of an upcoming funding round. While speculative, Faber believes the funding round could be as large as $25 billion, and bestow a valuation of between $150 billion and $200 billion on xAI. + +For the benefit of those who might not be aware, xAI recently acquired the social media platform X in an all-stock deal that valued the former at $80 billion and the latter at $33 billion, inclusive of $12 billion in liabilities. This meant that the deal bestowed a gross valuation of $45 billion on X before factoring in its debt load of $12 billion. + +Bear in mind that Elon Musk took X (then called Twitter) private back in 2022 in a $44 billion deal. Since then, Musk has managed to stem X's cash bleed, with the company reportedly generating $1.2 billion in adjusted EBITDA in 2024. + +According to the investors present on the call, xAI is currently generating around $1 billion in annual revenue. This contrasts sharply with the erstwhile muted expectations of many investors, who did not expect the startup to generate any material revenue this year. + +Elsewhere, Faber also alludes to the fact that xAI is already working on its next big training supercluster, officially dubbed the Colossus 2, which is expected to eventually house as many as 1 million NVIDIA GPUs at a cost of between $35 billion and $40 billion. + + +Even though xAI's Grok LLM is already largely comparable with OpenAI's cutting-edge models, the Colossus 2 would significantly up the ante, and could feasibly challenge OpenAI's apex position in the AI sphere. + +Give your honest take on the above text: +""" + response = requests.post( "http://0.0.0.0:8000/generate", - json={"input_ids": [prompt] * 32, "sampling_params": {"temperature": 0}}, + json={"text": prompt, "sampling_params": {"temperature": 0}}, ) -# print("Response content (raw):", response.content) +response_json = response.json() +print(response_json["text"])