[PD] Release initial code (#4654)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: Ying1123 <sqy1415@gmail.com> Co-authored-by: merrymercy <lianminzheng@gmail.com> Co-authored-by: makro Co-authored-by: dhou-xai
This commit is contained in:
@@ -37,6 +37,19 @@ from torch.distributed import barrier
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
||||
from sglang.srt.disaggregation.decode import (
|
||||
DecodePreallocQueue,
|
||||
DecodeTransferQueue,
|
||||
SchedulerDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.disaggregation.prefill import (
|
||||
PrefillBootstrapQueue,
|
||||
SchedulerDisaggregationPrefillMixin,
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
ReqToMetadataIdxAllocator,
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -137,7 +150,11 @@ class EmbeddingBatchResult:
|
||||
bid: int
|
||||
|
||||
|
||||
class Scheduler(SchedulerOutputProcessorMixin):
|
||||
class Scheduler(
|
||||
SchedulerOutputProcessorMixin,
|
||||
SchedulerDisaggregationDecodeMixin,
|
||||
SchedulerDisaggregationPrefillMixin,
|
||||
):
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
def __init__(
|
||||
@@ -389,6 +406,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
]
|
||||
)
|
||||
|
||||
self.disaggregation_mode = DisaggregationMode(
|
||||
self.server_args.disaggregation_mode
|
||||
)
|
||||
self.init_disaggregation()
|
||||
|
||||
def init_tokenizer(self):
|
||||
server_args = self.server_args
|
||||
|
||||
@@ -489,6 +511,73 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
},
|
||||
)
|
||||
|
||||
def init_disaggregation(self):
|
||||
if (
|
||||
self.disaggregation_mode == DisaggregationMode.DECODE
|
||||
): # *2 for the headroom.
|
||||
buffer_size = (self.req_to_token_pool.size) * 2
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
aux_dtype = torch.int32
|
||||
# A list of metadata buffers. The shape is (b, metadata_size) where
|
||||
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
||||
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
||||
output_id_buffer = torch.zeros(
|
||||
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
||||
)
|
||||
metadata_buffers = [output_id_buffer]
|
||||
|
||||
# The decode requests polling kv cache
|
||||
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
)
|
||||
|
||||
# The decode requests pending for pre-allocation
|
||||
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
aux_dtype=aux_dtype,
|
||||
scheduler=self,
|
||||
transfer_queue=self.disagg_decode_transfer_queue,
|
||||
tree_cache=self.tree_cache,
|
||||
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
# *2 for the headroom.
|
||||
buffer_size = self.max_running_requests * 2
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
aux_dtype = torch.int32
|
||||
# A list of metadata buffers. The shape is (b, metadata_size) where
|
||||
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
||||
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
||||
output_id_buffer = torch.zeros(
|
||||
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
||||
)
|
||||
metadata_buffers = [output_id_buffer]
|
||||
|
||||
self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
|
||||
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
aux_dtype=aux_dtype,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
|
||||
)
|
||||
# The prefill requests that are in the middle of kv sending
|
||||
self.disagg_prefill_infight_queue: List[Req] = []
|
||||
|
||||
@DynamicGradMode()
|
||||
def event_loop_normal(self):
|
||||
"""A normal scheduler loop."""
|
||||
@@ -549,6 +638,70 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_prefill(self):
|
||||
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
self.waiting_queue.extend(
|
||||
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
||||
)
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result_disagg_prefill(batch, result)
|
||||
|
||||
if len(self.disagg_prefill_infight_queue) > 0:
|
||||
self.process_disagg_prefill_infight_queue()
|
||||
|
||||
if batch is None and len(self.disagg_prefill_infight_queue) == 0:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
||||
# Otherwise, it hangs under high concurrency
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_normal_disagg_decode(self):
|
||||
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
# polling and allocating kv cache
|
||||
self.process_decode_queue()
|
||||
batch = self.get_next_disagg_decode_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
if batch.forward_mode.is_extend():
|
||||
# Note: Logprobs should be handled on the prefill engine.
|
||||
self.stream_output(
|
||||
batch.reqs, [False for _ in range(len(batch.reqs))]
|
||||
)
|
||||
else:
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
|
||||
if batch is None and (
|
||||
len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
def recv_requests(self) -> List[Req]:
|
||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||
if self.attn_tp_rank == 0:
|
||||
@@ -778,10 +931,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self._add_request_to_queue(req)
|
||||
|
||||
def _add_request_to_queue(self, req: Req):
|
||||
self.waiting_queue.append(req)
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.disagg_prefill_pending_queue.add(req)
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req]):
|
||||
self.waiting_queue.extend(reqs)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.disagg_decode_prealloc_queue.add(req)
|
||||
|
||||
else:
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.disagg_decode_prealloc_queue.extend(reqs)
|
||||
else:
|
||||
self.waiting_queue.extend(reqs)
|
||||
|
||||
def handle_embedding_request(
|
||||
self,
|
||||
@@ -1814,10 +1977,18 @@ def run_scheduler_process(
|
||||
"max_req_input_len": scheduler.max_req_input_len,
|
||||
}
|
||||
)
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap()
|
||||
else:
|
||||
scheduler.event_loop_normal()
|
||||
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
||||
|
||||
if disaggregation_mode == DisaggregationMode.NULL:
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap()
|
||||
else:
|
||||
scheduler.event_loop_normal()
|
||||
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
scheduler.event_loop_normal_disagg_prefill()
|
||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||
scheduler.event_loop_normal_disagg_decode()
|
||||
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"Scheduler hit an exception: {traceback}")
|
||||
|
||||
Reference in New Issue
Block a user