Add more api routes (completion, health, etc) to the router (#2146)
This commit is contained in:
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "sglang-router"
|
name = "sglang-router"
|
||||||
version = "0.0.5"
|
version = "0.0.6"
|
||||||
description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances."
|
description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances."
|
||||||
authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}]
|
authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}]
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
|
|||||||
@@ -97,14 +97,27 @@ pub enum PolicyConfig {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_text_from_request(body: &Bytes) -> String {
|
fn get_text_from_request(body: &Bytes, route: &str) -> String {
|
||||||
// 1. convert body to json
|
// convert body to json
|
||||||
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
|
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
|
||||||
// 2. get the text field
|
|
||||||
|
if route == "generate" {
|
||||||
|
// get the "text" field
|
||||||
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
||||||
return text.to_string();
|
return text.to_string();
|
||||||
|
} else if route == "v1/chat/completions" {
|
||||||
|
// get the messages field as raw text
|
||||||
|
if let Some(messages) = json.get("messages") {
|
||||||
|
// Convert messages back to a string, preserving all JSON formatting
|
||||||
|
return serde_json::to_string(messages).unwrap_or_default();
|
||||||
|
}
|
||||||
|
} else if route == "v1/completions" {
|
||||||
|
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
|
||||||
|
return prompt.to_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return "".to_string();
|
||||||
|
}
|
||||||
impl Router {
|
impl Router {
|
||||||
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
|
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
|
||||||
match policy_config {
|
match policy_config {
|
||||||
@@ -187,8 +200,11 @@ impl Router {
|
|||||||
client: &reqwest::Client,
|
client: &reqwest::Client,
|
||||||
req: HttpRequest,
|
req: HttpRequest,
|
||||||
body: Bytes,
|
body: Bytes,
|
||||||
|
route: &str,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
let text = get_text_from_request(&body);
|
let text = get_text_from_request(&body, route);
|
||||||
|
// For Debug
|
||||||
|
// println!("text: {:?}, route: {:?}", text, route);
|
||||||
|
|
||||||
let worker_url = match self {
|
let worker_url = match self {
|
||||||
Router::RoundRobin {
|
Router::RoundRobin {
|
||||||
@@ -236,13 +252,14 @@ impl Router {
|
|||||||
if matched_rate > *cache_threshold {
|
if matched_rate > *cache_threshold {
|
||||||
matched_worker.to_string()
|
matched_worker.to_string()
|
||||||
} else {
|
} else {
|
||||||
let m_map: HashMap<String, usize> = tree
|
// For Debug
|
||||||
.tenant_char_count
|
// let m_map: HashMap<String, usize> = tree
|
||||||
.iter()
|
// .tenant_char_count
|
||||||
.map(|entry| (entry.key().clone(), *entry.value()))
|
// .iter()
|
||||||
.collect();
|
// .map(|entry| (entry.key().clone(), *entry.value()))
|
||||||
|
// .collect();
|
||||||
|
|
||||||
println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
|
// println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
|
||||||
|
|
||||||
tree.get_smallest_tenant()
|
tree.get_smallest_tenant()
|
||||||
}
|
}
|
||||||
@@ -276,7 +293,7 @@ impl Router {
|
|||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
let res = match client
|
let res = match client
|
||||||
.post(format!("{}/generate", worker_url.clone()))
|
.post(format!("{}/{}", worker_url.clone(), route))
|
||||||
.header(
|
.header(
|
||||||
"Content-Type",
|
"Content-Type",
|
||||||
req.headers()
|
req.headers()
|
||||||
|
|||||||
@@ -33,7 +33,10 @@ async fn forward_request(
|
|||||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
// print the status
|
// print the status
|
||||||
println!("Worker URL: {}, Status: {}", worker_url, status);
|
println!(
|
||||||
|
"Forwarding Request Worker URL: {}, Route: {}, Status: {}",
|
||||||
|
worker_url, route, status
|
||||||
|
);
|
||||||
match res.bytes().await {
|
match res.bytes().await {
|
||||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||||
Err(_) => HttpResponse::InternalServerError().finish(),
|
Err(_) => HttpResponse::InternalServerError().finish(),
|
||||||
@@ -43,8 +46,38 @@ async fn forward_request(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[get("/health")]
|
||||||
|
async fn health(data: web::Data<AppState>) -> impl Responder {
|
||||||
|
let worker_url = match data.router.get_first() {
|
||||||
|
Some(url) => url,
|
||||||
|
None => return HttpResponse::InternalServerError().finish(),
|
||||||
|
};
|
||||||
|
|
||||||
|
forward_request(&data.client, worker_url, "/health".to_string()).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/health_generate")]
|
||||||
|
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
|
||||||
|
let worker_url = match data.router.get_first() {
|
||||||
|
Some(url) => url,
|
||||||
|
None => return HttpResponse::InternalServerError().finish(),
|
||||||
|
};
|
||||||
|
|
||||||
|
forward_request(&data.client, worker_url, "/health_generate".to_string()).await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/get_server_args")]
|
||||||
|
async fn get_server_args(data: web::Data<AppState>) -> impl Responder {
|
||||||
|
let worker_url = match data.router.get_first() {
|
||||||
|
Some(url) => url,
|
||||||
|
None => return HttpResponse::InternalServerError().finish(),
|
||||||
|
};
|
||||||
|
|
||||||
|
forward_request(&data.client, worker_url, "/get_server_args".to_string()).await
|
||||||
|
}
|
||||||
|
|
||||||
#[get("/v1/models")]
|
#[get("/v1/models")]
|
||||||
async fn v1_model(data: web::Data<AppState>) -> impl Responder {
|
async fn v1_models(data: web::Data<AppState>) -> impl Responder {
|
||||||
let worker_url = match data.router.get_first() {
|
let worker_url = match data.router.get_first() {
|
||||||
Some(url) => url,
|
Some(url) => url,
|
||||||
None => return HttpResponse::InternalServerError().finish(),
|
None => return HttpResponse::InternalServerError().finish(),
|
||||||
@@ -65,7 +98,31 @@ async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
|
|||||||
|
|
||||||
#[post("/generate")]
|
#[post("/generate")]
|
||||||
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
|
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
|
||||||
data.router.dispatch(&data.client, req, body).await
|
data.router
|
||||||
|
.dispatch(&data.client, req, body, "generate")
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/v1/chat/completions")]
|
||||||
|
async fn v1_chat_completions(
|
||||||
|
req: HttpRequest,
|
||||||
|
body: Bytes,
|
||||||
|
data: web::Data<AppState>,
|
||||||
|
) -> impl Responder {
|
||||||
|
data.router
|
||||||
|
.dispatch(&data.client, req, body, "v1/chat/completions")
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/v1/completions")]
|
||||||
|
async fn v1_completions(
|
||||||
|
req: HttpRequest,
|
||||||
|
body: Bytes,
|
||||||
|
data: web::Data<AppState>,
|
||||||
|
) -> impl Responder {
|
||||||
|
data.router
|
||||||
|
.dispatch(&data.client, req, body, "v1/completions")
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn startup(
|
pub async fn startup(
|
||||||
@@ -90,8 +147,13 @@ pub async fn startup(
|
|||||||
App::new()
|
App::new()
|
||||||
.app_data(app_state.clone())
|
.app_data(app_state.clone())
|
||||||
.service(generate)
|
.service(generate)
|
||||||
.service(v1_model)
|
.service(v1_chat_completions)
|
||||||
|
.service(v1_completions)
|
||||||
|
.service(v1_models)
|
||||||
.service(get_model_info)
|
.service(get_model_info)
|
||||||
|
.service(health)
|
||||||
|
.service(health_generate)
|
||||||
|
.service(get_server_args)
|
||||||
})
|
})
|
||||||
.bind((host, port))?
|
.bind((host, port))?
|
||||||
.run()
|
.run()
|
||||||
|
|||||||
Reference in New Issue
Block a user