[PD] Release initial code (#4654)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: Ying1123 <sqy1415@gmail.com>
Co-authored-by: merrymercy <lianminzheng@gmail.com>
Co-authored-by: makro
Co-authored-by: dhou-xai
This commit is contained in:
Byron Hsu
2025-03-21 14:47:47 -07:00
committed by GitHub
parent 417fc72f6f
commit c7c7dbebbe
10 changed files with 1410 additions and 9 deletions

View File

@@ -37,6 +37,19 @@ from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.disaggregation.decode import (
DecodePreallocQueue,
DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
DisaggregationMode,
ReqToMetadataIdxAllocator,
)
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -137,7 +150,11 @@ class EmbeddingBatchResult:
bid: int
class Scheduler(SchedulerOutputProcessorMixin):
class Scheduler(
SchedulerOutputProcessorMixin,
SchedulerDisaggregationDecodeMixin,
SchedulerDisaggregationPrefillMixin,
):
"""A scheduler that manages a tensor parallel GPU worker."""
def __init__(
@@ -389,6 +406,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
]
)
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
)
self.init_disaggregation()
def init_tokenizer(self):
server_args = self.server_args
@@ -489,6 +511,73 @@ class Scheduler(SchedulerOutputProcessorMixin):
},
)
def init_disaggregation(self):
if (
self.disaggregation_mode == DisaggregationMode.DECODE
): # *2 for the headroom.
buffer_size = (self.req_to_token_pool.size) * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
aux_dtype = torch.int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch.zeros(
(buffer_size, 16), dtype=aux_dtype, device="cpu"
)
metadata_buffers = [output_id_buffer]
# The decode requests polling kv cache
self.disagg_decode_transfer_queue = DecodeTransferQueue(
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
)
# The decode requests pending for pre-allocation
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
aux_dtype=aux_dtype,
scheduler=self,
transfer_queue=self.disagg_decode_transfer_queue,
tree_cache=self.tree_cache,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
)
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
# *2 for the headroom.
buffer_size = self.max_running_requests * 2
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
buffer_size
)
aux_dtype = torch.int32
# A list of metadata buffers. The shape is (b, metadata_size) where
# b corresponds to a max running requests. The last shape * dtype.itemsize
# should be larger than 64 bytes to work with RDMA, so we pad it.
output_id_buffer = torch.zeros(
(buffer_size, 16), dtype=aux_dtype, device="cpu"
)
metadata_buffers = [output_id_buffer]
self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
metadata_buffers=metadata_buffers,
aux_dtype=aux_dtype,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
)
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_infight_queue: List[Req] = []
@DynamicGradMode()
def event_loop_normal(self):
"""A normal scheduler loop."""
@@ -549,6 +638,70 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.last_batch = batch
@torch.no_grad()
def event_loop_normal_disagg_prefill(self):
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_pending_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.process_batch_result_disagg_prefill(batch, result)
if len(self.disagg_prefill_infight_queue) > 0:
self.process_disagg_prefill_infight_queue()
if batch is None and len(self.disagg_prefill_infight_queue) == 0:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
@torch.no_grad()
def event_loop_normal_disagg_decode(self):
"""A normal scheduler loop for decode worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
# polling and allocating kv cache
self.process_decode_queue()
batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch
if batch:
# Generate fake extend output.
if batch.forward_mode.is_extend():
# Note: Logprobs should be handled on the prefill engine.
self.stream_output(
batch.reqs, [False for _ in range(len(batch.reqs))]
)
else:
result = self.run_batch(batch)
self.process_batch_result(batch, result)
if batch is None and (
len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
self.last_batch = batch
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
if self.attn_tp_rank == 0:
@@ -778,10 +931,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
self._add_request_to_queue(req)
def _add_request_to_queue(self, req: Req):
self.waiting_queue.append(req)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_pending_queue.add(req)
def _extend_requests_to_queue(self, reqs: List[Req]):
self.waiting_queue.extend(reqs)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req)
else:
self.waiting_queue.append(req)
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.extend(reqs)
else:
self.waiting_queue.extend(reqs)
def handle_embedding_request(
self,
@@ -1814,10 +1977,18 @@ def run_scheduler_process(
"max_req_input_len": scheduler.max_req_input_len,
}
)
if scheduler.enable_overlap:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
if disaggregation_mode == DisaggregationMode.NULL:
if scheduler.enable_overlap:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()
elif disaggregation_mode == DisaggregationMode.PREFILL:
scheduler.event_loop_normal_disagg_prefill()
elif disaggregation_mode == DisaggregationMode.DECODE:
scheduler.event_loop_normal_disagg_decode()
except Exception:
traceback = get_exception_traceback()
logger.error(f"Scheduler hit an exception: {traceback}")