Support updating weights at once by stopping all requests (#6698)
Signed-off-by: Tianyu Zhou <albert.zty@antgroup.com> Co-authored-by: Zilin Zhu <zhuzilinallen@gmail.com>
This commit is contained in:
@@ -662,7 +662,9 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
|||||||
async def abort_request(obj: AbortReq, request: Request):
|
async def abort_request(obj: AbortReq, request: Request):
|
||||||
"""Abort a request."""
|
"""Abort a request."""
|
||||||
try:
|
try:
|
||||||
_global_state.tokenizer_manager.abort_request(rid=obj.rid)
|
_global_state.tokenizer_manager.abort_request(
|
||||||
|
rid=obj.rid, abort_all=obj.abort_all
|
||||||
|
)
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return _create_error_response(e)
|
return _create_error_response(e)
|
||||||
|
|||||||
@@ -236,7 +236,7 @@ class CompletionResponseStreamChoice(BaseModel):
|
|||||||
index: int
|
index: int
|
||||||
text: str
|
text: str
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
|
||||||
matched_stop: Union[None, int, str] = None
|
matched_stop: Union[None, int, str] = None
|
||||||
hidden_states: Optional[object] = None
|
hidden_states: Optional[object] = None
|
||||||
|
|
||||||
@@ -510,7 +510,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
|||||||
delta: DeltaMessage
|
delta: DeltaMessage
|
||||||
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
||||||
finish_reason: Optional[
|
finish_reason: Optional[
|
||||||
Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
|
Literal[
|
||||||
|
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
||||||
|
]
|
||||||
] = None
|
] = None
|
||||||
matched_stop: Union[None, int, str] = None
|
matched_stop: Union[None, int, str] = None
|
||||||
|
|
||||||
|
|||||||
@@ -740,6 +740,8 @@ class UpdateWeightFromDiskReqInput:
|
|||||||
model_path: str
|
model_path: str
|
||||||
# The format to load the weights
|
# The format to load the weights
|
||||||
load_format: Optional[str] = None
|
load_format: Optional[str] = None
|
||||||
|
# Whether to abort all requests before updating weights
|
||||||
|
abort_all_requests: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -759,6 +761,8 @@ class UpdateWeightsFromDistributedReqInput:
|
|||||||
group_name: str = "weight_update_group"
|
group_name: str = "weight_update_group"
|
||||||
# Whether to flush the cache after updating weights
|
# Whether to flush the cache after updating weights
|
||||||
flush_cache: bool = True
|
flush_cache: bool = True
|
||||||
|
# Whether to abort all requests before updating weights
|
||||||
|
abort_all_requests: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -780,6 +784,8 @@ class UpdateWeightsFromTensorReqInput:
|
|||||||
load_format: Optional[str] = None
|
load_format: Optional[str] = None
|
||||||
# Whether to flush the cache after updating weights
|
# Whether to flush the cache after updating weights
|
||||||
flush_cache: bool = True
|
flush_cache: bool = True
|
||||||
|
# Whether to abort all requests before updating weights
|
||||||
|
abort_all_requests: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -858,7 +864,9 @@ class SlowDownReqOutput:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AbortReq:
|
class AbortReq:
|
||||||
# The request id
|
# The request id
|
||||||
rid: str
|
rid: str = ""
|
||||||
|
# Whether to abort all requests
|
||||||
|
abort_all: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -2211,7 +2211,7 @@ class Scheduler(
|
|||||||
# Delete requests in the waiting queue
|
# Delete requests in the waiting queue
|
||||||
to_del = []
|
to_del = []
|
||||||
for i, req in enumerate(self.waiting_queue):
|
for i, req in enumerate(self.waiting_queue):
|
||||||
if req.rid.startswith(recv_req.rid):
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
||||||
to_del.append(i)
|
to_del.append(i)
|
||||||
|
|
||||||
# Sort in reverse order to avoid index issues when deleting
|
# Sort in reverse order to avoid index issues when deleting
|
||||||
@@ -2228,7 +2228,7 @@ class Scheduler(
|
|||||||
# Abort method 2: call `set_finish_with_abort`
|
# Abort method 2: call `set_finish_with_abort`
|
||||||
# The request will still run one prefill forward pass.
|
# The request will still run one prefill forward pass.
|
||||||
# In this case, we change the input_ids to be only one token to make this prefill cheap.
|
# In this case, we change the input_ids to be only one token to make this prefill cheap.
|
||||||
if req.rid.startswith(recv_req.rid):
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
||||||
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
||||||
if req.grammar:
|
if req.grammar:
|
||||||
req.grammar.cancel()
|
req.grammar.cancel()
|
||||||
@@ -2241,7 +2241,9 @@ class Scheduler(
|
|||||||
reqs = self.running_batch.reqs + self.cur_batch.reqs
|
reqs = self.running_batch.reqs + self.cur_batch.reqs
|
||||||
|
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
if req.rid.startswith(recv_req.rid) and not req.finished():
|
if not req.finished() and (
|
||||||
|
recv_req.abort_all or req.rid.startswith(recv_req.rid)
|
||||||
|
):
|
||||||
# Abort method 3: set `to_abort=True`
|
# Abort method 3: set `to_abort=True`
|
||||||
# The request will still run one decode forward pass.
|
# The request will still run one decode forward pass.
|
||||||
# Then we reuse all existing code to clean up the KV cache allocation.
|
# Then we reuse all existing code to clean up the KV cache allocation.
|
||||||
|
|||||||
@@ -846,10 +846,10 @@ class TokenizerManager:
|
|||||||
async def flush_cache(self) -> FlushCacheReqOutput:
|
async def flush_cache(self) -> FlushCacheReqOutput:
|
||||||
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
||||||
|
|
||||||
def abort_request(self, rid: str):
|
def abort_request(self, rid: str = "", abort_all: bool = False):
|
||||||
if rid not in self.rid_to_state:
|
if not abort_all and rid not in self.rid_to_state:
|
||||||
return
|
return
|
||||||
req = AbortReq(rid)
|
req = AbortReq(rid, abort_all)
|
||||||
self.send_to_scheduler.send_pyobj(req)
|
self.send_to_scheduler.send_pyobj(req)
|
||||||
|
|
||||||
if self.enable_metrics:
|
if self.enable_metrics:
|
||||||
@@ -914,6 +914,9 @@ class TokenizerManager:
|
|||||||
obj.load_format = self.server_args.load_format
|
obj.load_format = self.server_args.load_format
|
||||||
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
||||||
|
|
||||||
|
if obj.abort_all_requests:
|
||||||
|
self.abort_request(abort_all=True)
|
||||||
|
|
||||||
if True: # Keep this redundant check to simplify some internal code sync
|
if True: # Keep this redundant check to simplify some internal code sync
|
||||||
# Hold the lock if it is not async. This means that weight sync
|
# Hold the lock if it is not async. This means that weight sync
|
||||||
# cannot run while requests are in progress.
|
# cannot run while requests are in progress.
|
||||||
@@ -969,6 +972,9 @@ class TokenizerManager:
|
|||||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||||
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
|
||||||
|
|
||||||
|
if obj.abort_all_requests:
|
||||||
|
self.abort_request(abort_all=True)
|
||||||
|
|
||||||
# This means that weight sync
|
# This means that weight sync
|
||||||
# cannot run while requests are in progress.
|
# cannot run while requests are in progress.
|
||||||
async with self.model_update_lock.writer_lock:
|
async with self.model_update_lock.writer_lock:
|
||||||
@@ -985,6 +991,9 @@ class TokenizerManager:
|
|||||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||||
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
|
||||||
|
|
||||||
|
if obj.abort_all_requests:
|
||||||
|
self.abort_request(abort_all=True)
|
||||||
|
|
||||||
# This means that weight sync
|
# This means that weight sync
|
||||||
# cannot run while requests are in progress.
|
# cannot run while requests are in progress.
|
||||||
async with self.model_update_lock.writer_lock:
|
async with self.model_update_lock.writer_lock:
|
||||||
@@ -1619,7 +1628,23 @@ class TokenizerManager:
|
|||||||
self.crash_dump_request_list.popleft()
|
self.crash_dump_request_list.popleft()
|
||||||
|
|
||||||
def _handle_abort_req(self, recv_obj):
|
def _handle_abort_req(self, recv_obj):
|
||||||
self.rid_to_state.pop(recv_obj.rid, None)
|
state = self.rid_to_state[recv_obj.rid]
|
||||||
|
state.finished = True
|
||||||
|
state.out_list.append(
|
||||||
|
{
|
||||||
|
"text": "",
|
||||||
|
"meta_info": {
|
||||||
|
"id": recv_obj.rid,
|
||||||
|
"finish_reason": {
|
||||||
|
"type": "abort",
|
||||||
|
"message": "Abort before prefill",
|
||||||
|
},
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
state.event.set()
|
||||||
|
|
||||||
def _handle_open_session_req_output(self, recv_obj):
|
def _handle_open_session_req_output(self, recv_obj):
|
||||||
self.session_futures[recv_obj.session_id].set_result(
|
self.session_futures[recv_obj.session_id].set_result(
|
||||||
|
|||||||
@@ -1,11 +1,20 @@
|
|||||||
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from sglang.test.test_utils import CustomTestCase, run_and_check_memory_leak
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
run_and_check_memory_leak,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestAbort(CustomTestCase):
|
class TestAbort(CustomTestCase):
|
||||||
@@ -50,5 +59,56 @@ class TestAbort(CustomTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAbortAll(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=["--max-running-requests", 8],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def _run_decode(self):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 16000,
|
||||||
|
"ignore_eos": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def test_abort_all(self):
|
||||||
|
num_requests = 32
|
||||||
|
with ThreadPoolExecutor(num_requests) as executor:
|
||||||
|
futures = [executor.submit(self._run_decode) for _ in range(num_requests)]
|
||||||
|
|
||||||
|
# ensure the decode has been started
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
requests.post(
|
||||||
|
self.base_url + "/abort_request",
|
||||||
|
json={
|
||||||
|
"abort_all": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for future in as_completed(futures):
|
||||||
|
self.assertEqual(
|
||||||
|
future.result()["meta_info"]["finish_reason"]["type"], "abort"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -153,6 +155,82 @@ class TestServerUpdateWeightsFromDisk(CustomTestCase):
|
|||||||
self.assertEqual(origin_response[:32], updated_response[:32])
|
self.assertEqual(origin_response[:32], updated_response[:32])
|
||||||
|
|
||||||
|
|
||||||
|
class TestServerUpdateWeightsFromDiskAbortAllRequests(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=["--max-running-requests", 8],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def run_decode(self, max_new_tokens=32):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"ignore_eos": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def get_model_info(self):
|
||||||
|
response = requests.get(self.base_url + "/get_model_info")
|
||||||
|
model_path = response.json()["model_path"]
|
||||||
|
print(json.dumps(response.json()))
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
def run_update_weights(self, model_path, abort_all_requests=False):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/update_weights_from_disk",
|
||||||
|
json={
|
||||||
|
"model_path": model_path,
|
||||||
|
"abort_all_requests": abort_all_requests,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ret = response.json()
|
||||||
|
print(json.dumps(ret))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def test_update_weights_abort_all_requests(self):
|
||||||
|
origin_model_path = self.get_model_info()
|
||||||
|
print(f"[Server Mode] origin_model_path: {origin_model_path}")
|
||||||
|
|
||||||
|
num_requests = 32
|
||||||
|
with ThreadPoolExecutor(num_requests) as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(self.run_decode, 16000) for _ in range(num_requests)
|
||||||
|
]
|
||||||
|
|
||||||
|
# ensure the decode has been started
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
|
||||||
|
ret = self.run_update_weights(new_model_path, abort_all_requests=True)
|
||||||
|
self.assertTrue(ret["success"])
|
||||||
|
|
||||||
|
for future in as_completed(futures):
|
||||||
|
self.assertEqual(
|
||||||
|
future.result()["meta_info"]["finish_reason"]["type"], "abort"
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_model_path = self.get_model_info()
|
||||||
|
print(f"[Server Mode] updated_model_path: {updated_model_path}")
|
||||||
|
self.assertEqual(updated_model_path, new_model_path)
|
||||||
|
self.assertNotEqual(updated_model_path, origin_model_path)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Parameterized Tests for update_weights_from_disk
|
# Parameterized Tests for update_weights_from_disk
|
||||||
# Test coverage is determined based on the value of is_in_ci:
|
# Test coverage is determined based on the value of is_in_ci:
|
||||||
|
|||||||
Reference in New Issue
Block a user