diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 0aef85ba5..84c3a79d8 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -41,6 +41,7 @@ from sglang.srt.disaggregation.utils import ( is_mla_backend, kv_to_page_indices, poll_and_all_reduce, + prepare_abort, ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator @@ -178,7 +179,17 @@ class DecodePreallocQueue: elif poll == KVPoll.WaitingForInput: decode_req.waiting_for_input = True elif poll == KVPoll.Failed: - raise Exception("Handshake failed") + error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" + try: + decode_req.kv_receiver.failure_exception() + except Exception as e: + error_message += f" with exception {e}" + logger.error(error_message) + prepare_abort( + decode_req.req, + error_message, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) def pop_preallocated(self) -> List[DecodeRequest]: """Pop the preallocated requests from the pending queue (FIFO).""" @@ -333,7 +344,24 @@ class DecodeTransferQueue: indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): if poll == KVPoll.Failed: - raise Exception("Transfer failed") + error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" + try: + decode_req.kv_receiver.failure_exception() + except Exception as e: + error_message += f" with exception {e}" + logger.error(error_message) + prepare_abort( + decode_req.req, + error_message, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + self.scheduler.stream_output( + [decode_req.req], decode_req.req.return_logprob + ) + # unlock the kv cache or it will have memory leak + self.tree_cache.cache_finished_req(decode_req.req) + indices_to_remove.add(i) + continue elif poll == KVPoll.Success: # pop and push it to waiting queue idx = decode_req.metadata_buffer_index diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 4b843e02e..9e894d1e3 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -496,6 +496,7 @@ class MooncakeKVSender(BaseKVSender): return self.kv_mgr.check_status(self.bootstrap_room) def failure_exception(self): + # TODO: raise a real exception raise Exception("Fake KVSender Exception") @@ -723,6 +724,7 @@ class MooncakeKVReceiver(BaseKVReceiver): return self.kv_mgr.check_status(self.bootstrap_room) def failure_exception(self): + # TODO: raise a real exception raise Exception("Fake KVReceiver Exception") diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 83fe5a838..2572210d5 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -38,6 +38,7 @@ from sglang.srt.disaggregation.utils import ( kv_to_page_indices, kv_to_page_num, poll_and_all_reduce, + prepare_abort, ) from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch @@ -157,7 +158,18 @@ class PrefillBootstrapQueue: if poll == KVPoll.Bootstrapping: continue elif poll == KVPoll.Failed: - raise Exception("Bootstrap failed") + error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}" + try: + req.disagg_kv_sender.failure_exception() + except Exception as e: + error_message += f" with exception {e}" + logger.error(error_message) + prepare_abort( + req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR + ) + self.scheduler.stream_output([req], req.return_logprob) + indices_to_remove.add(i) + continue # KV.WaitingForInput num_kv_indices = len(req.origin_input_ids) @@ -335,7 +347,17 @@ class SchedulerDisaggregationPrefillMixin: # FIXME: clean up req's data in transfer engine done_reqs.append(req) elif poll == KVPoll.Failed: - raise Exception("Transferring failed") + error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}" + try: + req.disagg_kv_sender.failure_exception() + except Exception as e: + error_message += f" with exception {e}" + logger.warning(error_message) + self.tree_cache.cache_finished_req(req) # unlock the tree + prepare_abort( + req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR + ) + done_reqs.append(req) for req in done_reqs: self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free( diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 2d795c0f0..69348d4ff 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -167,3 +167,18 @@ def is_mla_backend(target_kv_pool) -> bool: from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool return isinstance(target_kv_pool, MLATokenToKVPool) + + +def prepare_abort(req: Req, error_message: str, status_code=None): + from sglang.srt.managers.schedule_batch import FINISH_ABORT + + # populate finish metadata and stream output + req.finished_reason = FINISH_ABORT(error_message, status_code) + + if req.return_logprob: + req.input_token_logprobs_val = [] + req.input_token_logprobs_idx = [] + req.input_top_logprobs_val = [] + req.input_top_logprobs_idx = [] + req.input_token_ids_logprobs_val = [] + req.input_token_ids_logprobs_idx = [] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 72a3f7246..62af0c6b3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import ( DisaggregationMode, ReqToMetadataIdxAllocator, TransferBackend, + prepare_abort, ) from sglang.srt.distributed import get_pp_group, get_world_group from sglang.srt.hf_transformers_utils import ( @@ -935,6 +936,18 @@ class Scheduler( ) req.tokenizer = self.tokenizer + if self.disaggregation_mode != DisaggregationMode.NULL: + # Invalid request for disaggregated mode + if recv_req.bootstrap_room is None: + error_message = ( + f"Invalid request: Disaggregated request received without " + f"boostrap room id. {req.rid=}" + ) + logger.error(error_message) + prepare_abort(req, error_message) + self.stream_output([req], req.return_logprob) + return + if ( recv_req.session_params is not None and recv_req.session_params.id is not None