Support updating weights at once by stopping all requests (#6698)

Signed-off-by: Tianyu Zhou <albert.zty@antgroup.com>
Co-authored-by: Zilin Zhu <zhuzilinallen@gmail.com>
This commit is contained in:
Albert
2025-07-03 13:26:06 +08:00
committed by GitHub
parent b044400dd3
commit d3c275b117
7 changed files with 190 additions and 13 deletions

View File

@@ -662,7 +662,9 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
async def abort_request(obj: AbortReq, request: Request):
"""Abort a request."""
try:
_global_state.tokenizer_manager.abort_request(rid=obj.rid)
_global_state.tokenizer_manager.abort_request(
rid=obj.rid, abort_all=obj.abort_all
)
return Response(status_code=200)
except Exception as e:
return _create_error_response(e)

View File

@@ -236,7 +236,7 @@ class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
matched_stop: Union[None, int, str] = None
hidden_states: Optional[object] = None
@@ -510,7 +510,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
finish_reason: Optional[
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
]
] = None
matched_stop: Union[None, int, str] = None

View File

@@ -740,6 +740,8 @@ class UpdateWeightFromDiskReqInput:
model_path: str
# The format to load the weights
load_format: Optional[str] = None
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
@dataclass
@@ -759,6 +761,8 @@ class UpdateWeightsFromDistributedReqInput:
group_name: str = "weight_update_group"
# Whether to flush the cache after updating weights
flush_cache: bool = True
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
@dataclass
@@ -780,6 +784,8 @@ class UpdateWeightsFromTensorReqInput:
load_format: Optional[str] = None
# Whether to flush the cache after updating weights
flush_cache: bool = True
# Whether to abort all requests before updating weights
abort_all_requests: bool = False
@dataclass
@@ -858,7 +864,9 @@ class SlowDownReqOutput:
@dataclass
class AbortReq:
# The request id
rid: str
rid: str = ""
# Whether to abort all requests
abort_all: bool = False
@dataclass

View File

@@ -2211,7 +2211,7 @@ class Scheduler(
# Delete requests in the waiting queue
to_del = []
for i, req in enumerate(self.waiting_queue):
if req.rid.startswith(recv_req.rid):
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
to_del.append(i)
# Sort in reverse order to avoid index issues when deleting
@@ -2228,7 +2228,7 @@ class Scheduler(
# Abort method 2: call `set_finish_with_abort`
# The request will still run one prefill forward pass.
# In this case, we change the input_ids to be only one token to make this prefill cheap.
if req.rid.startswith(recv_req.rid):
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
logger.debug(f"Abort grammar queue request. {req.rid=}")
if req.grammar:
req.grammar.cancel()
@@ -2241,7 +2241,9 @@ class Scheduler(
reqs = self.running_batch.reqs + self.cur_batch.reqs
for req in reqs:
if req.rid.startswith(recv_req.rid) and not req.finished():
if not req.finished() and (
recv_req.abort_all or req.rid.startswith(recv_req.rid)
):
# Abort method 3: set `to_abort=True`
# The request will still run one decode forward pass.
# Then we reuse all existing code to clean up the KV cache allocation.

View File

@@ -846,10 +846,10 @@ class TokenizerManager:
async def flush_cache(self) -> FlushCacheReqOutput:
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
def abort_request(self, rid: str):
if rid not in self.rid_to_state:
def abort_request(self, rid: str = "", abort_all: bool = False):
if not abort_all and rid not in self.rid_to_state:
return
req = AbortReq(rid)
req = AbortReq(rid, abort_all)
self.send_to_scheduler.send_pyobj(req)
if self.enable_metrics:
@@ -914,6 +914,9 @@ class TokenizerManager:
obj.load_format = self.server_args.load_format
logger.info("Start update_weights. Load format=%s", obj.load_format)
if obj.abort_all_requests:
self.abort_request(abort_all=True)
if True: # Keep this redundant check to simplify some internal code sync
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
@@ -969,6 +972,9 @@ class TokenizerManager:
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
if obj.abort_all_requests:
self.abort_request(abort_all=True)
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
@@ -985,6 +991,9 @@ class TokenizerManager:
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
if obj.abort_all_requests:
self.abort_request(abort_all=True)
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
@@ -1619,7 +1628,23 @@ class TokenizerManager:
self.crash_dump_request_list.popleft()
def _handle_abort_req(self, recv_obj):
self.rid_to_state.pop(recv_obj.rid, None)
state = self.rid_to_state[recv_obj.rid]
state.finished = True
state.out_list.append(
{
"text": "",
"meta_info": {
"id": recv_obj.rid,
"finish_reason": {
"type": "abort",
"message": "Abort before prefill",
},
"prompt_tokens": 0,
"completion_tokens": 0,
},
}
)
state.event.set()
def _handle_open_session_req_output(self, recv_obj):
self.session_futures[recv_obj.session_id].set_result(