Support updating weights at once by stopping all requests (#6698)
Signed-off-by: Tianyu Zhou <albert.zty@antgroup.com> Co-authored-by: Zilin Zhu <zhuzilinallen@gmail.com>
This commit is contained in:
@@ -662,7 +662,9 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
async def abort_request(obj: AbortReq, request: Request):
|
||||
"""Abort a request."""
|
||||
try:
|
||||
_global_state.tokenizer_manager.abort_request(rid=obj.rid)
|
||||
_global_state.tokenizer_manager.abort_request(
|
||||
rid=obj.rid, abort_all=obj.abort_all
|
||||
)
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
@@ -236,7 +236,7 @@ class CompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
hidden_states: Optional[object] = None
|
||||
|
||||
@@ -510,7 +510,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||
finish_reason: Optional[
|
||||
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
|
||||
Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
||||
]
|
||||
] = None
|
||||
matched_stop: Union[None, int, str] = None
|
||||
|
||||
|
||||
@@ -740,6 +740,8 @@ class UpdateWeightFromDiskReqInput:
|
||||
model_path: str
|
||||
# The format to load the weights
|
||||
load_format: Optional[str] = None
|
||||
# Whether to abort all requests before updating weights
|
||||
abort_all_requests: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -759,6 +761,8 @@ class UpdateWeightsFromDistributedReqInput:
|
||||
group_name: str = "weight_update_group"
|
||||
# Whether to flush the cache after updating weights
|
||||
flush_cache: bool = True
|
||||
# Whether to abort all requests before updating weights
|
||||
abort_all_requests: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -780,6 +784,8 @@ class UpdateWeightsFromTensorReqInput:
|
||||
load_format: Optional[str] = None
|
||||
# Whether to flush the cache after updating weights
|
||||
flush_cache: bool = True
|
||||
# Whether to abort all requests before updating weights
|
||||
abort_all_requests: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -858,7 +864,9 @@ class SlowDownReqOutput:
|
||||
@dataclass
|
||||
class AbortReq:
|
||||
# The request id
|
||||
rid: str
|
||||
rid: str = ""
|
||||
# Whether to abort all requests
|
||||
abort_all: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -2211,7 +2211,7 @@ class Scheduler(
|
||||
# Delete requests in the waiting queue
|
||||
to_del = []
|
||||
for i, req in enumerate(self.waiting_queue):
|
||||
if req.rid.startswith(recv_req.rid):
|
||||
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
||||
to_del.append(i)
|
||||
|
||||
# Sort in reverse order to avoid index issues when deleting
|
||||
@@ -2228,7 +2228,7 @@ class Scheduler(
|
||||
# Abort method 2: call `set_finish_with_abort`
|
||||
# The request will still run one prefill forward pass.
|
||||
# In this case, we change the input_ids to be only one token to make this prefill cheap.
|
||||
if req.rid.startswith(recv_req.rid):
|
||||
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
||||
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
||||
if req.grammar:
|
||||
req.grammar.cancel()
|
||||
@@ -2241,7 +2241,9 @@ class Scheduler(
|
||||
reqs = self.running_batch.reqs + self.cur_batch.reqs
|
||||
|
||||
for req in reqs:
|
||||
if req.rid.startswith(recv_req.rid) and not req.finished():
|
||||
if not req.finished() and (
|
||||
recv_req.abort_all or req.rid.startswith(recv_req.rid)
|
||||
):
|
||||
# Abort method 3: set `to_abort=True`
|
||||
# The request will still run one decode forward pass.
|
||||
# Then we reuse all existing code to clean up the KV cache allocation.
|
||||
|
||||
@@ -846,10 +846,10 @@ class TokenizerManager:
|
||||
async def flush_cache(self) -> FlushCacheReqOutput:
|
||||
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
||||
|
||||
def abort_request(self, rid: str):
|
||||
if rid not in self.rid_to_state:
|
||||
def abort_request(self, rid: str = "", abort_all: bool = False):
|
||||
if not abort_all and rid not in self.rid_to_state:
|
||||
return
|
||||
req = AbortReq(rid)
|
||||
req = AbortReq(rid, abort_all)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
if self.enable_metrics:
|
||||
@@ -914,6 +914,9 @@ class TokenizerManager:
|
||||
obj.load_format = self.server_args.load_format
|
||||
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
||||
|
||||
if obj.abort_all_requests:
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
if True: # Keep this redundant check to simplify some internal code sync
|
||||
# Hold the lock if it is not async. This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
@@ -969,6 +972,9 @@ class TokenizerManager:
|
||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
||||
|
||||
if obj.abort_all_requests:
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
@@ -985,6 +991,9 @@ class TokenizerManager:
|
||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
||||
|
||||
if obj.abort_all_requests:
|
||||
self.abort_request(abort_all=True)
|
||||
|
||||
# This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
@@ -1619,7 +1628,23 @@ class TokenizerManager:
|
||||
self.crash_dump_request_list.popleft()
|
||||
|
||||
def _handle_abort_req(self, recv_obj):
|
||||
self.rid_to_state.pop(recv_obj.rid, None)
|
||||
state = self.rid_to_state[recv_obj.rid]
|
||||
state.finished = True
|
||||
state.out_list.append(
|
||||
{
|
||||
"text": "",
|
||||
"meta_info": {
|
||||
"id": recv_obj.rid,
|
||||
"finish_reason": {
|
||||
"type": "abort",
|
||||
"message": "Abort before prefill",
|
||||
},
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
}
|
||||
)
|
||||
state.event.set()
|
||||
|
||||
def _handle_open_session_req_output(self, recv_obj):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
import json
|
||||
import multiprocessing
|
||||
import time
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase, run_and_check_memory_leak
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
run_and_check_memory_leak,
|
||||
)
|
||||
|
||||
|
||||
class TestAbort(CustomTestCase):
|
||||
@@ -50,5 +59,56 @@ class TestAbort(CustomTestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestAbortAll(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--max-running-requests", 8],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def _run_decode(self):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16000,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def test_abort_all(self):
|
||||
num_requests = 32
|
||||
with ThreadPoolExecutor(num_requests) as executor:
|
||||
futures = [executor.submit(self._run_decode) for _ in range(num_requests)]
|
||||
|
||||
# ensure the decode has been started
|
||||
time.sleep(2)
|
||||
|
||||
requests.post(
|
||||
self.base_url + "/abort_request",
|
||||
json={
|
||||
"abort_all": True,
|
||||
},
|
||||
)
|
||||
|
||||
for future in as_completed(futures):
|
||||
self.assertEqual(
|
||||
future.result()["meta_info"]["finish_reason"]["type"], "abort"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import requests
|
||||
|
||||
@@ -153,6 +155,82 @@ class TestServerUpdateWeightsFromDisk(CustomTestCase):
|
||||
self.assertEqual(origin_response[:32], updated_response[:32])
|
||||
|
||||
|
||||
class TestServerUpdateWeightsFromDiskAbortAllRequests(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--max-running-requests", 8],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(self, max_new_tokens=32):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
|
||||
def get_model_info(self):
|
||||
response = requests.get(self.base_url + "/get_model_info")
|
||||
model_path = response.json()["model_path"]
|
||||
print(json.dumps(response.json()))
|
||||
return model_path
|
||||
|
||||
def run_update_weights(self, model_path, abort_all_requests=False):
|
||||
response = requests.post(
|
||||
self.base_url + "/update_weights_from_disk",
|
||||
json={
|
||||
"model_path": model_path,
|
||||
"abort_all_requests": abort_all_requests,
|
||||
},
|
||||
)
|
||||
ret = response.json()
|
||||
print(json.dumps(ret))
|
||||
return ret
|
||||
|
||||
def test_update_weights_abort_all_requests(self):
|
||||
origin_model_path = self.get_model_info()
|
||||
print(f"[Server Mode] origin_model_path: {origin_model_path}")
|
||||
|
||||
num_requests = 32
|
||||
with ThreadPoolExecutor(num_requests) as executor:
|
||||
futures = [
|
||||
executor.submit(self.run_decode, 16000) for _ in range(num_requests)
|
||||
]
|
||||
|
||||
# ensure the decode has been started
|
||||
time.sleep(2)
|
||||
|
||||
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
|
||||
ret = self.run_update_weights(new_model_path, abort_all_requests=True)
|
||||
self.assertTrue(ret["success"])
|
||||
|
||||
for future in as_completed(futures):
|
||||
self.assertEqual(
|
||||
future.result()["meta_info"]["finish_reason"]["type"], "abort"
|
||||
)
|
||||
|
||||
updated_model_path = self.get_model_info()
|
||||
print(f"[Server Mode] updated_model_path: {updated_model_path}")
|
||||
self.assertEqual(updated_model_path, new_model_path)
|
||||
self.assertNotEqual(updated_model_path, origin_model_path)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Parameterized Tests for update_weights_from_disk
|
||||
# Test coverage is determined based on the value of is_in_ci:
|
||||
|
||||
Reference in New Issue
Block a user