diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index b7e294d28..aae285347 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -4,11 +4,11 @@ on: push: branches: [ main ] paths: - - "rust/*" + - "rust/**" pull_request: branches: [ main ] paths: - - "rust/*" + - "rust/**" workflow_dispatch: concurrency: diff --git a/rust/src/router.rs b/rust/src/router.rs index 9d42cc13f..29db6e37c 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -1,93 +1,109 @@ -// src/router.rs - +use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; +use actix_web::{HttpRequest, HttpResponse}; +use bytes::Bytes; +use futures_util::TryStreamExt; use std::fmt::Debug; -/// Generic Router trait that can be implemented with different policies -pub trait Router: Send + Sync + Debug { - /// Select a worker URL based on the implementation's policy - /// Returns None if no worker is available - fn select(&self) -> Option; - - // get first worker - fn get_first(&self) -> Option; -} - -// Round Robin Router #[derive(Debug)] -pub struct RoundRobinRouter { - worker_urls: Vec, - current_index: std::sync::atomic::AtomicUsize, // AtomicUsize is a thread-safe integer +pub enum Router { + RoundRobin { + worker_urls: Vec, + current_index: std::sync::atomic::AtomicUsize, + }, + Random { + worker_urls: Vec, + }, } -impl RoundRobinRouter { - pub fn new(worker_urls: Vec) -> Self { - Self { - worker_urls, - current_index: std::sync::atomic::AtomicUsize::new(0), +impl Router { + pub fn new(worker_urls: Vec, policy: String) -> Self { + match policy.to_lowercase().as_str() { + "random" => Router::Random { worker_urls }, + "round_robin" => Router::RoundRobin { + worker_urls, + current_index: std::sync::atomic::AtomicUsize::new(0), + }, + _ => panic!( + "Unknown routing policy: {}. The available policies are 'random' and 'round_robin'", + policy + ), + } + } + + pub fn get_first(&self) -> Option { + match self { + Router::RoundRobin { worker_urls, .. } | Router::Random { worker_urls } => { + if worker_urls.is_empty() { + None + } else { + Some(worker_urls[0].clone()) + } + } + } + } + + pub async fn dispatch( + &self, + client: &reqwest::Client, + req: HttpRequest, + body: Bytes, + ) -> HttpResponse { + let worker_url = match self { + Router::RoundRobin { + worker_urls, + current_index, + } => { + current_index + .fetch_update( + std::sync::atomic::Ordering::SeqCst, + std::sync::atomic::Ordering::SeqCst, + |x| Some((x + 1) % worker_urls.len()), + ) + .expect_err("Error updating index in round robin"); + + &worker_urls[current_index.load(std::sync::atomic::Ordering::SeqCst)] + } + Router::Random { worker_urls } => { + &worker_urls[rand::random::() % worker_urls.len()] + } + }; + + // Check if client requested streaming + let is_stream = serde_json::from_slice::(&body) + .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) + .unwrap_or(false); + + let res = match client + .post(format!("{}/generate", worker_url)) + .header( + "Content-Type", + req.headers() + .get("Content-Type") + .and_then(|h| h.to_str().ok()) + .unwrap_or("application/json"), + ) + .body(body.to_vec()) + .send() + .await + { + Ok(res) => res, + Err(_) => return HttpResponse::InternalServerError().finish(), + }; + + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + if !is_stream { + match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(_) => HttpResponse::InternalServerError().finish(), + } + } else { + HttpResponse::build(status) + .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) + .streaming(res.bytes_stream().map_err(|_| { + actix_web::error::ErrorInternalServerError("Failed to read string") + })) } } } - -impl Router for RoundRobinRouter { - fn select(&self) -> Option { - if self.worker_urls.is_empty() { - return None; - } - // Use relaxed because operation order doesn't matter in round robin - let index = self - .current_index - .fetch_add(1, std::sync::atomic::Ordering::Relaxed) - % self.worker_urls.len(); - Some(self.worker_urls[index].clone()) - } - - fn get_first(&self) -> Option { - if self.worker_urls.is_empty() { - return None; - } - Some(self.worker_urls[0].clone()) - } -} - -// Random Router -#[derive(Debug)] -pub struct RandomRouter { - worker_urls: Vec, -} - -impl RandomRouter { - pub fn new(worker_urls: Vec) -> Self { - Self { worker_urls } - } -} - -impl Router for RandomRouter { - fn select(&self) -> Option { - use rand::seq::SliceRandom; - - if self.worker_urls.is_empty() { - return None; - } - - self.worker_urls.choose(&mut rand::thread_rng()).cloned() - } - - fn get_first(&self) -> Option { - if self.worker_urls.is_empty() { - return None; - } - Some(self.worker_urls[0].clone()) - } -} - -// create a router based on routing policy -pub fn create_router(worker_urls: Vec, policy: String) -> Box { - match policy.to_lowercase().as_str() { - "random" => Box::new(RandomRouter::new(worker_urls)), - "round_robin" => Box::new(RoundRobinRouter::new(worker_urls)), - _ => panic!( - "Unknown routing policy: {}. The available policies are 'random' and 'round_robin'", - policy - ), - } -} diff --git a/rust/src/server.rs b/rust/src/server.rs index 1c6c515b4..fec7fae74 100644 --- a/rust/src/server.rs +++ b/rust/src/server.rs @@ -1,38 +1,28 @@ -use crate::router::create_router; use crate::router::Router; -use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; use bytes::Bytes; -use futures_util::StreamExt; #[derive(Debug)] pub struct AppState { - router: Box, + router: Router, client: reqwest::Client, } impl AppState { pub fn new(worker_urls: Vec, policy: String, client: reqwest::Client) -> Self { // Create router based on policy - let router = create_router(worker_urls, policy); + let router = Router::new(worker_urls, policy); Self { router, client } } } -#[get("/v1/models")] -async fn v1_model(data: web::Data) -> impl Responder { - let worker_url = match data.router.get_first() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - // Use the shared client - match data - .client - .get(format!("{}/v1/models", worker_url)) - .send() - .await - { +async fn forward_request( + client: &reqwest::Client, + worker_url: String, + route: String, +) -> HttpResponse { + match client.get(format!("{}{}", worker_url, route)).send().await { Ok(res) => { let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); @@ -48,85 +38,31 @@ async fn v1_model(data: web::Data) -> impl Responder { } } +#[get("/v1/models")] +async fn v1_model(data: web::Data) -> impl Responder { + // TODO: extract forward_to_route + let worker_url = match data.router.get_first() { + Some(url) => url, + None => return HttpResponse::InternalServerError().finish(), + }; + + forward_request(&data.client, worker_url, "/v1/models".to_string()).await +} + #[get("/get_model_info")] async fn get_model_info(data: web::Data) -> impl Responder { let worker_url = match data.router.get_first() { Some(url) => url, None => return HttpResponse::InternalServerError().finish(), }; - // Use the shared client - match data - .client - .get(format!("{}/get_model_info", worker_url)) - .send() - .await - { - Ok(res) => { - let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - // print the status - println!("Worker URL: {}, Status: {}", worker_url, status); - match res.bytes().await { - Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(_) => HttpResponse::InternalServerError().finish(), - } - } - Err(_) => HttpResponse::InternalServerError().finish(), - } + forward_request(&data.client, worker_url, "/get_model_info".to_string()).await } // no deser and ser, just forward and return #[post("/generate")] async fn generate(req: HttpRequest, body: Bytes, data: web::Data) -> impl Responder { - // create a router struct - // TODO: use router abstraction for different policy - let worker_url = match data.router.select() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - - // Check if client requested streaming - let is_stream = serde_json::from_slice::(&body) - .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) - .unwrap_or(false); - - let res = match data - .client - .post(format!("{}/generate", worker_url)) - .header( - "Content-Type", - req.headers() - .get("Content-Type") - .and_then(|h| h.to_str().ok()) - .unwrap_or("application/json"), - ) - .body(body.to_vec()) - .send() - .await - { - Ok(res) => res, - Err(_) => return HttpResponse::InternalServerError().finish(), - }; - - let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - if !is_stream { - match res.bytes().await { - Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(_) => HttpResponse::InternalServerError().finish(), - } - } else { - HttpResponse::build(status) - .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) - .streaming(res.bytes_stream().map(|b| match b { - Ok(b) => Ok::<_, actix_web::Error>(b), - Err(_) => Err(actix_web::error::ErrorInternalServerError( - "Failed to read stream", - )), - })) - } + data.router.dispatch(&data.client, req, body).await } pub async fn startup(