""" Integration test for abort_request functionality with a SGLang server. Run with: python -m unittest sglang.test.srt.entrypoints.http_server.test_abort_request -v """ import threading import time import unittest from typing import Optional import requests from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, popen_launch_server, ) class TestAbortRequest(CustomTestCase): """Integration test class for abort request functionality.""" model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST base_url = DEFAULT_URL_FOR_TEST @classmethod def setUpClass(cls): """Launch the server.""" cls.process = popen_launch_server( cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=["--disable-cuda-graph"], ) cls.completion_url = f"{cls.base_url}/generate" cls.abort_url = f"{cls.base_url}/abort_request" cls.health_url = f"{cls.base_url}/health" print(f"Server started at {cls.base_url}") @classmethod def tearDownClass(cls): """Clean up the server.""" kill_process_tree(cls.process.pid) def _send_completion_request( self, text: str, request_id: str, max_tokens: int = 50, temperature: float = 0.8, stream: bool = True, ) -> requests.Response: """Send a completion request to the server.""" payload = { "text": text, "sampling_params": { "max_new_tokens": max_tokens, "temperature": temperature, }, "stream": stream, "rid": request_id, } response = requests.post( self.completion_url, json=payload, headers={"Content-Type": "application/json"}, timeout=30, stream=stream, ) return response def _send_abort_request(self, request_id: str) -> requests.Response: """Send an abort request.""" payload = {"rid": request_id} return requests.post(self.abort_url, json=payload, timeout=10) def _check_server_health(self) -> bool: """Check if server is healthy.""" try: response = requests.get(self.health_url, timeout=5) return response.status_code == 200 except: return False def test_abort_during_non_streaming_generation(self): """Test aborting a non-streaming request during generation.""" self.assertTrue(self._check_server_health(), "Server should be healthy") request_id = "test_abort_non_streaming" completion_result = {} def run_completion(): response = self._send_completion_request( "Write a detailed essay about artificial intelligence", max_tokens=500, temperature=1, request_id=request_id, stream=False, ) if response.status_code == 200: result = response.json() completion_result["text"] = result.get("text", "") completion_result["finish_reason"] = result.get("meta_info", {}).get( "finish_reason" ) completion_thread = threading.Thread(target=run_completion) completion_thread.start() time.sleep(0.1) abort_response = self._send_abort_request(request_id) completion_thread.join() self.assertEqual(abort_response.status_code, 200) self.assertIsNotNone(completion_result, "Should have completion result") if completion_result: finish_reason_obj = completion_result.get("finish_reason") self.assertIsNotNone(finish_reason_obj, "Should have finish_reason") if finish_reason_obj: self.assertEqual( finish_reason_obj.get("type"), "abort", "Should be aborted" ) def test_batch_requests_with_selective_abort(self): """Test multiple concurrent requests with selective abort of one request.""" self.assertTrue(self._check_server_health(), "Server should be healthy") request_ids = ["batch_test_0", "batch_test_1", "batch_test_2"] abort_target_id = "batch_test_1" completion_results = {} threads = [] def run_completion(req_id, prompt): response = self._send_completion_request( f"Write a story about {prompt}", max_tokens=100, temperature=0.8, request_id=req_id, stream=False, ) if response.status_code == 200: result = response.json() completion_results[req_id] = { "text": result.get("text", ""), "finish_reason": result.get("meta_info", {}).get("finish_reason"), } # Start all requests prompts = ["a knight's adventure", "a space discovery", "a chef's restaurant"] for i, req_id in enumerate(request_ids): thread = threading.Thread(target=run_completion, args=(req_id, prompts[i])) threads.append(thread) thread.start() # Abort one request time.sleep(0.1) abort_response = self._send_abort_request(abort_target_id) # Wait for completion for thread in threads: thread.join(timeout=30) # Verify results self.assertEqual(abort_response.status_code, 200) # Check aborted request aborted_result = completion_results.get(abort_target_id) self.assertIsNotNone( aborted_result, f"Aborted request {abort_target_id} should have result" ) if aborted_result: aborted_finish_reason = aborted_result.get("finish_reason") self.assertIsNotNone( aborted_finish_reason, "Aborted request should have finish_reason" ) if aborted_finish_reason: self.assertEqual(aborted_finish_reason.get("type"), "abort") # Check other requests completed normally normal_completions = 0 for req_id in request_ids: if req_id != abort_target_id and req_id in completion_results: result = completion_results[req_id] if result: finish_reason = result.get("finish_reason") if finish_reason and finish_reason.get("type") == "length": normal_completions += 1 self.assertEqual( normal_completions, 2, "Other 2 requests should complete normally" ) if __name__ == "__main__": unittest.main(verbosity=2, warnings="ignore")