Add more api routes (completion, health, etc) to the router (#2146)

This commit is contained in:
Byron Hsu
2024-11-23 15:10:26 -08:00
committed by GitHub
parent 52f58fc42a
commit bbb81c2457
3 changed files with 98 additions and 19 deletions

View File

@@ -97,14 +97,27 @@ pub enum PolicyConfig {
},
}
fn get_text_from_request(body: &Bytes) -> String {
// 1. convert body to json
fn get_text_from_request(body: &Bytes, route: &str) -> String {
// convert body to json
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
// 2. get the text field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
return text.to_string();
}
if route == "generate" {
// get the "text" field
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
return text.to_string();
} else if route == "v1/chat/completions" {
// get the messages field as raw text
if let Some(messages) = json.get("messages") {
// Convert messages back to a string, preserving all JSON formatting
return serde_json::to_string(messages).unwrap_or_default();
}
} else if route == "v1/completions" {
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
return prompt.to_string();
}
return "".to_string();
}
impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
match policy_config {
@@ -187,8 +200,11 @@ impl Router {
client: &reqwest::Client,
req: HttpRequest,
body: Bytes,
route: &str,
) -> HttpResponse {
let text = get_text_from_request(&body);
let text = get_text_from_request(&body, route);
// For Debug
// println!("text: {:?}, route: {:?}", text, route);
let worker_url = match self {
Router::RoundRobin {
@@ -236,13 +252,14 @@ impl Router {
if matched_rate > *cache_threshold {
matched_worker.to_string()
} else {
let m_map: HashMap<String, usize> = tree
.tenant_char_count
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
// For Debug
// let m_map: HashMap<String, usize> = tree
// .tenant_char_count
// .iter()
// .map(|entry| (entry.key().clone(), *entry.value()))
// .collect();
println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
// println!("map: {:?}, mmap: {:?}", tree.get_tenant_char_count(), m_map);
tree.get_smallest_tenant()
}
@@ -276,7 +293,7 @@ impl Router {
.unwrap_or(false);
let res = match client
.post(format!("{}/generate", worker_url.clone()))
.post(format!("{}/{}", worker_url.clone(), route))
.header(
"Content-Type",
req.headers()