[router] Forward all request headers from router to workers (#3070)

This commit is contained in:
Byron Hsu
2025-01-23 20:30:31 -08:00
committed by GitHub
parent 7bad7e75bf
commit 9a0cc2e90e
4 changed files with 132 additions and 25 deletions

View File

@@ -22,6 +22,7 @@ def popen_launch_router(
timeout: float,
policy: str = "cache_aware",
max_payload_size: int = None,
api_key: str = None,
):
"""
Launch the router server process.
@@ -33,6 +34,7 @@ def popen_launch_router(
timeout: Server launch timeout
policy: Router policy, one of "cache_aware", "round_robin", "random"
max_payload_size: Maximum payload size in bytes
api_key: API key for the router
"""
_, host, port = base_url.split(":")
host = host[2:]
@@ -55,6 +57,9 @@ def popen_launch_router(
policy,
]
if api_key is not None:
command.extend(["--api-key", api_key])
if max_payload_size is not None:
command.extend(["--router-max-payload-size", str(max_payload_size)])
@@ -333,6 +338,57 @@ class TestLaunchServer(unittest.TestCase):
f"1.2MB payload should fail with 413 but got status {response.status_code}",
)
def test_5_api_key(self):
print("Running test_5_api_key...")
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
api_key="correct_api_key",
)
# # Test case 1: request without api key should fail
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is, ", "temperature": 0},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
401,
"Request without api key should fail with 401",
)
# Test case 2: request with invalid api key should fail
with requests.Session() as session:
response = requests.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is, ", "temperature": 0},
headers={"Authorization": "Bearer 123"},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
401,
"Request with invalid api key should fail with 401",
)
# Test case 3: request with correct api key should succeed
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is ", "temperature": 0},
headers={"Authorization": "Bearer correct_api_key"},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code, 200, "Request with correct api key should succeed"
)
if __name__ == "__main__":
unittest.main()