[PD] Abort request if transfer fails (#6504)
This commit is contained in:
@@ -41,6 +41,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
is_mla_backend,
|
is_mla_backend,
|
||||||
kv_to_page_indices,
|
kv_to_page_indices,
|
||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
|
prepare_abort,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
@@ -178,7 +179,17 @@ class DecodePreallocQueue:
|
|||||||
elif poll == KVPoll.WaitingForInput:
|
elif poll == KVPoll.WaitingForInput:
|
||||||
decode_req.waiting_for_input = True
|
decode_req.waiting_for_input = True
|
||||||
elif poll == KVPoll.Failed:
|
elif poll == KVPoll.Failed:
|
||||||
raise Exception("Handshake failed")
|
error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
||||||
|
try:
|
||||||
|
decode_req.kv_receiver.failure_exception()
|
||||||
|
except Exception as e:
|
||||||
|
error_message += f" with exception {e}"
|
||||||
|
logger.error(error_message)
|
||||||
|
prepare_abort(
|
||||||
|
decode_req.req,
|
||||||
|
error_message,
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
def pop_preallocated(self) -> List[DecodeRequest]:
|
def pop_preallocated(self) -> List[DecodeRequest]:
|
||||||
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
||||||
@@ -333,7 +344,24 @@ class DecodeTransferQueue:
|
|||||||
indices_to_remove = set()
|
indices_to_remove = set()
|
||||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||||
if poll == KVPoll.Failed:
|
if poll == KVPoll.Failed:
|
||||||
raise Exception("Transfer failed")
|
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
||||||
|
try:
|
||||||
|
decode_req.kv_receiver.failure_exception()
|
||||||
|
except Exception as e:
|
||||||
|
error_message += f" with exception {e}"
|
||||||
|
logger.error(error_message)
|
||||||
|
prepare_abort(
|
||||||
|
decode_req.req,
|
||||||
|
error_message,
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
self.scheduler.stream_output(
|
||||||
|
[decode_req.req], decode_req.req.return_logprob
|
||||||
|
)
|
||||||
|
# unlock the kv cache or it will have memory leak
|
||||||
|
self.tree_cache.cache_finished_req(decode_req.req)
|
||||||
|
indices_to_remove.add(i)
|
||||||
|
continue
|
||||||
elif poll == KVPoll.Success:
|
elif poll == KVPoll.Success:
|
||||||
# pop and push it to waiting queue
|
# pop and push it to waiting queue
|
||||||
idx = decode_req.metadata_buffer_index
|
idx = decode_req.metadata_buffer_index
|
||||||
|
|||||||
@@ -496,6 +496,7 @@ class MooncakeKVSender(BaseKVSender):
|
|||||||
return self.kv_mgr.check_status(self.bootstrap_room)
|
return self.kv_mgr.check_status(self.bootstrap_room)
|
||||||
|
|
||||||
def failure_exception(self):
|
def failure_exception(self):
|
||||||
|
# TODO: raise a real exception
|
||||||
raise Exception("Fake KVSender Exception")
|
raise Exception("Fake KVSender Exception")
|
||||||
|
|
||||||
|
|
||||||
@@ -723,6 +724,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
return self.kv_mgr.check_status(self.bootstrap_room)
|
return self.kv_mgr.check_status(self.bootstrap_room)
|
||||||
|
|
||||||
def failure_exception(self):
|
def failure_exception(self):
|
||||||
|
# TODO: raise a real exception
|
||||||
raise Exception("Fake KVReceiver Exception")
|
raise Exception("Fake KVReceiver Exception")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
kv_to_page_indices,
|
kv_to_page_indices,
|
||||||
kv_to_page_num,
|
kv_to_page_num,
|
||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
|
prepare_abort,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||||
|
|
||||||
@@ -157,7 +158,18 @@ class PrefillBootstrapQueue:
|
|||||||
if poll == KVPoll.Bootstrapping:
|
if poll == KVPoll.Bootstrapping:
|
||||||
continue
|
continue
|
||||||
elif poll == KVPoll.Failed:
|
elif poll == KVPoll.Failed:
|
||||||
raise Exception("Bootstrap failed")
|
error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
|
||||||
|
try:
|
||||||
|
req.disagg_kv_sender.failure_exception()
|
||||||
|
except Exception as e:
|
||||||
|
error_message += f" with exception {e}"
|
||||||
|
logger.error(error_message)
|
||||||
|
prepare_abort(
|
||||||
|
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
)
|
||||||
|
self.scheduler.stream_output([req], req.return_logprob)
|
||||||
|
indices_to_remove.add(i)
|
||||||
|
continue
|
||||||
|
|
||||||
# KV.WaitingForInput
|
# KV.WaitingForInput
|
||||||
num_kv_indices = len(req.origin_input_ids)
|
num_kv_indices = len(req.origin_input_ids)
|
||||||
@@ -335,7 +347,17 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
# FIXME: clean up req's data in transfer engine
|
# FIXME: clean up req's data in transfer engine
|
||||||
done_reqs.append(req)
|
done_reqs.append(req)
|
||||||
elif poll == KVPoll.Failed:
|
elif poll == KVPoll.Failed:
|
||||||
raise Exception("Transferring failed")
|
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
|
||||||
|
try:
|
||||||
|
req.disagg_kv_sender.failure_exception()
|
||||||
|
except Exception as e:
|
||||||
|
error_message += f" with exception {e}"
|
||||||
|
logger.warning(error_message)
|
||||||
|
self.tree_cache.cache_finished_req(req) # unlock the tree
|
||||||
|
prepare_abort(
|
||||||
|
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
)
|
||||||
|
done_reqs.append(req)
|
||||||
|
|
||||||
for req in done_reqs:
|
for req in done_reqs:
|
||||||
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
||||||
|
|||||||
@@ -167,3 +167,18 @@ def is_mla_backend(target_kv_pool) -> bool:
|
|||||||
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
||||||
|
|
||||||
return isinstance(target_kv_pool, MLATokenToKVPool)
|
return isinstance(target_kv_pool, MLATokenToKVPool)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_abort(req: Req, error_message: str, status_code=None):
|
||||||
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
||||||
|
|
||||||
|
# populate finish metadata and stream output
|
||||||
|
req.finished_reason = FINISH_ABORT(error_message, status_code)
|
||||||
|
|
||||||
|
if req.return_logprob:
|
||||||
|
req.input_token_logprobs_val = []
|
||||||
|
req.input_token_logprobs_idx = []
|
||||||
|
req.input_top_logprobs_val = []
|
||||||
|
req.input_top_logprobs_idx = []
|
||||||
|
req.input_token_ids_logprobs_val = []
|
||||||
|
req.input_token_ids_logprobs_idx = []
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
DisaggregationMode,
|
DisaggregationMode,
|
||||||
ReqToMetadataIdxAllocator,
|
ReqToMetadataIdxAllocator,
|
||||||
TransferBackend,
|
TransferBackend,
|
||||||
|
prepare_abort,
|
||||||
)
|
)
|
||||||
from sglang.srt.distributed import get_pp_group, get_world_group
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
||||||
from sglang.srt.hf_transformers_utils import (
|
from sglang.srt.hf_transformers_utils import (
|
||||||
@@ -935,6 +936,18 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
|
||||||
|
if self.disaggregation_mode != DisaggregationMode.NULL:
|
||||||
|
# Invalid request for disaggregated mode
|
||||||
|
if recv_req.bootstrap_room is None:
|
||||||
|
error_message = (
|
||||||
|
f"Invalid request: Disaggregated request received without "
|
||||||
|
f"boostrap room id. {req.rid=}"
|
||||||
|
)
|
||||||
|
logger.error(error_message)
|
||||||
|
prepare_abort(req, error_message)
|
||||||
|
self.stream_output([req], req.return_logprob)
|
||||||
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
recv_req.session_params is not None
|
recv_req.session_params is not None
|
||||||
and recv_req.session_params.id is not None
|
and recv_req.session_params.id is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user