diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index a2676038d..0f258a9d9 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -662,7 +662,9 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request): async def abort_request(obj: AbortReq, request: Request): """Abort a request.""" try: - _global_state.tokenizer_manager.abort_request(rid=obj.rid) + _global_state.tokenizer_manager.abort_request( + rid=obj.rid, abort_all=obj.abort_all + ) return Response(status_code=200) except Exception as e: return _create_error_response(e) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 192869c31..552774537 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -236,7 +236,7 @@ class CompletionResponseStreamChoice(BaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None + finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None matched_stop: Union[None, int, str] = None hidden_states: Optional[object] = None @@ -510,7 +510,9 @@ class ChatCompletionResponseStreamChoice(BaseModel): delta: DeltaMessage logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None finish_reason: Optional[ - Literal["stop", "length", "tool_calls", "content_filter", "function_call"] + Literal[ + "stop", "length", "tool_calls", "content_filter", "function_call", "abort" + ] ] = None matched_stop: Union[None, int, str] = None diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 6c5b6e196..b42e44214 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -740,6 +740,8 @@ class UpdateWeightFromDiskReqInput: model_path: str # The format to load the weights load_format: Optional[str] = None + # Whether to abort all requests before updating weights + abort_all_requests: bool = False @dataclass @@ -759,6 +761,8 @@ class UpdateWeightsFromDistributedReqInput: group_name: str = "weight_update_group" # Whether to flush the cache after updating weights flush_cache: bool = True + # Whether to abort all requests before updating weights + abort_all_requests: bool = False @dataclass @@ -780,6 +784,8 @@ class UpdateWeightsFromTensorReqInput: load_format: Optional[str] = None # Whether to flush the cache after updating weights flush_cache: bool = True + # Whether to abort all requests before updating weights + abort_all_requests: bool = False @dataclass @@ -858,7 +864,9 @@ class SlowDownReqOutput: @dataclass class AbortReq: # The request id - rid: str + rid: str = "" + # Whether to abort all requests + abort_all: bool = False @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8a82c1fea..f7bd19f7a 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2211,7 +2211,7 @@ class Scheduler( # Delete requests in the waiting queue to_del = [] for i, req in enumerate(self.waiting_queue): - if req.rid.startswith(recv_req.rid): + if recv_req.abort_all or req.rid.startswith(recv_req.rid): to_del.append(i) # Sort in reverse order to avoid index issues when deleting @@ -2228,7 +2228,7 @@ class Scheduler( # Abort method 2: call `set_finish_with_abort` # The request will still run one prefill forward pass. # In this case, we change the input_ids to be only one token to make this prefill cheap. - if req.rid.startswith(recv_req.rid): + if recv_req.abort_all or req.rid.startswith(recv_req.rid): logger.debug(f"Abort grammar queue request. {req.rid=}") if req.grammar: req.grammar.cancel() @@ -2241,7 +2241,9 @@ class Scheduler( reqs = self.running_batch.reqs + self.cur_batch.reqs for req in reqs: - if req.rid.startswith(recv_req.rid) and not req.finished(): + if not req.finished() and ( + recv_req.abort_all or req.rid.startswith(recv_req.rid) + ): # Abort method 3: set `to_abort=True` # The request will still run one decode forward pass. # Then we reuse all existing code to clean up the KV cache allocation. diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 8e1eb758c..b16bb8a59 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -846,10 +846,10 @@ class TokenizerManager: async def flush_cache(self) -> FlushCacheReqOutput: return (await self.flush_cache_communicator(FlushCacheReqInput()))[0] - def abort_request(self, rid: str): - if rid not in self.rid_to_state: + def abort_request(self, rid: str = "", abort_all: bool = False): + if not abort_all and rid not in self.rid_to_state: return - req = AbortReq(rid) + req = AbortReq(rid, abort_all) self.send_to_scheduler.send_pyobj(req) if self.enable_metrics: @@ -914,6 +914,9 @@ class TokenizerManager: obj.load_format = self.server_args.load_format logger.info("Start update_weights. Load format=%s", obj.load_format) + if obj.abort_all_requests: + self.abort_request(abort_all=True) + if True: # Keep this redundant check to simplify some internal code sync # Hold the lock if it is not async. This means that weight sync # cannot run while requests are in progress. @@ -969,6 +972,9 @@ class TokenizerManager: self.server_args.dp_size == 1 or self.server_args.enable_dp_attention ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed" + if obj.abort_all_requests: + self.abort_request(abort_all=True) + # This means that weight sync # cannot run while requests are in progress. async with self.model_update_lock.writer_lock: @@ -985,6 +991,9 @@ class TokenizerManager: self.server_args.dp_size == 1 or self.server_args.enable_dp_attention ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor" + if obj.abort_all_requests: + self.abort_request(abort_all=True) + # This means that weight sync # cannot run while requests are in progress. async with self.model_update_lock.writer_lock: @@ -1619,7 +1628,23 @@ class TokenizerManager: self.crash_dump_request_list.popleft() def _handle_abort_req(self, recv_obj): - self.rid_to_state.pop(recv_obj.rid, None) + state = self.rid_to_state[recv_obj.rid] + state.finished = True + state.out_list.append( + { + "text": "", + "meta_info": { + "id": recv_obj.rid, + "finish_reason": { + "type": "abort", + "message": "Abort before prefill", + }, + "prompt_tokens": 0, + "completion_tokens": 0, + }, + } + ) + state.event.set() def _handle_open_session_req_output(self, recv_obj): self.session_futures[recv_obj.session_id].set_result( diff --git a/test/srt/test_abort.py b/test/srt/test_abort.py index d2ab4d034..591c21674 100644 --- a/test/srt/test_abort.py +++ b/test/srt/test_abort.py @@ -1,11 +1,20 @@ +import json import multiprocessing import time import unittest -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed import requests -from sglang.test.test_utils import CustomTestCase, run_and_check_memory_leak +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, + run_and_check_memory_leak, +) class TestAbort(CustomTestCase): @@ -50,5 +59,56 @@ class TestAbort(CustomTestCase): ) +class TestAbortAll(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--max-running-requests", 8], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def _run_decode(self): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16000, + "ignore_eos": True, + }, + }, + ) + return response.json() + + def test_abort_all(self): + num_requests = 32 + with ThreadPoolExecutor(num_requests) as executor: + futures = [executor.submit(self._run_decode) for _ in range(num_requests)] + + # ensure the decode has been started + time.sleep(2) + + requests.post( + self.base_url + "/abort_request", + json={ + "abort_all": True, + }, + ) + + for future in as_completed(futures): + self.assertEqual( + future.result()["meta_info"]["finish_reason"]["type"], "abort" + ) + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_update_weights_from_disk.py b/test/srt/test_update_weights_from_disk.py index 11b7e678a..02b283efd 100644 --- a/test/srt/test_update_weights_from_disk.py +++ b/test/srt/test_update_weights_from_disk.py @@ -1,6 +1,8 @@ import json import random +import time import unittest +from concurrent.futures import ThreadPoolExecutor, as_completed import requests @@ -153,6 +155,82 @@ class TestServerUpdateWeightsFromDisk(CustomTestCase): self.assertEqual(origin_response[:32], updated_response[:32]) +class TestServerUpdateWeightsFromDiskAbortAllRequests(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--max-running-requests", 8], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_decode(self, max_new_tokens=32): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + "ignore_eos": True, + }, + }, + ) + return response.json() + + def get_model_info(self): + response = requests.get(self.base_url + "/get_model_info") + model_path = response.json()["model_path"] + print(json.dumps(response.json())) + return model_path + + def run_update_weights(self, model_path, abort_all_requests=False): + response = requests.post( + self.base_url + "/update_weights_from_disk", + json={ + "model_path": model_path, + "abort_all_requests": abort_all_requests, + }, + ) + ret = response.json() + print(json.dumps(ret)) + return ret + + def test_update_weights_abort_all_requests(self): + origin_model_path = self.get_model_info() + print(f"[Server Mode] origin_model_path: {origin_model_path}") + + num_requests = 32 + with ThreadPoolExecutor(num_requests) as executor: + futures = [ + executor.submit(self.run_decode, 16000) for _ in range(num_requests) + ] + + # ensure the decode has been started + time.sleep(2) + + new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "") + ret = self.run_update_weights(new_model_path, abort_all_requests=True) + self.assertTrue(ret["success"]) + + for future in as_completed(futures): + self.assertEqual( + future.result()["meta_info"]["finish_reason"]["type"], "abort" + ) + + updated_model_path = self.get_model_info() + print(f"[Server Mode] updated_model_path: {updated_model_path}") + self.assertEqual(updated_model_path, new_model_path) + self.assertNotEqual(updated_model_path, origin_model_path) + + ############################################################################### # Parameterized Tests for update_weights_from_disk # Test coverage is determined based on the value of is_in_ci: