From d2e0881a34e8002fd242c7240bf94105829d7307 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Fri, 23 May 2025 12:03:05 -0700 Subject: [PATCH] [PD] support spec decode (#6507) Co-authored-by: SangBin Cho --- .pre-commit-config.yaml | 2 +- python/sglang/srt/disaggregation/decode.py | 12 +- .../srt/disaggregation/mooncake/conn.py | 3 + .../mooncake/transfer_engine.py | 3 +- python/sglang/srt/disaggregation/prefill.py | 13 ++ python/sglang/srt/managers/scheduler.py | 17 +++ test/srt/run_suite.py | 2 +- test/srt/test_disaggregation.py | 143 +++++++++++++++++- 8 files changed, 190 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6b02e38c4..337e5d7fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: hooks: - id: isort - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.2 + rev: v0.11.7 hooks: - id: ruff args: [--select=F401, --fixable=F401] diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 62f0dfd42..77c307ead 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import ( from sglang.srt.managers.schedule_batch import FINISH_ABORT from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo logger = logging.getLogger(__name__) @@ -76,6 +76,7 @@ class DecodePreallocQueue: self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator, + draft_token_to_kv_pool: Optional[KVCache], req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, metadata_buffers: List[torch.Tensor], aux_dtype: torch.dtype, @@ -91,6 +92,7 @@ class DecodePreallocQueue: self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache() + self.draft_token_to_kv_pool = draft_token_to_kv_pool self.is_mla_backend = is_mla_backend(self.token_to_kv_pool) self.aux_dtype = aux_dtype self.metadata_buffers = metadata_buffers @@ -119,6 +121,14 @@ class DecodePreallocQueue: self.token_to_kv_pool.get_contiguous_buf_infos() ) + if self.draft_token_to_kv_pool is not None: + draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = ( + self.draft_token_to_kv_pool.get_contiguous_buf_infos() + ) + kv_data_ptrs += draft_kv_data_ptrs + kv_data_lens += draft_kv_data_lens + kv_item_lens += draft_kv_item_lens + kv_args.kv_data_ptrs = kv_data_ptrs kv_args.kv_data_lens = kv_data_lens kv_args.kv_item_lens = kv_item_lens diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 9e894d1e3..57c426f25 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -51,6 +51,7 @@ def group_concurrent_contiguous( return src_groups, dst_groups +# prefill @dataclasses.dataclass class TransferKVChunk: room: int @@ -60,6 +61,7 @@ class TransferKVChunk: prefill_aux_index: Optional[int] +# decode @dataclasses.dataclass class TransferInfo: room: int @@ -93,6 +95,7 @@ class TransferInfo: ) +# decode @dataclasses.dataclass class KVArgsRegisterInfo: room: str diff --git a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py index 8c9f910b3..1f3c44bcc 100644 --- a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py +++ b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py @@ -61,7 +61,8 @@ class MooncakeTransferEngine: self, session_id: str, buffer: int, peer_buffer_address: int, length: int ) -> int: """Synchronously transfer data to the specified address.""" - + # the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair + # later: based on the cached queue pair to send data ret = self.engine.transfer_sync_write( session_id, buffer, peer_buffer_address, length ) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 4e346ea88..6075914e5 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -61,6 +61,7 @@ class PrefillBootstrapQueue: def __init__( self, token_to_kv_pool: KVCache, + draft_token_to_kv_pool: Optional[KVCache], req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, metadata_buffers: List[torch.Tensor], aux_dtype: torch.dtype, @@ -72,6 +73,8 @@ class PrefillBootstrapQueue: scheduler: Scheduler, ): self.token_to_kv_pool = token_to_kv_pool + self.draft_token_to_kv_pool = draft_token_to_kv_pool + self.is_mla_backend = is_mla_backend(token_to_kv_pool) self.aux_dtype = aux_dtype @@ -98,6 +101,16 @@ class PrefillBootstrapQueue: self.token_to_kv_pool.get_contiguous_buf_infos() ) + if self.draft_token_to_kv_pool is not None: + # We should also transfer draft model kv cache. The indices are + # always shared with a target model. + draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = ( + self.draft_token_to_kv_pool.get_contiguous_buf_infos() + ) + kv_data_ptrs += draft_kv_data_ptrs + kv_data_lens += draft_kv_data_lens + kv_item_lens += draft_kv_item_lens + kv_args.kv_data_ptrs = kv_data_ptrs kv_args.kv_data_lens = kv_data_lens kv_args.kv_item_lens = kv_item_lens diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 60f39b1a5..db282da51 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -591,6 +591,11 @@ class Scheduler( self.disagg_decode_prealloc_queue = DecodePreallocQueue( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + draft_token_to_kv_pool=( + None + if self.draft_worker is None + else self.draft_worker.model_runner.token_to_kv_pool + ), req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, metadata_buffers=metadata_buffers, aux_dtype=aux_dtype, @@ -624,6 +629,11 @@ class Scheduler( self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue( token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), + draft_token_to_kv_pool=( + None + if self.draft_worker is None + else self.draft_worker.model_runner.token_to_kv_pool + ), req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, metadata_buffers=metadata_buffers, aux_dtype=aux_dtype, @@ -1409,6 +1419,13 @@ class Scheduler( self.running_batch.batch_is_full = True break + if self.disaggregation_mode == DisaggregationMode.PREFILL: + # In prefill mode, prealloc queue and transfer queue can also take memory, + # so we need to check if the available size for the actual available size. + if len(adder.can_run_list) >= self.req_to_token_pool.available_size(): + self.running_batch.batch_is_full = True + break + req.init_next_round_input( None if prefix_computed else self.tree_cache, self.enable_hierarchical_cache, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3eb03d90f..dcb4970fc 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -115,7 +115,7 @@ suites = { # TestFile("test_deepep_intranode.py", 50), # TestFile("test_deepep_low_latency.py", 50), # TestFile("test_moe_deepep_eval_accuracy_large.py", 250), - # TestFile("test_disaggregation.py", 210), # disabled since we have different_tp test + TestFile("test_disaggregation.py", 210), TestFile("test_disaggregation_different_tp.py", 210), TestFile("test_full_deepseek_v3.py", 250), ], diff --git a/test/srt/test_disaggregation.py b/test/srt/test_disaggregation.py index c46a0c29c..3a9996a78 100644 --- a/test/srt/test_disaggregation.py +++ b/test/srt/test_disaggregation.py @@ -8,6 +8,8 @@ import requests from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -17,7 +19,9 @@ from sglang.test.test_utils import ( ) -class TestDisaggregationMooncake(CustomTestCase): +# skip the test because we have different_tp test +@unittest.skip("skip the test because we have different_tp test") +class TestDisaggregationAccuracy(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST @@ -65,6 +69,8 @@ class TestDisaggregationMooncake(CustomTestCase): str(cls.base_port + 100), "--tp", "4", + # "--disaggregation-ib-device", + # "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3", ] cls.process_prefill = popen_launch_pd_server( cls.model, @@ -87,6 +93,8 @@ class TestDisaggregationMooncake(CustomTestCase): "4", "--base-gpu-id", "4", + # "--disaggregation-ib-device", + # "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", ] cls.process_decode = popen_launch_pd_server( cls.model, @@ -136,5 +144,138 @@ class TestDisaggregationMooncake(CustomTestCase): self.assertGreater(metrics["accuracy"], 0.62) +class TestDisaggregationSpecAccuracy(CustomTestCase): + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST + cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST + cls.base_host = "127.0.0.1" + cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1]) + cls.lb_url = DEFAULT_URL_FOR_TEST + cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}" + cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}" + cls.spec_args = [ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + cls.draft_model, + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "4", + "--speculative-num-draft-tokens", + "16", + "--cuda-graph-max-bs", + "8", + ] + + run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) + run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) + + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + lb_command = [ + "python3", + "-m", + "sglang.srt.disaggregation.mini_lb", + "--prefill", + cls.prefill_url, + "--decode", + cls.decode_url, + "--host", + cls.base_host, + "--port", + str(cls.base_port), + ] + + print("Starting load balancer:", " ".join(lb_command)) + cls.process_lb = subprocess.Popen( + lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + cls.wait_server_ready(cls.lb_url + "/health") + + @classmethod + def wait_server_ready(cls, url, timeout=60): + start_time = time.perf_counter() + while True: + try: + response = requests.get(url) + if response.status_code == 200: + print(f"Server {url} is ready") + return + except Exception: + pass + + if time.perf_counter() - start_time > timeout: + raise RuntimeError(f"Server {url} failed to start in {timeout}s") + time.sleep(1) + + @classmethod + def start_prefill(cls): + + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--host", + cls.base_host, + "--port", + str(cls.base_port + 100), + "--tp", + "4", + # "--disaggregation-ib-device", + # "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3", + ] + cls.spec_args + + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--host", + cls.base_host, + "--port", + str(cls.base_port + 200), + "--tp", + "4", + "--base-gpu-id", + "4", + # "--disaggregation-ib-device", + # "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + ] + cls.spec_args + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=4, # TODO: 128 crashes the decode + host="http://127.0.0.1", + port=int(self.lb_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Evaluation metrics: {metrics}") + + self.assertGreater(metrics["accuracy"], 0.20) + + if __name__ == "__main__": unittest.main()