diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index c5baa6988..bb0b47471 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -992,6 +992,14 @@ class MooncakeKVSender(BaseKVSender): ) raise KVTransferError(self.bootstrap_room, failure_reason) + def abort(self): + self.kv_mgr.record_failure( + self.bootstrap_room, + "Aborted by AbortReq.", + ) + # Explicitly set the status to failure since this request has been aborted + self.conclude_state = KVPoll.Failed + class MooncakeKVReceiver(BaseKVReceiver): _ctx = zmq.Context() @@ -1305,6 +1313,14 @@ class MooncakeKVReceiver(BaseKVReceiver): ) raise KVTransferError(self.bootstrap_room, failure_reason) + def abort(self): + self.kv_mgr.record_failure( + self.bootstrap_room, + "Aborted by AbortReq.", + ) + # Explicitly set the status to failure since this request has been aborted + self.conclude_state = KVPoll.Failed + class MooncakeKVBootstrapServer(BaseKVBootstrapServer): def __init__(self, port: int): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 0be67eaca..ecfce1392 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2440,6 +2440,37 @@ class Scheduler( req.grammar.cancel() req.set_finish_with_abort("Aborted by AbortReq.") + # Delete requests not in the waiting queue when PD disaggregation is enabled + if self.disaggregation_mode == DisaggregationMode.PREFILL: + # Abort requests that have not yet been bootstrapped + for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue): + logger.debug(f"Abort bootstrap queue request. {req.rid=}") + if recv_req.abort_all or req.rid.startswith(recv_req.rid): + if hasattr(req.disagg_kv_sender, "abort"): + req.disagg_kv_sender.abort() + + # Abort in-flight requests + for i, req in enumerate(self.disagg_prefill_inflight_queue): + logger.debug(f"Abort inflight queue request. {req.rid=}") + if recv_req.abort_all or req.rid.startswith(recv_req.rid): + if hasattr(req.disagg_kv_sender, "abort"): + req.disagg_kv_sender.abort() + + elif self.disaggregation_mode == DisaggregationMode.DECODE: + # Abort requests that have not yet finished preallocation + for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue): + logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}") + if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid): + if hasattr(decode_req.kv_receiver, "abort"): + decode_req.kv_receiver.abort() + + # Abort requests waiting for kvcache to release tree cache + for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue): + logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}") + if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid): + if hasattr(decode_req.kv_receiver, "abort"): + decode_req.kv_receiver.abort() + # Delete requests in the running batch if self.cur_batch is self.running_batch or self.cur_batch is None: reqs = self.running_batch.reqs