diff --git a/rust/pyproject.toml b/rust/pyproject.toml index 9277bd6af..7365e24dc 100644 --- a/rust/pyproject.toml +++ b/rust/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang-router" -version = "0.0.5" +version = "0.0.6" description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances." authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}] requires-python = ">=3.8" diff --git a/rust/src/router.rs b/rust/src/router.rs index 64738cc57..f1960bca5 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -97,14 +97,27 @@ pub enum PolicyConfig { }, } -fn get_text_from_request(body: &Bytes) -> String { - // 1. convert body to json +fn get_text_from_request(body: &Bytes, route: &str) -> String { + // convert body to json let json = serde_json::from_slice::(body).unwrap(); - // 2. get the text field - let text = json.get("text").and_then(|t| t.as_str()).unwrap_or(""); - return text.to_string(); -} + if route == "generate" { + // get the "text" field + let text = json.get("text").and_then(|t| t.as_str()).unwrap_or(""); + return text.to_string(); + } else if route == "v1/chat/completions" { + // get the messages field as raw text + if let Some(messages) = json.get("messages") { + // Convert messages back to a string, preserving all JSON formatting + return serde_json::to_string(messages).unwrap_or_default(); + } + } else if route == "v1/completions" { + let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or(""); + return prompt.to_string(); + } + + return "".to_string(); +} impl Router { pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Self { match policy_config { @@ -187,8 +200,11 @@ impl Router { client: &reqwest::Client, req: HttpRequest, body: Bytes, + route: &str, ) -> HttpResponse { - let text = get_text_from_request(&body); + let text = get_text_from_request(&body, route); + // For Debug + // println!("text: {:?}, route: {:?}", text, route); let worker_url = match self { Router::RoundRobin { @@ -236,13 +252,14 @@ impl Router { if matched_rate > *cache_threshold { matched_worker.to_string() } else { - let m_map: HashMap = tree - .tenant_char_count - .iter() - .map(|entry| (entry.key().clone(), *entry.value())) - .collect(); + // For Debug + // let m_map: HashMap = tree + // .tenant_char_count + // .iter() + // .map(|entry| (entry.key().clone(), *entry.value())) + // .collect(); - println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map); + // println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map); tree.get_smallest_tenant() } @@ -276,7 +293,7 @@ impl Router { .unwrap_or(false); let res = match client - .post(format!("{}/generate", worker_url.clone())) + .post(format!("{}/{}", worker_url.clone(), route)) .header( "Content-Type", req.headers() diff --git a/rust/src/server.rs b/rust/src/server.rs index 51df65f97..93dd9e0b9 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -33,7 +33,10 @@ async fn forward_request( .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); // print the status - println!("Worker URL: {}, Status: {}", worker_url, status); + println!( + "Forwarding Request Worker URL: {}, Route: {}, Status: {}", + worker_url, route, status + ); match res.bytes().await { Ok(body) => HttpResponse::build(status).body(body.to_vec()), Err(_) => HttpResponse::InternalServerError().finish(), @@ -43,8 +46,38 @@ async fn forward_request( } } +#[get("/health")] +async fn health(data: web::Data) -> impl Responder { + let worker_url = match data.router.get_first() { + Some(url) => url, + None => return HttpResponse::InternalServerError().finish(), + }; + + forward_request(&data.client, worker_url, "/health".to_string()).await +} + +#[get("/health_generate")] +async fn health_generate(data: web::Data) -> impl Responder { + let worker_url = match data.router.get_first() { + Some(url) => url, + None => return HttpResponse::InternalServerError().finish(), + }; + + forward_request(&data.client, worker_url, "/health_generate".to_string()).await +} + +#[get("/get_server_args")] +async fn get_server_args(data: web::Data) -> impl Responder { + let worker_url = match data.router.get_first() { + Some(url) => url, + None => return HttpResponse::InternalServerError().finish(), + }; + + forward_request(&data.client, worker_url, "/get_server_args".to_string()).await +} + #[get("/v1/models")] -async fn v1_model(data: web::Data) -> impl Responder { +async fn v1_models(data: web::Data) -> impl Responder { let worker_url = match data.router.get_first() { Some(url) => url, None => return HttpResponse::InternalServerError().finish(), @@ -65,7 +98,31 @@ async fn get_model_info(data: web::Data) -> impl Responder { #[post("/generate")] async fn generate(req: HttpRequest, body: Bytes, data: web::Data) -> impl Responder { - data.router.dispatch(&data.client, req, body).await + data.router + .dispatch(&data.client, req, body, "generate") + .await +} + +#[post("/v1/chat/completions")] +async fn v1_chat_completions( + req: HttpRequest, + body: Bytes, + data: web::Data, +) -> impl Responder { + data.router + .dispatch(&data.client, req, body, "v1/chat/completions") + .await +} + +#[post("/v1/completions")] +async fn v1_completions( + req: HttpRequest, + body: Bytes, + data: web::Data, +) -> impl Responder { + data.router + .dispatch(&data.client, req, body, "v1/completions") + .await } pub async fn startup( @@ -90,8 +147,13 @@ pub async fn startup( App::new() .app_data(app_state.clone()) .service(generate) - .service(v1_model) + .service(v1_chat_completions) + .service(v1_completions) + .service(v1_models) .service(get_model_info) + .service(health) + .service(health_generate) + .service(get_server_args) }) .bind((host, port))? .run()