[router] Forward all request headers from router to workers (#3070)
This commit is contained in:
@@ -22,6 +22,7 @@ def popen_launch_router(
|
||||
timeout: float,
|
||||
policy: str = "cache_aware",
|
||||
max_payload_size: int = None,
|
||||
api_key: str = None,
|
||||
):
|
||||
"""
|
||||
Launch the router server process.
|
||||
@@ -33,6 +34,7 @@ def popen_launch_router(
|
||||
timeout: Server launch timeout
|
||||
policy: Router policy, one of "cache_aware", "round_robin", "random"
|
||||
max_payload_size: Maximum payload size in bytes
|
||||
api_key: API key for the router
|
||||
"""
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
@@ -55,6 +57,9 @@ def popen_launch_router(
|
||||
policy,
|
||||
]
|
||||
|
||||
if api_key is not None:
|
||||
command.extend(["--api-key", api_key])
|
||||
|
||||
if max_payload_size is not None:
|
||||
command.extend(["--router-max-payload-size", str(max_payload_size)])
|
||||
|
||||
@@ -333,6 +338,57 @@ class TestLaunchServer(unittest.TestCase):
|
||||
f"1.2MB payload should fail with 413 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
def test_5_api_key(self):
|
||||
print("Running test_5_api_key...")
|
||||
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
api_key="correct_api_key",
|
||||
)
|
||||
|
||||
# # Test case 1: request without api key should fail
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is, ", "temperature": 0},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
401,
|
||||
"Request without api key should fail with 401",
|
||||
)
|
||||
|
||||
# Test case 2: request with invalid api key should fail
|
||||
with requests.Session() as session:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is, ", "temperature": 0},
|
||||
headers={"Authorization": "Bearer 123"},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
401,
|
||||
"Request with invalid api key should fail with 401",
|
||||
)
|
||||
|
||||
# Test case 3: request with correct api key should succeed
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is ", "temperature": 0},
|
||||
headers={"Authorization": "Bearer correct_api_key"},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code, 200, "Request with correct api key should succeed"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user