[PD] Abort request if transfer fails (#6504)
This commit is contained in:
@@ -41,6 +41,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
is_mla_backend,
|
||||
kv_to_page_indices,
|
||||
poll_and_all_reduce,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
@@ -178,7 +179,17 @@ class DecodePreallocQueue:
|
||||
elif poll == KVPoll.WaitingForInput:
|
||||
decode_req.waiting_for_input = True
|
||||
elif poll == KVPoll.Failed:
|
||||
raise Exception("Handshake failed")
|
||||
error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
||||
try:
|
||||
decode_req.kv_receiver.failure_exception()
|
||||
except Exception as e:
|
||||
error_message += f" with exception {e}"
|
||||
logger.error(error_message)
|
||||
prepare_abort(
|
||||
decode_req.req,
|
||||
error_message,
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
def pop_preallocated(self) -> List[DecodeRequest]:
|
||||
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
||||
@@ -333,7 +344,24 @@ class DecodeTransferQueue:
|
||||
indices_to_remove = set()
|
||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||
if poll == KVPoll.Failed:
|
||||
raise Exception("Transfer failed")
|
||||
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
||||
try:
|
||||
decode_req.kv_receiver.failure_exception()
|
||||
except Exception as e:
|
||||
error_message += f" with exception {e}"
|
||||
logger.error(error_message)
|
||||
prepare_abort(
|
||||
decode_req.req,
|
||||
error_message,
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
self.scheduler.stream_output(
|
||||
[decode_req.req], decode_req.req.return_logprob
|
||||
)
|
||||
# unlock the kv cache or it will have memory leak
|
||||
self.tree_cache.cache_finished_req(decode_req.req)
|
||||
indices_to_remove.add(i)
|
||||
continue
|
||||
elif poll == KVPoll.Success:
|
||||
# pop and push it to waiting queue
|
||||
idx = decode_req.metadata_buffer_index
|
||||
|
||||
@@ -496,6 +496,7 @@ class MooncakeKVSender(BaseKVSender):
|
||||
return self.kv_mgr.check_status(self.bootstrap_room)
|
||||
|
||||
def failure_exception(self):
|
||||
# TODO: raise a real exception
|
||||
raise Exception("Fake KVSender Exception")
|
||||
|
||||
|
||||
@@ -723,6 +724,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
return self.kv_mgr.check_status(self.bootstrap_room)
|
||||
|
||||
def failure_exception(self):
|
||||
# TODO: raise a real exception
|
||||
raise Exception("Fake KVReceiver Exception")
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
kv_to_page_indices,
|
||||
kv_to_page_num,
|
||||
poll_and_all_reduce,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||
|
||||
@@ -157,7 +158,18 @@ class PrefillBootstrapQueue:
|
||||
if poll == KVPoll.Bootstrapping:
|
||||
continue
|
||||
elif poll == KVPoll.Failed:
|
||||
raise Exception("Bootstrap failed")
|
||||
error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
|
||||
try:
|
||||
req.disagg_kv_sender.failure_exception()
|
||||
except Exception as e:
|
||||
error_message += f" with exception {e}"
|
||||
logger.error(error_message)
|
||||
prepare_abort(
|
||||
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
self.scheduler.stream_output([req], req.return_logprob)
|
||||
indices_to_remove.add(i)
|
||||
continue
|
||||
|
||||
# KV.WaitingForInput
|
||||
num_kv_indices = len(req.origin_input_ids)
|
||||
@@ -335,7 +347,17 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# FIXME: clean up req's data in transfer engine
|
||||
done_reqs.append(req)
|
||||
elif poll == KVPoll.Failed:
|
||||
raise Exception("Transferring failed")
|
||||
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
|
||||
try:
|
||||
req.disagg_kv_sender.failure_exception()
|
||||
except Exception as e:
|
||||
error_message += f" with exception {e}"
|
||||
logger.warning(error_message)
|
||||
self.tree_cache.cache_finished_req(req) # unlock the tree
|
||||
prepare_abort(
|
||||
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
)
|
||||
done_reqs.append(req)
|
||||
|
||||
for req in done_reqs:
|
||||
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
||||
|
||||
@@ -167,3 +167,18 @@ def is_mla_backend(target_kv_pool) -> bool:
|
||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||
|
||||
return isinstance(target_kv_pool, MLATokenToKVPool)
|
||||
|
||||
|
||||
def prepare_abort(req: Req, error_message: str, status_code=None):
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
||||
|
||||
# populate finish metadata and stream output
|
||||
req.finished_reason = FINISH_ABORT(error_message, status_code)
|
||||
|
||||
if req.return_logprob:
|
||||
req.input_token_logprobs_val = []
|
||||
req.input_token_logprobs_idx = []
|
||||
req.input_top_logprobs_val = []
|
||||
req.input_top_logprobs_idx = []
|
||||
req.input_token_ids_logprobs_val = []
|
||||
req.input_token_ids_logprobs_idx = []
|
||||
|
||||
@@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.distributed import get_pp_group, get_world_group
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
@@ -935,6 +936,18 @@ class Scheduler(
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
if self.disaggregation_mode != DisaggregationMode.NULL:
|
||||
# Invalid request for disaggregated mode
|
||||
if recv_req.bootstrap_room is None:
|
||||
error_message = (
|
||||
f"Invalid request: Disaggregated request received without "
|
||||
f"boostrap room id. {req.rid=}"
|
||||
)
|
||||
logger.error(error_message)
|
||||
prepare_abort(req, error_message)
|
||||
self.stream_output([req], req.return_logprob)
|
||||
return
|
||||
|
||||
if (
|
||||
recv_req.session_params is not None
|
||||
and recv_req.session_params.id is not None
|
||||
|
||||
Reference in New Issue
Block a user