From 6345069f6c0e2fd4850385e0bb5f441f73424233 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Sun, 10 Aug 2025 09:14:19 -0700 Subject: [PATCH] [RL] Add test for /abort_request (#7626) --- .../http_server/test_abort_request.py | 206 ++++++++++++++++++ 1 file changed, 206 insertions(+) create mode 100644 test/srt/entrypoints/http_server/test_abort_request.py diff --git a/test/srt/entrypoints/http_server/test_abort_request.py b/test/srt/entrypoints/http_server/test_abort_request.py new file mode 100644 index 000000000..7a366f244 --- /dev/null +++ b/test/srt/entrypoints/http_server/test_abort_request.py @@ -0,0 +1,206 @@ +""" +Integration test for abort_request functionality with a SGLang server. + +Run with: + python -m unittest sglang.test.srt.entrypoints.http_server.test_abort_request -v +""" + +import threading +import time +import unittest +from typing import Optional + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestAbortRequest(CustomTestCase): + """Integration test class for abort request functionality.""" + + model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + + @classmethod + def setUpClass(cls): + """Launch the server.""" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--disable-cuda-graph"], + ) + + cls.completion_url = f"{cls.base_url}/generate" + cls.abort_url = f"{cls.base_url}/abort_request" + cls.health_url = f"{cls.base_url}/health" + + print(f"Server started at {cls.base_url}") + + @classmethod + def tearDownClass(cls): + """Clean up the server.""" + kill_process_tree(cls.process.pid) + + def _send_completion_request( + self, + text: str, + request_id: str, + max_tokens: int = 50, + temperature: float = 0.8, + stream: bool = True, + ) -> requests.Response: + """Send a completion request to the server.""" + payload = { + "text": text, + "sampling_params": { + "max_new_tokens": max_tokens, + "temperature": temperature, + }, + "stream": stream, + "rid": request_id, + } + + response = requests.post( + self.completion_url, + json=payload, + headers={"Content-Type": "application/json"}, + timeout=30, + stream=stream, + ) + + return response + + def _send_abort_request(self, request_id: str) -> requests.Response: + """Send an abort request.""" + payload = {"rid": request_id} + return requests.post(self.abort_url, json=payload, timeout=10) + + def _check_server_health(self) -> bool: + """Check if server is healthy.""" + try: + response = requests.get(self.health_url, timeout=5) + return response.status_code == 200 + except: + return False + + def test_abort_during_non_streaming_generation(self): + """Test aborting a non-streaming request during generation.""" + self.assertTrue(self._check_server_health(), "Server should be healthy") + + request_id = "test_abort_non_streaming" + completion_result = {} + + def run_completion(): + response = self._send_completion_request( + "Write a detailed essay about artificial intelligence", + max_tokens=500, + temperature=1, + request_id=request_id, + stream=False, + ) + + if response.status_code == 200: + result = response.json() + completion_result["text"] = result.get("text", "") + completion_result["finish_reason"] = result.get("meta_info", {}).get( + "finish_reason" + ) + + completion_thread = threading.Thread(target=run_completion) + completion_thread.start() + time.sleep(0.1) + + abort_response = self._send_abort_request(request_id) + completion_thread.join() + + self.assertEqual(abort_response.status_code, 200) + self.assertIsNotNone(completion_result, "Should have completion result") + if completion_result: + finish_reason_obj = completion_result.get("finish_reason") + self.assertIsNotNone(finish_reason_obj, "Should have finish_reason") + if finish_reason_obj: + self.assertEqual( + finish_reason_obj.get("type"), "abort", "Should be aborted" + ) + + def test_batch_requests_with_selective_abort(self): + """Test multiple concurrent requests with selective abort of one request.""" + self.assertTrue(self._check_server_health(), "Server should be healthy") + + request_ids = ["batch_test_0", "batch_test_1", "batch_test_2"] + abort_target_id = "batch_test_1" + completion_results = {} + threads = [] + + def run_completion(req_id, prompt): + response = self._send_completion_request( + f"Write a story about {prompt}", + max_tokens=100, + temperature=0.8, + request_id=req_id, + stream=False, + ) + + if response.status_code == 200: + result = response.json() + completion_results[req_id] = { + "text": result.get("text", ""), + "finish_reason": result.get("meta_info", {}).get("finish_reason"), + } + + # Start all requests + prompts = ["a knight's adventure", "a space discovery", "a chef's restaurant"] + for i, req_id in enumerate(request_ids): + thread = threading.Thread(target=run_completion, args=(req_id, prompts[i])) + threads.append(thread) + thread.start() + + # Abort one request + time.sleep(0.1) + abort_response = self._send_abort_request(abort_target_id) + + # Wait for completion + for thread in threads: + thread.join(timeout=30) + + # Verify results + self.assertEqual(abort_response.status_code, 200) + + # Check aborted request + aborted_result = completion_results.get(abort_target_id) + self.assertIsNotNone( + aborted_result, f"Aborted request {abort_target_id} should have result" + ) + if aborted_result: + aborted_finish_reason = aborted_result.get("finish_reason") + self.assertIsNotNone( + aborted_finish_reason, "Aborted request should have finish_reason" + ) + if aborted_finish_reason: + self.assertEqual(aborted_finish_reason.get("type"), "abort") + + # Check other requests completed normally + normal_completions = 0 + for req_id in request_ids: + if req_id != abort_target_id and req_id in completion_results: + result = completion_results[req_id] + if result: + finish_reason = result.get("finish_reason") + if finish_reason and finish_reason.get("type") == "length": + normal_completions += 1 + + self.assertEqual( + normal_completions, 2, "Other 2 requests should complete normally" + ) + + +if __name__ == "__main__": + unittest.main(verbosity=2, warnings="ignore")