Support updating weights at once by stopping all requests (#6698)
Signed-off-by: Tianyu Zhou <albert.zty@antgroup.com> Co-authored-by: Zilin Zhu <zhuzilinallen@gmail.com>
This commit is contained in:
@@ -662,7 +662,9 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
async def abort_request(obj: AbortReq, request: Request):
|
||||
"""Abort a request."""
|
||||
try:
|
||||
_global_state.tokenizer_manager.abort_request(rid=obj.rid)
|
||||
_global_state.tokenizer_manager.abort_request(
|
||||
rid=obj.rid, abort_all=obj.abort_all
|
||||
)
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
@@ -236,7 +236,7 @@ class CompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@@ -510,7 +510,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||
finish_reason: Optional[
|
||||
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
|
||||
Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
||||
]
|
||||
] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
|
||||
|
||||
@@ -740,6 +740,8 @@ class UpdateWeightFromDiskReqInput:
|
||||
model_path: str
|
||||
# The format to load the weights
|
||||
load_format: Optional[str] = None
|
||||
# Whether to abort all requests before updating weights
|
||||
abort_all_requests: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -759,6 +761,8 @@ class UpdateWeightsFromDistributedReqInput:
|
||||
group_name: str = "weight_update_group"
|
||||
# Whether to flush the cache after updating weights
|
||||
flush_cache: bool = True
|
||||
# Whether to abort all requests before updating weights
|
||||
abort_all_requests: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -780,6 +784,8 @@ class UpdateWeightsFromTensorReqInput:
|
||||
load_format: Optional[str] = None
|
||||
# Whether to flush the cache after updating weights
|
||||
flush_cache: bool = True
|
||||
# Whether to abort all requests before updating weights
|
||||
abort_all_requests: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -858,7 +864,9 @@ class SlowDownReqOutput:
|
||||
@dataclass
|
||||
class AbortReq:
|
||||
# The request id
|
||||
rid: str
|
||||
rid: str = ""
|
||||
# Whether to abort all requests
|
||||
abort_all: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -2211,7 +2211,7 @@ class Scheduler(
|
||||
# Delete requests in the waiting queue
|
||||
to_del = []
|
||||
for i, req in enumerate(self.waiting_queue):
|
||||
if req.rid.startswith(recv_req.rid):
|
||||
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
||||
to_del.append(i)
|
||||
|
||||
# Sort in reverse order to avoid index issues when deleting
|
||||
@@ -2228,7 +2228,7 @@ class Scheduler(
|
||||
# Abort method 2: call `set_finish_with_abort`
|
||||
# The request will still run one prefill forward pass.
|
||||
# In this case, we change the input_ids to be only one token to make this prefill cheap.
|
||||
if req.rid.startswith(recv_req.rid):
|
||||
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
||||
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
||||
if req.grammar:
|
||||
req.grammar.cancel()
|
||||
@@ -2241,7 +2241,9 @@ class Scheduler(
|
||||
reqs = self.running_batch.reqs + self.cur_batch.reqs
|
||||
|
||||
for req in reqs:
|
||||
if req.rid.startswith(recv_req.rid) and not req.finished():
|
||||
if not req.finished() and (
|
||||
recv_req.abort_all or req.rid.startswith(recv_req.rid)
|
||||
):
|
||||
# Abort method 3: set `to_abort=True`
|
||||
# The request will still run one decode forward pass.
|
||||
# Then we reuse all existing code to clean up the KV cache allocation.
|
||||
|
||||
@@ -846,10 +846,10 @@ class TokenizerManager:
|
||||
async def flush_cache(self) -> FlushCacheReqOutput:
|
||||
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
||||
|
||||
def abort_request(self, rid: str):
|
||||
if rid not in self.rid_to_state:
|
||||
def abort_request(self, rid: str = "", abort_all: bool = False):
|
||||
if not abort_all and rid not in self.rid_to_state:
|
||||
return
|
||||
req = AbortReq(rid)
|
||||
req = AbortReq(rid, abort_all)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
if self.enable_metrics:
|
||||
@@ -914,6 +914,9 @@ class TokenizerManager:
|
||||
obj.load_format = self.server_args.load_format
|
||||
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
||||
|
||||
if obj.abort_all_requests:
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
if True: # Keep this redundant check to simplify some internal code sync
|
||||
# Hold the lock if it is not async. This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
@@ -969,6 +972,9 @@ class TokenizerManager:
|
||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
||||
|
||||
if obj.abort_all_requests:
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
@@ -985,6 +991,9 @@ class TokenizerManager:
|
||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
||||
|
||||
if obj.abort_all_requests:
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
@@ -1619,7 +1628,23 @@ class TokenizerManager:
|
||||
self.crash_dump_request_list.popleft()
|
||||
|
||||
def _handle_abort_req(self, recv_obj):
|
||||
self.rid_to_state.pop(recv_obj.rid, None)
|
||||
state = self.rid_to_state[recv_obj.rid]
|
||||
state.finished = True
|
||||
state.out_list.append(
|
||||
{
|
||||
"text": "",
|
||||
"meta_info": {
|
||||
"id": recv_obj.rid,
|
||||
"finish_reason": {
|
||||
"type": "abort",
|
||||
"message": "Abort before prefill",
|
||||
},
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
}
|
||||
)
|
||||
state.event.set()
|
||||
|
||||
def _handle_open_session_req_output(self, recv_obj):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
|
||||
Reference in New Issue
Block a user