Fix a bug in abort & Improve docstrings for abort (#6931)
This commit is contained in:
@@ -2041,10 +2041,23 @@ class Scheduler(
|
|||||||
|
|
||||||
# Sort in reverse order to avoid index issues when deleting
|
# Sort in reverse order to avoid index issues when deleting
|
||||||
for i in reversed(to_del):
|
for i in reversed(to_del):
|
||||||
|
# Abort method 1: directly pop from the queue
|
||||||
|
# This only works for requests that have not started anything.
|
||||||
|
# We still need to send something back to TokenizerManager to clean up the state.
|
||||||
req = self.waiting_queue.pop(i)
|
req = self.waiting_queue.pop(i)
|
||||||
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
||||||
logger.debug(f"Abort queued request. {req.rid=}")
|
logger.debug(f"Abort queued request. {req.rid=}")
|
||||||
|
|
||||||
|
# Delete the requests in the grammar queue
|
||||||
|
for req in self.grammar_queue:
|
||||||
|
# Abort method 2: call `set_finish_with_abort`
|
||||||
|
# The request will still run one prefill forward pass.
|
||||||
|
# In this case, we change the input_ids to be only one token to make this prefill cheap.
|
||||||
|
if req.rid.startswith(recv_req.rid):
|
||||||
|
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
||||||
|
req.grammar.cancel()
|
||||||
|
req.set_finish_with_abort("Aborted by AbortReq.")
|
||||||
|
|
||||||
# Delete requests in the running batch
|
# Delete requests in the running batch
|
||||||
if self.cur_batch is self.running_batch or self.cur_batch is None:
|
if self.cur_batch is self.running_batch or self.cur_batch is None:
|
||||||
reqs = self.running_batch.reqs
|
reqs = self.running_batch.reqs
|
||||||
@@ -2053,17 +2066,12 @@ class Scheduler(
|
|||||||
|
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
if req.rid.startswith(recv_req.rid) and not req.finished():
|
if req.rid.startswith(recv_req.rid) and not req.finished():
|
||||||
|
# Abort method 3: set `to_abort=True`
|
||||||
|
# The request will still run one decode forward pass.
|
||||||
|
# Then we reuse all existing code to clean up the KV cache allocation.
|
||||||
logger.debug(f"Abort running request. {req.rid=}")
|
logger.debug(f"Abort running request. {req.rid=}")
|
||||||
# We must use to_abort because it is in a running batch
|
|
||||||
req.to_abort = True
|
req.to_abort = True
|
||||||
|
|
||||||
# Delete the requests in the grammar queue
|
|
||||||
for req in self.grammar_queue:
|
|
||||||
if req.rid.startswith(recv_req.rid):
|
|
||||||
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
|
||||||
req.grammar.cancel()
|
|
||||||
req.set_finish_with_abort("Aborted by AbortReq.")
|
|
||||||
|
|
||||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@@ -1419,7 +1419,7 @@ class TokenizerManager:
|
|||||||
asyncio.create_task(asyncio.to_thread(background_task))
|
asyncio.create_task(asyncio.to_thread(background_task))
|
||||||
|
|
||||||
def _handle_abort_req(self, recv_obj):
|
def _handle_abort_req(self, recv_obj):
|
||||||
self.rid_to_state.pop(recv_obj.rid)
|
self.rid_to_state.pop(recv_obj.rid, None)
|
||||||
|
|
||||||
def _handle_open_session_req_output(self, recv_obj):
|
def _handle_open_session_req_output(self, recv_obj):
|
||||||
self.session_futures[recv_obj.session_id].set_result(
|
self.session_futures[recv_obj.session_id].set_result(
|
||||||
|
|||||||
Reference in New Issue
Block a user