[router] Forward all request headers from router to workers (#3070)
This commit is contained in:
@@ -1,5 +1,14 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Check if sudo is available
|
||||||
|
if command -v sudo >/dev/null 2>&1; then
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y lsof
|
||||||
|
else
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y lsof
|
||||||
|
fi
|
||||||
|
|
||||||
# Show current GPU status
|
# Show current GPU status
|
||||||
nvidia-smi
|
nvidia-smi
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ def popen_launch_router(
|
|||||||
timeout: float,
|
timeout: float,
|
||||||
policy: str = "cache_aware",
|
policy: str = "cache_aware",
|
||||||
max_payload_size: int = None,
|
max_payload_size: int = None,
|
||||||
|
api_key: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Launch the router server process.
|
Launch the router server process.
|
||||||
@@ -33,6 +34,7 @@ def popen_launch_router(
|
|||||||
timeout: Server launch timeout
|
timeout: Server launch timeout
|
||||||
policy: Router policy, one of "cache_aware", "round_robin", "random"
|
policy: Router policy, one of "cache_aware", "round_robin", "random"
|
||||||
max_payload_size: Maximum payload size in bytes
|
max_payload_size: Maximum payload size in bytes
|
||||||
|
api_key: API key for the router
|
||||||
"""
|
"""
|
||||||
_, host, port = base_url.split(":")
|
_, host, port = base_url.split(":")
|
||||||
host = host[2:]
|
host = host[2:]
|
||||||
@@ -55,6 +57,9 @@ def popen_launch_router(
|
|||||||
policy,
|
policy,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if api_key is not None:
|
||||||
|
command.extend(["--api-key", api_key])
|
||||||
|
|
||||||
if max_payload_size is not None:
|
if max_payload_size is not None:
|
||||||
command.extend(["--router-max-payload-size", str(max_payload_size)])
|
command.extend(["--router-max-payload-size", str(max_payload_size)])
|
||||||
|
|
||||||
@@ -333,6 +338,57 @@ class TestLaunchServer(unittest.TestCase):
|
|||||||
f"1.2MB payload should fail with 413 but got status {response.status_code}",
|
f"1.2MB payload should fail with 413 but got status {response.status_code}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_5_api_key(self):
|
||||||
|
print("Running test_5_api_key...")
|
||||||
|
|
||||||
|
self.process = popen_launch_router(
|
||||||
|
self.model,
|
||||||
|
self.base_url,
|
||||||
|
dp_size=1,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
policy="round_robin",
|
||||||
|
api_key="correct_api_key",
|
||||||
|
)
|
||||||
|
|
||||||
|
# # Test case 1: request without api key should fail
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = session.post(
|
||||||
|
f"{self.base_url}/generate",
|
||||||
|
json={"text": "Kanye west is, ", "temperature": 0},
|
||||||
|
)
|
||||||
|
print(f"status code: {response.status_code}, response: {response.text}")
|
||||||
|
self.assertEqual(
|
||||||
|
response.status_code,
|
||||||
|
401,
|
||||||
|
"Request without api key should fail with 401",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test case 2: request with invalid api key should fail
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/generate",
|
||||||
|
json={"text": "Kanye west is, ", "temperature": 0},
|
||||||
|
headers={"Authorization": "Bearer 123"},
|
||||||
|
)
|
||||||
|
print(f"status code: {response.status_code}, response: {response.text}")
|
||||||
|
self.assertEqual(
|
||||||
|
response.status_code,
|
||||||
|
401,
|
||||||
|
"Request with invalid api key should fail with 401",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test case 3: request with correct api key should succeed
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = session.post(
|
||||||
|
f"{self.base_url}/generate",
|
||||||
|
json={"text": "Kanye west is ", "temperature": 0},
|
||||||
|
headers={"Authorization": "Bearer correct_api_key"},
|
||||||
|
)
|
||||||
|
print(f"status code: {response.status_code}, response: {response.text}")
|
||||||
|
self.assertEqual(
|
||||||
|
response.status_code, 200, "Request with correct api key should succeed"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -12,6 +12,18 @@ use std::thread;
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio;
|
use tokio;
|
||||||
|
|
||||||
|
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
||||||
|
req.headers()
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(name, value)| {
|
||||||
|
value
|
||||||
|
.to_str()
|
||||||
|
.ok()
|
||||||
|
.map(|v| (name.to_string(), v.to_string()))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Router {
|
pub enum Router {
|
||||||
RoundRobin {
|
RoundRobin {
|
||||||
@@ -303,8 +315,18 @@ impl Router {
|
|||||||
client: &reqwest::Client,
|
client: &reqwest::Client,
|
||||||
worker_url: &str,
|
worker_url: &str,
|
||||||
route: &str,
|
route: &str,
|
||||||
|
req: &HttpRequest,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
match client.get(format!("{}{}", worker_url, route)).send().await {
|
let mut request_builder = client.get(format!("{}{}", worker_url, route));
|
||||||
|
|
||||||
|
// Copy all headers from original request except for /health because it does not need authorization
|
||||||
|
if route != "/health" {
|
||||||
|
for (name, value) in copy_request_headers(req) {
|
||||||
|
request_builder = request_builder.header(name, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match request_builder.send().await {
|
||||||
Ok(res) => {
|
Ok(res) => {
|
||||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
@@ -322,7 +344,12 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse {
|
pub async fn route_to_first(
|
||||||
|
&self,
|
||||||
|
client: &reqwest::Client,
|
||||||
|
route: &str,
|
||||||
|
req: &HttpRequest,
|
||||||
|
) -> HttpResponse {
|
||||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||||
let mut total_retries = 0;
|
let mut total_retries = 0;
|
||||||
@@ -338,10 +365,17 @@ impl Router {
|
|||||||
info!("Retrying request after {} failed attempts", total_retries);
|
info!("Retrying request after {} failed attempts", total_retries);
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = self.send_request(client, &worker_url, route).await;
|
let response = self.send_request(client, &worker_url, route, req).await;
|
||||||
|
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
return response;
|
return response;
|
||||||
|
} else {
|
||||||
|
// if the worker is healthy, it means the request is bad, so return the error response
|
||||||
|
let health_response =
|
||||||
|
self.send_request(client, &worker_url, "/health", req).await;
|
||||||
|
if health_response.status().is_success() {
|
||||||
|
return response;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
warn!(
|
warn!(
|
||||||
@@ -496,19 +530,16 @@ impl Router {
|
|||||||
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
let res = match client
|
let mut request_builder = client
|
||||||
.post(format!("{}{}", worker_url, route))
|
.post(format!("{}{}", worker_url, route))
|
||||||
.header(
|
.body(body.to_vec());
|
||||||
"Content-Type",
|
|
||||||
req.headers()
|
// Copy all headers from original request
|
||||||
.get("Content-Type")
|
for (name, value) in copy_request_headers(req) {
|
||||||
.and_then(|h| h.to_str().ok())
|
request_builder = request_builder.header(name, value);
|
||||||
.unwrap_or("application/json"),
|
}
|
||||||
)
|
|
||||||
.body(body.to_vec())
|
let res = match request_builder.send().await {
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(res) => res,
|
Ok(res) => res,
|
||||||
Err(_) => return HttpResponse::InternalServerError().finish(),
|
Err(_) => return HttpResponse::InternalServerError().finish(),
|
||||||
};
|
};
|
||||||
@@ -596,6 +627,13 @@ impl Router {
|
|||||||
|
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
return response;
|
return response;
|
||||||
|
} else {
|
||||||
|
// if the worker is healthy, it means the request is bad, so return the error response
|
||||||
|
let health_response =
|
||||||
|
self.send_request(client, &worker_url, "/health", req).await;
|
||||||
|
if health_response.status().is_success() {
|
||||||
|
return response;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
warn!(
|
warn!(
|
||||||
|
|||||||
@@ -26,33 +26,37 @@ impl AppState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[get("/health")]
|
#[get("/health")]
|
||||||
async fn health(data: web::Data<AppState>) -> impl Responder {
|
async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||||
data.router.route_to_first(&data.client, "/health").await
|
data.router
|
||||||
|
.route_to_first(&data.client, "/health", &req)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/health_generate")]
|
#[get("/health_generate")]
|
||||||
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
|
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||||
data.router
|
data.router
|
||||||
.route_to_first(&data.client, "/health_generate")
|
.route_to_first(&data.client, "/health_generate", &req)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/get_server_info")]
|
#[get("/get_server_info")]
|
||||||
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
|
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||||
data.router
|
data.router
|
||||||
.route_to_first(&data.client, "/get_server_info")
|
.route_to_first(&data.client, "/get_server_info", &req)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/v1/models")]
|
#[get("/v1/models")]
|
||||||
async fn v1_models(data: web::Data<AppState>) -> impl Responder {
|
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||||
data.router.route_to_first(&data.client, "/v1/models").await
|
data.router
|
||||||
|
.route_to_first(&data.client, "/v1/models", &req)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/get_model_info")]
|
#[get("/get_model_info")]
|
||||||
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
|
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||||
data.router
|
data.router
|
||||||
.route_to_first(&data.client, "/get_model_info")
|
.route_to_first(&data.client, "/get_model_info", &req)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user