diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 77c307ead..92531aa73 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -36,6 +36,7 @@ from sglang.srt.disaggregation.utils import ( DisaggregationMode, FakeBootstrapHost, KVClassType, + MetadataBuffers, ReqToMetadataIdxAllocator, TransferBackend, get_kv_class, @@ -78,8 +79,7 @@ class DecodePreallocQueue: token_to_kv_pool_allocator: TokenToKVPoolAllocator, draft_token_to_kv_pool: Optional[KVCache], req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, - metadata_buffers: List[torch.Tensor], - aux_dtype: torch.dtype, + metadata_buffers: MetadataBuffers, scheduler: Scheduler, transfer_queue: DecodeTransferQueue, tree_cache: BasePrefixCache, @@ -94,7 +94,6 @@ class DecodePreallocQueue: self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache() self.draft_token_to_kv_pool = draft_token_to_kv_pool self.is_mla_backend = is_mla_backend(self.token_to_kv_pool) - self.aux_dtype = aux_dtype self.metadata_buffers = metadata_buffers self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.scheduler = scheduler @@ -133,15 +132,9 @@ class DecodePreallocQueue: kv_args.kv_data_lens = kv_data_lens kv_args.kv_item_lens = kv_item_lens - kv_args.aux_data_ptrs = [ - output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers - ] - kv_args.aux_data_lens = [ - metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers - ] - kv_args.aux_item_lens = [ - metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers - ] + kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( + self.metadata_buffers.get_buf_infos() + ) kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) @@ -211,7 +204,18 @@ class DecodePreallocQueue: indices_to_remove = set() allocatable_tokens = self._allocatable_tokens() + # First, remove all failed requests from the queue for i, decode_req in enumerate(self.queue): + if isinstance(decode_req.req.finished_reason, FINISH_ABORT): + self.scheduler.stream_output( + [decode_req.req], decode_req.req.return_logprob + ) + indices_to_remove.add(i) + + for i, decode_req in enumerate(self.queue): + if i in indices_to_remove: + continue + if not decode_req.waiting_for_input: continue @@ -331,7 +335,7 @@ class DecodeTransferQueue: self, gloo_group: ProcessGroup, req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, - metadata_buffers: torch.Tensor, + metadata_buffers: MetadataBuffers, scheduler: Scheduler, tree_cache: BasePrefixCache, ): @@ -342,11 +346,11 @@ class DecodeTransferQueue: self.scheduler = scheduler self.tree_cache = tree_cache - def add(self, req_conn: DecodeRequest) -> None: - self.queue.append(req_conn) + def add(self, decode_req: DecodeRequest) -> None: + self.queue.append(decode_req) - def extend(self, req_conns) -> None: - self.queue.extend(req_conns) + def extend(self, decode_reqs: List[DecodeRequest]) -> None: + self.queue.extend(decode_reqs) def pop_transferred(self) -> List[DecodeRequest]: if not self.queue: @@ -356,14 +360,6 @@ class DecodeTransferQueue: [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group ) - # First, remove all failed requests from the queue - for i, decode_req in enumerate(self.queue): - if isinstance(decode_req.req.finished_reason, FINISH_ABORT): - self.scheduler.stream_output( - [decode_req.req], decode_req.req.return_logprob - ) - indices_to_remove.add(i) - transferred_reqs = [] indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): @@ -387,16 +383,37 @@ class DecodeTransferQueue: indices_to_remove.add(i) continue elif poll == KVPoll.Success: - # pop and push it to waiting queue + idx = decode_req.metadata_buffer_index - assert len(decode_req.req.output_ids) == 0 - output_id_buffer = self.metadata_buffers[0] - # the last dimension is padded by the same values. - output_id = output_id_buffer[idx][0].item() - assert len(decode_req.req.output_ids) == 0 - assert decode_req.req.transferred_output_id is None - decode_req.req.transferred_output_id = output_id - transferred_reqs.append(decode_req) + ( + output_id, + output_token_logprobs_val, + output_token_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, + ) = self.metadata_buffers.get_buf(idx) + + decode_req.req.output_ids.append(output_id[0].item()) + + if decode_req.req.return_logprob: + decode_req.req.output_token_logprobs_val.append( + output_token_logprobs_val[0].item() + ) + decode_req.req.output_token_logprobs_idx.append( + output_token_logprobs_idx[0].item() + ) + decode_req.req.output_top_logprobs_val.append( + output_top_logprobs_val[ + : decode_req.req.top_logprobs_num + ].tolist() + ) + decode_req.req.output_top_logprobs_idx.append( + output_top_logprobs_idx[ + : decode_req.req.top_logprobs_num + ].tolist() + ) + + transferred_reqs.append(decode_req.req) indices_to_remove.add(i) elif poll in [ KVPoll.Bootstrapping, @@ -451,7 +468,9 @@ class SchedulerDisaggregationDecodeMixin: # Generate fake extend output. if batch.forward_mode.is_extend(): # Note: Logprobs should be handled on the prefill engine. - self.stream_output(batch.reqs, False) + self.stream_output( + batch.reqs, any(req.return_logprob for req in batch.reqs) + ) if prepare_dp_attn_flag: self._prepare_idle_batch_and_run(None) else: @@ -497,7 +516,9 @@ class SchedulerDisaggregationDecodeMixin: # Generate fake extend output. if batch.forward_mode.is_extend(): # Note: Logprobs should be handled on the prefill engine. - self.stream_output(batch.reqs, False) + self.stream_output( + batch.reqs, any(req.return_logprob for req in batch.reqs) + ) if prepare_dp_attn_flag: batch_, result = self._prepare_idle_batch_and_run( None, delay_process=True @@ -618,15 +639,8 @@ class SchedulerDisaggregationDecodeMixin: def process_decode_queue(self: Scheduler): req_conns = self.disagg_decode_prealloc_queue.pop_preallocated() - - def _num_pre_alloc(req): - return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0) - - self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns) self.disagg_decode_transfer_queue.extend(req_conns) alloc_reqs = ( self.disagg_decode_transfer_queue.pop_transferred() ) # the requests which kv has arrived - self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs) - - self.waiting_queue.extend([req.req for req in alloc_reqs]) + self.waiting_queue.extend(alloc_reqs) diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index 0c2e763c2..c05e8231d 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING import torch -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo logger = logging.getLogger(__name__) @@ -76,6 +76,11 @@ class ScheduleBatchDisaggregationDecodeMixin: self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device) self.out_cache_loc = out_cache_loc self.seq_lens_sum = sum(seq_lens) + + if self.return_logprob: + self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + self.token_ids_logprobs = [r.token_ids_logprob for r in reqs] + self.extend_num_tokens = extend_num_tokens self.prefix_lens = [len(r.prefix_indices) for r in reqs] self.extend_lens = [r.extend_input_len for r in reqs] @@ -94,12 +99,41 @@ class ScheduleBatchDisaggregationDecodeMixin: """Assign the buffered last input id to schedule batch""" self.output_ids = [] for req in self.reqs: - if req.output_ids and len(req.output_ids) > 0: - # resumed retracted req - self.output_ids.append(req.output_ids[-1]) - else: - assert req.transferred_output_id is not None - req.output_ids.append(req.transferred_output_id) - self.output_ids.append(req.transferred_output_id) + self.output_ids.append(req.output_ids[-1]) self.tree_cache.cache_unfinished_req(req) self.output_ids = torch.tensor(self.output_ids, device=self.device) + + # Simulate the eagle run. We add mock data to hidden states for the + # ease of implementation now meaning the first token will have acc rate + # of 0. + if not self.spec_algorithm.is_none(): + + b = len(self.reqs) + topk_p = torch.arange( + b * server_args.speculative_eagle_topk, + 0, + -1, + device=self.device, + dtype=torch.float32, + ) + topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk) + topk_p /= b * server_args.speculative_eagle_topk + topk_index = torch.arange( + b * server_args.speculative_eagle_topk, device=self.device + ) + topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk) + + # local import to avoid circular import + from sglang.srt.speculative.eagle_utils import EagleDraftInput + + spec_info = EagleDraftInput( + topk_p=topk_p, + topk_index=topk_index, + hidden_states=torch.ones( + (b, model_config.hidden_size), device=self.device + ), + verified_id=self.output_ids, + ) + spec_info.prepare_for_extend(self) + spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + self.spec_info = spec_info diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index 013b77326..f6de3a884 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -73,11 +73,27 @@ class MiniLoadBalancer: session.post(f"{prefill_server}/{endpoint}", json=modified_request), session.post(f"{decode_server}/{endpoint}", json=modified_request), ] + # Wait for both responses to complete. Prefill should end first. - _, decode_response = await asyncio.gather(*tasks) + prefill_response, decode_response = await asyncio.gather(*tasks) + + if "return_logprob" in modified_request: + + prefill_json = await prefill_response.json() + ret_json = await decode_response.json() + + # merge `meta_info.input_token_logprobs` from prefill to decode + if "meta_info" in ret_json: + if "input_token_logprobs" in ret_json["meta_info"]: + ret_json["meta_info"]["input_token_logprobs"] = ( + prefill_json["meta_info"]["input_token_logprobs"] + + ret_json["meta_info"]["input_token_logprobs"] + ) + else: + ret_json = await decode_response.json() return ORJSONResponse( - content=await decode_response.json(), + content=ret_json, status_code=decode_response.status, ) @@ -92,30 +108,47 @@ class MiniLoadBalancer: total=3600 ) # Add timeout for request reliability ) as session: - try: - # Create the tasks for both prefill and decode requests - tasks = [ - session.post( - f"{prefill_server}/{endpoint}", json=modified_request - ), - session.post( - f"{decode_server}/{endpoint}", json=modified_request - ), - ] - # Wait for both responses to complete. Since this is streaming, they return immediately. - prefill_response, decode_response = await asyncio.gather(*tasks) + # Create the tasks for both prefill and decode requests + tasks = [ + session.post(f"{prefill_server}/generate", json=modified_request), + session.post(f"{decode_server}/generate", json=modified_request), + ] + # Wait for both responses to complete. Since this is streaming, they return immediately. + prefill_response, decode_response = await asyncio.gather(*tasks) + + if modified_request.get("return_logprob", False): + prefill_chunks = [] + async for chunk in prefill_response.content: + prefill_chunks.append(chunk) + + first_prefill_chunk = ( + prefill_chunks[0].decode("utf-8")[5:].strip("\n") + ) + first_prefill_chunk_json = orjson.loads(first_prefill_chunk) + + async for chunk in decode_response.content: + # Note: This is inefficient + # merge prefill input_token_logprobs, output_token_logprobs to decode + decoded_chunk = chunk.decode("utf-8") + if ( + decoded_chunk + and decoded_chunk.startswith("data:") + and "[DONE]" not in decoded_chunk + ): + ret_json = orjson.loads(decoded_chunk[5:].strip("\n")) + ret_json["meta_info"]["input_token_logprobs"] = ( + first_prefill_chunk_json["meta_info"][ + "input_token_logprobs" + ] + + ret_json["meta_info"]["input_token_logprobs"] + ) + + yield b"data: " + orjson.dumps(ret_json) + b"\n\n" + else: + yield chunk + else: async for chunk in decode_response.content: yield chunk - except Exception as e: - error_msg = { - "error": {"message": f"Stream processing error: {str(e)}"} - } - yield b"data: " + orjson.dumps( - error_msg, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - finally: - if prefill_response is not None: - await prefill_response.release() return StreamingResponse( stream_results(), diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 6075914e5..0ed04f06a 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -32,6 +32,7 @@ from sglang.srt.disaggregation.utils import ( DisaggregationMode, FakeBootstrapHost, KVClassType, + MetadataBuffers, ReqToMetadataIdxAllocator, TransferBackend, get_kv_class, @@ -63,8 +64,7 @@ class PrefillBootstrapQueue: token_to_kv_pool: KVCache, draft_token_to_kv_pool: Optional[KVCache], req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, - metadata_buffers: List[torch.Tensor], - aux_dtype: torch.dtype, + metadata_buffers: MetadataBuffers, tp_rank: int, tp_size: int, bootstrap_port: int, @@ -76,7 +76,6 @@ class PrefillBootstrapQueue: self.draft_token_to_kv_pool = draft_token_to_kv_pool self.is_mla_backend = is_mla_backend(token_to_kv_pool) - self.aux_dtype = aux_dtype self.metadata_buffers = metadata_buffers self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator @@ -116,15 +115,9 @@ class PrefillBootstrapQueue: kv_args.kv_item_lens = kv_item_lens # Define req -> input ids buffer - kv_args.aux_data_ptrs = [ - metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers - ] - kv_args.aux_data_lens = [ - metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers - ] - kv_args.aux_item_lens = [ - metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers - ] + kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( + self.metadata_buffers.get_buf_infos() + ) kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER) @@ -299,10 +292,9 @@ class SchedulerDisaggregationPrefillMixin: launch_done: Optional[threading.Event] = None, ) -> None: """ - Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue + Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue Adapted from process_batch_result_prefill """ - ( logits_output, next_token_ids, @@ -315,27 +307,78 @@ class SchedulerDisaggregationPrefillMixin: result.extend_logprob_start_len_per_req, ) + logprob_pt = 0 # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue if self.enable_overlap: # wait - _, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done) + logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result( + launch_done + ) else: next_token_ids = result.next_token_ids.tolist() - - for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True): + if batch.return_logprob: + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs.tolist() + ) + if logits_output.input_token_logprobs is not None: + logits_output.input_token_logprobs = tuple( + logits_output.input_token_logprobs.tolist() + ) + for i, (req, next_token_id) in enumerate( + zip(batch.reqs, next_token_ids, strict=True) + ): req: Req if req.is_chunked <= 0: # There is no output_ids for prefill req.output_ids.append(next_token_id) self.tree_cache.cache_unfinished_req(req) # update the tree and lock - self.send_kv_chunk(req, token_id=next_token_id) self.disagg_prefill_inflight_queue.append(req) + if req.return_logprob: + assert extend_logprob_start_len_per_req is not None + assert extend_input_len_per_req is not None + extend_logprob_start_len = extend_logprob_start_len_per_req[i] + extend_input_len = extend_input_len_per_req[i] + num_input_logprobs = extend_input_len - extend_logprob_start_len + self.add_logprob_return_values( + i, + req, + logprob_pt, + next_token_ids, + num_input_logprobs, + logits_output, + ) + logprob_pt += num_input_logprobs + self.send_kv_chunk(req, last_chunk=True) + + if req.grammar is not None: + req.grammar.accept_token(next_token_id) + req.grammar.finished = req.finished() else: # being chunked reqs' prefill is not finished req.is_chunked -= 1 + if req.return_logprob: + extend_logprob_start_len = extend_logprob_start_len_per_req[i] + extend_input_len = extend_input_len_per_req[i] + if extend_logprob_start_len < extend_input_len: + # Update input logprobs. + num_input_logprobs = extend_input_len - extend_logprob_start_len + self.add_input_logprob_return_values( + i, + req, + logits_output, + logprob_pt, + num_input_logprobs, + last_prefill_chunk=False, + ) + logprob_pt += num_input_logprobs + if self.enable_overlap: - self.send_kv_chunk(req, end_idx=req.tmp_end_idx) + self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx) + + # We need to remove the sync in the following function for overlap schedule. + self.set_next_batch_sampling_info_done(batch) def process_disagg_prefill_inflight_queue(self: Scheduler) -> None: """ @@ -379,7 +422,11 @@ class SchedulerDisaggregationPrefillMixin: ) # Stream requests which have finished transfer - self.stream_output(done_reqs, False, None) + self.stream_output( + done_reqs, + any(req.return_logprob for req in done_reqs), + None, + ) self.disagg_prefill_inflight_queue = undone_reqs @@ -405,7 +452,7 @@ class SchedulerDisaggregationPrefillMixin: def send_kv_chunk( self: Scheduler, req: Req, - token_id: Optional[int] = None, + last_chunk: bool = False, end_idx: Optional[int] = None, ) -> None: """ @@ -413,37 +460,28 @@ class SchedulerDisaggregationPrefillMixin: """ page_size = self.token_to_kv_pool_allocator.page_size start_idx = req.start_send_idx - # if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule, - # the resolved length is not the same as fill_ids's length end_idx = ( end_idx if end_idx is not None else min(len(req.fill_ids), len(req.origin_input_ids)) ) - last_chunk = token_id is not None - if not last_chunk: # if not the last chunk and the last page is partial, delay the last partial page to the next send end_idx = end_idx - end_idx % page_size - # Update next start_send_idx - req.start_send_idx = end_idx - kv_indices = ( self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx] .cpu() .numpy() ) - if last_chunk is True: - self.disagg_prefill_bootstrap_queue.store_prefill_results( - req.metadata_buffer_index, token_id - ) + req.start_send_idx = end_idx + if last_chunk: + self.disagg_metadata_buffers.set_buf(req) page_indices = kv_to_page_indices(kv_indices, page_size) if len(page_indices) == 0: logger.info( f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" ) return - req.disagg_kv_sender.send(page_indices) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 5a5eb4780..74923cd89 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -6,7 +6,7 @@ import random import warnings from collections import deque from enum import Enum -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional import numpy as np import requests @@ -15,6 +15,9 @@ import torch.distributed as dist from sglang.srt.utils import get_ip +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req + FakeBootstrapHost = "2.2.2.2" # env var for testing failure, convert to float explicitly @@ -196,3 +199,83 @@ def prepare_abort(req: Req, error_message: str, status_code=None): req.input_top_logprobs_idx = [] req.input_token_ids_logprobs_val = [] req.input_token_ids_logprobs_idx = [] + + +class MetadataBuffers: + def __init__(self, size: int, max_top_logprobs_num: int = 128): + # TODO: abort top_logprobs_num > 128 in PD + + # We transfer the metadata of first output token to decode + # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes + self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu") + self.output_token_logprobs_val = torch.zeros( + (size, 16), dtype=torch.float32, device="cpu" + ) + self.output_token_logprobs_idx = torch.zeros( + (size, 16), dtype=torch.int32, device="cpu" + ) + self.output_top_logprobs_val = torch.zeros( + (size, max_top_logprobs_num), dtype=torch.float32, device="cpu" + ) + self.output_top_logprobs_idx = torch.zeros( + (size, max_top_logprobs_num), dtype=torch.int32, device="cpu" + ) + + def get_buf_infos(self): + ptrs = [ + self.output_ids.data_ptr(), + self.output_token_logprobs_val.data_ptr(), + self.output_token_logprobs_idx.data_ptr(), + self.output_top_logprobs_val.data_ptr(), + self.output_top_logprobs_idx.data_ptr(), + ] + data_lens = [ + self.output_ids.nbytes, + self.output_token_logprobs_val.nbytes, + self.output_token_logprobs_idx.nbytes, + self.output_top_logprobs_val.nbytes, + self.output_top_logprobs_idx.nbytes, + ] + item_lens = [ + self.output_ids[0].nbytes, + self.output_token_logprobs_val[0].nbytes, + self.output_token_logprobs_idx[0].nbytes, + self.output_top_logprobs_val[0].nbytes, + self.output_top_logprobs_idx[0].nbytes, + ] + return ptrs, data_lens, item_lens + + def get_buf(self, idx: int): + return ( + self.output_ids[idx], + self.output_token_logprobs_val[idx], + self.output_token_logprobs_idx[idx], + self.output_top_logprobs_val[idx], + self.output_top_logprobs_idx[idx], + ) + + def set_buf(self, req: Req): + + self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] + if req.return_logprob: + if req.output_token_logprobs_val: # not none or empty list + self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( + req.output_token_logprobs_val[0] + ) + if req.output_token_logprobs_idx: # not none or empty list + self.output_token_logprobs_idx[req.metadata_buffer_index][0] = ( + req.output_token_logprobs_idx[0] + ) + + if req.output_top_logprobs_val: # not none or empty list + self.output_top_logprobs_val[req.metadata_buffer_index][ + : len(req.output_top_logprobs_val[0]) + ] = torch.tensor( + req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu" + ) + if req.output_top_logprobs_idx: # not none or empty list + self.output_top_logprobs_idx[req.metadata_buffer_index][ + : len(req.output_top_logprobs_idx[0]) + ] = torch.tensor( + req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu" + ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 4f86ac5dd..be6c3e99f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -607,9 +607,6 @@ class Req: self.tmp_end_idx: int = -1 self.metadata_buffer_index: int = -1 - # The first output_id transferred from prefill instance. - self.transferred_output_id: Optional[int] = None - @property def seqlen(self): return len(self.origin_input_ids) + len(self.output_ids) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index db282da51..2d9b840ae 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -48,6 +48,7 @@ from sglang.srt.disaggregation.prefill import ( ) from sglang.srt.disaggregation.utils import ( DisaggregationMode, + MetadataBuffers, ReqToMetadataIdxAllocator, TransferBackend, prepare_abort, @@ -569,20 +570,13 @@ class Scheduler( req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( buffer_size ) - aux_dtype = torch.int32 - # A list of metadata buffers. The shape is (b, metadata_size) where - # b corresponds to a max running requests. The last shape * dtype.itemsize - # should be larger than 64 bytes to work with RDMA, so we pad it. - output_id_buffer = torch.zeros( - (buffer_size, 16), dtype=aux_dtype, device="cpu" - ) - metadata_buffers = [output_id_buffer] + self.disagg_metadata_buffers = MetadataBuffers(buffer_size) # The decode requests polling kv cache self.disagg_decode_transfer_queue = DecodeTransferQueue( gloo_group=self.attn_tp_cpu_group, req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, - metadata_buffers=metadata_buffers, + metadata_buffers=self.disagg_metadata_buffers, scheduler=self, tree_cache=self.tree_cache, ) @@ -597,8 +591,7 @@ class Scheduler( else self.draft_worker.model_runner.token_to_kv_pool ), req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, - metadata_buffers=metadata_buffers, - aux_dtype=aux_dtype, + metadata_buffers=self.disagg_metadata_buffers, scheduler=self, transfer_queue=self.disagg_decode_transfer_queue, tree_cache=self.tree_cache, @@ -618,14 +611,7 @@ class Scheduler( req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator( buffer_size ) - aux_dtype = torch.int32 - # A list of metadata buffers. The shape is (b, metadata_size) where - # b corresponds to a max running requests. The last shape * dtype.itemsize - # should be larger than 64 bytes to work with RDMA, so we pad it. - output_id_buffer = torch.zeros( - (buffer_size, 16), dtype=aux_dtype, device="cpu" - ) - metadata_buffers = [output_id_buffer] + self.disagg_metadata_buffers = MetadataBuffers(buffer_size) self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue( token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), @@ -635,8 +621,7 @@ class Scheduler( else self.draft_worker.model_runner.token_to_kv_pool ), req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, - metadata_buffers=metadata_buffers, - aux_dtype=aux_dtype, + metadata_buffers=self.disagg_metadata_buffers, tp_rank=self.tp_rank, tp_size=self.tp_size, bootstrap_port=self.server_args.disaggregation_bootstrap_port, diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 22a6a47e7..fe213b9ee 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -485,7 +485,6 @@ def popen_launch_pd_server( api_key: Optional[str] = None, other_args: list[str] = (), env: Optional[dict] = None, - return_stdout_stderr: Optional[tuple] = None, ): _, host, port = base_url.split(":") host = host[2:] @@ -515,42 +514,9 @@ def popen_launch_pd_server( print(f"command={' '.join(command)}") - if return_stdout_stderr: - process = subprocess.Popen( - command, - stdout=return_stdout_stderr[0], - stderr=return_stdout_stderr[1], - env=env, - text=True, - ) - else: - process = subprocess.Popen(command, stdout=None, stderr=None, env=env) + process = subprocess.Popen(command, stdout=None, stderr=None, env=env) - start_time = time.perf_counter() - with requests.Session() as session: - while time.perf_counter() - start_time < timeout: - try: - headers = { - "Content-Type": "application/json; charset=utf-8", - "Authorization": f"Bearer {api_key}", - } - response = session.get( - f"{base_url}/health", - headers=headers, - ) - if response.status_code == 200: - return process - except requests.RequestException: - pass - - return_code = process.poll() - if return_code is not None: - raise Exception(f"Server unexpectedly exits ({return_code=}).") - - time.sleep(10) - - kill_process_tree(process.pid) - raise TimeoutError("Server failed to start within the timeout period.") + return process def run_with_timeout( diff --git a/scripts/playground/disaggregation/cli-logprob.py b/scripts/playground/disaggregation/cli-logprob.py new file mode 100644 index 000000000..2dcfd3d4e --- /dev/null +++ b/scripts/playground/disaggregation/cli-logprob.py @@ -0,0 +1,22 @@ +prompt = "The capital of taiwan is " + +import json + +import requests + +response = requests.post( + "http://0.0.0.0:8000/generate", + json={ + "text": prompt, + "sampling_params": {"temperature": 0}, + "return_logprob": True, + "return_input_logprob": True, + "logprob_start_len": 0, + }, +) + +j = response.json() +input_logprobs = j["meta_info"]["input_token_logprobs"] +output_logprobs = j["meta_info"]["output_token_logprobs"] + +print(len(input_logprobs), len(output_logprobs)) diff --git a/test/srt/test_disaggregation.py b/test/srt/test_disaggregation.py index 3a9996a78..fda00e249 100644 --- a/test/srt/test_disaggregation.py +++ b/test/srt/test_disaggregation.py @@ -1,7 +1,9 @@ +import os import subprocess import time import unittest from types import SimpleNamespace +from urllib.parse import urlparse import requests @@ -25,15 +27,22 @@ class TestDisaggregationAccuracy(CustomTestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_host = "127.0.0.1" - cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1]) - cls.lb_url = DEFAULT_URL_FOR_TEST - cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}" - cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}" + parsed_url = urlparse(DEFAULT_URL_FOR_TEST) + cls.base_host = parsed_url.hostname + base_port = str(parsed_url.port) + cls.lb_port = base_port + cls.prefill_port = f"{int(base_port) + 100}" + cls.decode_port = f"{int(base_port) + 200}" + cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" + cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" + cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" + print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") - run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) - run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) + # Non blocking start servers + cls.start_prefill() + cls.start_decode() + # Block until both cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") @@ -48,7 +57,7 @@ class TestDisaggregationAccuracy(CustomTestCase): "--host", cls.base_host, "--port", - str(cls.base_port), + cls.lb_port, ] print("Starting load balancer:", " ".join(lb_command)) @@ -63,14 +72,10 @@ class TestDisaggregationAccuracy(CustomTestCase): "--trust-remote-code", "--disaggregation-mode", "prefill", - "--host", - cls.base_host, - "--port", - str(cls.base_port + 100), "--tp", - "4", - # "--disaggregation-ib-device", - # "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3", + "1", + "--disaggregation-ib-device", + "mlx5_roce0", ] cls.process_prefill = popen_launch_pd_server( cls.model, @@ -85,16 +90,12 @@ class TestDisaggregationAccuracy(CustomTestCase): "--trust-remote-code", "--disaggregation-mode", "decode", - "--host", - cls.base_host, - "--port", - str(cls.base_port + 200), "--tp", - "4", + "1", "--base-gpu-id", - "4", - # "--disaggregation-ib-device", - # "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + "1", + "--disaggregation-ib-device", + "mlx5_roce1", ] cls.process_decode = popen_launch_pd_server( cls.model, @@ -128,6 +129,9 @@ class TestDisaggregationAccuracy(CustomTestCase): except Exception as e: print(f"Error killing process {process.pid}: {e}") + # wait for 5 seconds + time.sleep(5) + def test_gsm8k(self): args = SimpleNamespace( num_shots=5, @@ -135,45 +139,63 @@ class TestDisaggregationAccuracy(CustomTestCase): num_questions=200, max_new_tokens=512, parallel=128, - host="http://127.0.0.1", - port=int(self.lb_url.split(":")[-1]), + host=f"http://{self.base_host}", + port=int(self.lb_port), ) metrics = run_eval_few_shot_gsm8k(args) print(f"Evaluation metrics: {metrics}") self.assertGreater(metrics["accuracy"], 0.62) + def test_logprob(self): + prompt = "The capital of taiwan is " + response = requests.post( + self.lb_url + "/generate", + json={ + "text": prompt, + "sampling_params": {"temperature": 0}, + "return_logprob": True, + "return_input_logprob": True, + "logprob_start_len": 0, + }, + ) -class TestDisaggregationSpecAccuracy(CustomTestCase): + j = response.json() + completion_tokens = j["meta_info"]["completion_tokens"] + input_logprobs = j["meta_info"]["input_token_logprobs"] + output_logprobs = j["meta_info"]["output_token_logprobs"] + assert ( + len(output_logprobs) == completion_tokens + ), f"output_logprobs and completion_tokens should have the same length, but got {len(output_logprobs)} and {completion_tokens}" + assert ( + len(input_logprobs) > 0 + ), f"input_logprobs should have at least one token, but got {len(input_logprobs)}" + + +class TestDisaggregationMooncakeFailure(CustomTestCase): @classmethod def setUpClass(cls): - super().setUpClass() - cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST - cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST - cls.base_host = "127.0.0.1" - cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1]) - cls.lb_url = DEFAULT_URL_FOR_TEST - cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}" - cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}" - cls.spec_args = [ - "--speculative-algorithm", - "EAGLE", - "--speculative-draft-model-path", - cls.draft_model, - "--speculative-num-steps", - "3", - "--speculative-eagle-topk", - "4", - "--speculative-num-draft-tokens", - "16", - "--cuda-graph-max-bs", - "8", - ] + # set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure + os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05" - run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) - run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH) + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + parsed_url = urlparse(DEFAULT_URL_FOR_TEST) + cls.base_host = parsed_url.hostname + base_port = str(parsed_url.port) + cls.lb_port = base_port + cls.prefill_port = f"{int(base_port) + 100}" + cls.decode_port = f"{int(base_port) + 200}" + cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" + cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" + cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" + print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") + # Non blocking start servers + cls.start_prefill() + cls.start_decode() + + # Block until both cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.decode_url + "/health") @@ -188,7 +210,149 @@ class TestDisaggregationSpecAccuracy(CustomTestCase): "--host", cls.base_host, "--port", - str(cls.base_port), + cls.lb_port, + ] + + print("Starting load balancer:", " ".join(lb_command)) + cls.process_lb = subprocess.Popen( + lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + cls.wait_server_ready(cls.lb_url + "/health") + + @classmethod + def start_prefill(cls): + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--tp", + "1", + "--disaggregation-ib-device", + "mlx5_roce0", + ] + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + ) + + @classmethod + def start_decode(cls): + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp", + "1", + "--base-gpu-id", + "1", + "--disaggregation-ib-device", + "mlx5_roce1", + ] + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + ) + + @classmethod + def wait_server_ready(cls, url, timeout=60): + start_time = time.perf_counter() + while True: + try: + response = requests.get(url) + if response.status_code == 200: + print(f"Server {url} is ready") + return + except Exception: + pass + + if time.perf_counter() - start_time > timeout: + raise RuntimeError(f"Server {url} failed to start in {timeout}s") + time.sleep(1) + + @classmethod + def tearDownClass(cls): + # unset DISAGGREGATION_TEST_FAILURE_PROB + os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB") + for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process {process.pid}: {e}") + + # wait for 5 seconds + time.sleep(5) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host=f"http://{self.base_host}", + port=int(self.lb_port), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"Evaluation metrics: {metrics}") + # Expect lots of failure but the server cannot crash + + +class TestDisaggregationMooncakeSpec(CustomTestCase): + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST + cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST + parsed_url = urlparse(DEFAULT_URL_FOR_TEST) + cls.base_host = parsed_url.hostname + base_port = str(parsed_url.port) + cls.lb_port = base_port + cls.prefill_port = f"{int(base_port) + 100}" + cls.decode_port = f"{int(base_port) + 200}" + cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" + cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" + cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" + cls.spec_args = [ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + cls.draft_model, + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "4", + "--speculative-num-draft-tokens", + "16", + "--cuda-graph-max-bs", + "8", + ] + print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") + + # Non blocking start servers + cls.start_prefill() + cls.start_decode() + + # Block until both + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + lb_command = [ + "python3", + "-m", + "sglang.srt.disaggregation.mini_lb", + "--prefill", + cls.prefill_url, + "--decode", + cls.decode_url, + "--host", + cls.base_host, + "--port", + cls.lb_port, ] print("Starting load balancer:", " ".join(lb_command)) @@ -215,21 +379,15 @@ class TestDisaggregationSpecAccuracy(CustomTestCase): @classmethod def start_prefill(cls): - prefill_args = [ "--trust-remote-code", "--disaggregation-mode", "prefill", - "--host", - cls.base_host, - "--port", - str(cls.base_port + 100), "--tp", - "4", - # "--disaggregation-ib-device", - # "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3", + "2", + "--disaggregation-ib-device", + "mlx5_roce0,mlx5_roce1", ] + cls.spec_args - cls.process_prefill = popen_launch_pd_server( cls.model, cls.prefill_url, @@ -243,16 +401,12 @@ class TestDisaggregationSpecAccuracy(CustomTestCase): "--trust-remote-code", "--disaggregation-mode", "decode", - "--host", - cls.base_host, - "--port", - str(cls.base_port + 200), "--tp", - "4", + "2", "--base-gpu-id", - "4", - # "--disaggregation-ib-device", - # "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7", + "2", + "--disaggregation-ib-device", + "mlx5_roce2,mlx5_roce3", ] + cls.spec_args cls.process_decode = popen_launch_pd_server( cls.model, @@ -261,15 +415,43 @@ class TestDisaggregationSpecAccuracy(CustomTestCase): other_args=decode_args, ) + @classmethod + def wait_server_ready(cls, url, timeout=60): + start_time = time.perf_counter() + while True: + try: + response = requests.get(url) + if response.status_code == 200: + print(f"Server {url} is ready") + return + except Exception: + pass + + if time.perf_counter() - start_time > timeout: + raise RuntimeError(f"Server {url} failed to start in {timeout}s") + time.sleep(1) + + @classmethod + def tearDownClass(cls): + for process in [cls.process_lb, cls.process_decode, cls.process_prefill]: + if process: + try: + kill_process_tree(process.pid) + except Exception as e: + print(f"Error killing process {process.pid}: {e}") + + # wait for 5 seconds + time.sleep(5) + def test_gsm8k(self): args = SimpleNamespace( num_shots=5, data_path=None, num_questions=200, max_new_tokens=512, - parallel=4, # TODO: 128 crashes the decode - host="http://127.0.0.1", - port=int(self.lb_url.split(":")[-1]), + parallel=2, + host=f"http://{self.base_host}", + port=int(self.lb_port), ) metrics = run_eval_few_shot_gsm8k(args) print(f"Evaluation metrics: {metrics}")