[router] migrate router from actix to axum (#8479)
This commit is contained in:
@@ -1,43 +1,27 @@
|
||||
mod common;
|
||||
|
||||
use actix_web::{http::StatusCode, rt::System, test as actix_test, web, App};
|
||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::server::{
|
||||
add_worker, generate, v1_chat_completions, v1_completions, AppState,
|
||||
};
|
||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Test context for request type testing
|
||||
struct RequestTestContext {
|
||||
/// Test context that manages mock workers
|
||||
struct TestContext {
|
||||
workers: Vec<MockWorker>,
|
||||
app_state: web::Data<AppState>,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
}
|
||||
|
||||
impl RequestTestContext {
|
||||
impl TestContext {
|
||||
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
|
||||
let mut workers = Vec::new();
|
||||
let mut worker_urls = Vec::new();
|
||||
|
||||
// Start mock workers
|
||||
for config in worker_configs {
|
||||
let mut worker = MockWorker::new(config);
|
||||
let url = worker.start().await.unwrap();
|
||||
worker_urls.push(url);
|
||||
workers.push(worker);
|
||||
}
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
||||
|
||||
// Create router config
|
||||
let config = RouterConfig {
|
||||
let mut config = RouterConfig {
|
||||
mode: RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
},
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 3006,
|
||||
port: 3003,
|
||||
max_payload_size: 256 * 1024 * 1024,
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 1,
|
||||
@@ -49,528 +33,348 @@ impl RequestTestContext {
|
||||
log_dir: None,
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
};
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(config.request_timeout_secs))
|
||||
.build()
|
||||
.unwrap();
|
||||
let mut workers = Vec::new();
|
||||
let mut worker_urls = Vec::new();
|
||||
|
||||
let app_state = AppState::new(config, client).unwrap();
|
||||
let app_state = web::Data::new(app_state);
|
||||
|
||||
// Add workers via HTTP API
|
||||
let app =
|
||||
actix_test::init_service(App::new().app_data(app_state.clone()).service(add_worker))
|
||||
.await;
|
||||
|
||||
for url in &worker_urls {
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri(&format!("/add_worker?url={}", url))
|
||||
.to_request();
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert!(resp.status().is_success());
|
||||
for worker_config in worker_configs {
|
||||
let mut worker = MockWorker::new(worker_config);
|
||||
let url = worker.start().await.unwrap();
|
||||
worker_urls.push(url);
|
||||
workers.push(worker);
|
||||
}
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||
if !workers.is_empty() {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
||||
}
|
||||
|
||||
Self { workers, app_state }
|
||||
}
|
||||
config.mode = RoutingMode::Regular { worker_urls };
|
||||
|
||||
async fn create_app(
|
||||
&self,
|
||||
) -> impl actix_web::dev::Service<
|
||||
actix_http::Request,
|
||||
Response = actix_web::dev::ServiceResponse,
|
||||
Error = actix_web::Error,
|
||||
> {
|
||||
actix_test::init_service(
|
||||
App::new()
|
||||
.app_data(self.app_state.clone())
|
||||
.service(generate)
|
||||
.service(v1_chat_completions)
|
||||
.service(v1_completions),
|
||||
)
|
||||
.await
|
||||
let router = tokio::task::spawn_blocking(move || RouterFactory::create_router(&config))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let router = Arc::from(router);
|
||||
|
||||
if !workers.is_empty() {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||
}
|
||||
|
||||
Self { workers, router }
|
||||
}
|
||||
|
||||
async fn shutdown(mut self) {
|
||||
// Small delay to ensure any pending operations complete
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
for worker in &mut self.workers {
|
||||
worker.stop().await;
|
||||
}
|
||||
|
||||
// Another small delay to ensure cleanup completes
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
async fn make_request(
|
||||
&self,
|
||||
endpoint: &str,
|
||||
body: serde_json::Value,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let client = Client::new();
|
||||
|
||||
// Get any worker URL for testing
|
||||
let worker_urls = self.router.get_worker_urls();
|
||||
if worker_urls.is_empty() {
|
||||
return Err("No available workers".to_string());
|
||||
}
|
||||
|
||||
let worker_url = &worker_urls[0];
|
||||
|
||||
let response = client
|
||||
.post(&format!("{}{}", worker_url, endpoint))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Request failed: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("Request failed with status: {}", response.status()));
|
||||
}
|
||||
|
||||
response
|
||||
.json::<serde_json::Value>()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse response: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod generate_input_format_tests {
|
||||
mod request_format_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_with_text_input() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21001,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
#[tokio::test]
|
||||
async fn test_generate_request_formats() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19001,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
// Standard text input
|
||||
let payload = json!({
|
||||
"text": "Hello world",
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/generate")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body: serde_json::Value = actix_test::read_body_json(resp).await;
|
||||
assert!(body.get("text").is_some());
|
||||
|
||||
ctx.shutdown().await;
|
||||
// Test 1: Basic text request
|
||||
let payload = json!({
|
||||
"text": "Hello, world!",
|
||||
"stream": false
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_with_prompt_input() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21002,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
// Prompt input (alternative to text)
|
||||
let payload = json!({
|
||||
"prompt": "Once upon a time",
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/generate")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
ctx.shutdown().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_with_input_ids() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21003,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
// Input IDs (tokenized input)
|
||||
let payload = json!({
|
||||
"input_ids": [1, 2, 3, 4, 5],
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/generate")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
ctx.shutdown().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_with_all_parameters() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21004,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
// All generation parameters
|
||||
let payload = json!({
|
||||
"text": "Complete this",
|
||||
// Test 2: Request with sampling parameters
|
||||
let payload = json!({
|
||||
"text": "Tell me a story",
|
||||
"sampling_params": {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 50,
|
||||
"max_new_tokens": 100,
|
||||
"min_new_tokens": 10,
|
||||
"frequency_penalty": 0.5,
|
||||
"presence_penalty": 0.3,
|
||||
"repetition_penalty": 1.1,
|
||||
"stop": [".", "!", "?"],
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/generate")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
ctx.shutdown().await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod chat_completion_format_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chat_with_system_message() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21010,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/v1/chat/completions")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
ctx.shutdown().await;
|
||||
});
|
||||
}
|
||||
|
||||
// Note: Function calling and tools tests are commented out because
|
||||
// they require special handling in the mock worker that's not implemented yet.
|
||||
// In production, these would be forwarded to the actual model.
|
||||
|
||||
// #[test]
|
||||
// fn test_chat_with_function_calling() {
|
||||
// // Test would go here when mock worker supports function calling
|
||||
// }
|
||||
|
||||
// #[test]
|
||||
// fn test_chat_with_tools() {
|
||||
// // Test would go here when mock worker supports tools
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn test_chat_with_response_format() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21013,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Return JSON"}
|
||||
],
|
||||
"response_format": {
|
||||
"type": "json_object"
|
||||
}
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/v1/chat/completions")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
ctx.shutdown().await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod completion_format_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_completion_with_single_prompt() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21020,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": "Once upon a time",
|
||||
"max_tokens": 50
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/v1/completions")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body: serde_json::Value = actix_test::read_body_json(resp).await;
|
||||
assert!(body.get("choices").is_some());
|
||||
|
||||
ctx.shutdown().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_completion_with_batch_prompts() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21021,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": ["First prompt", "Second prompt", "Third prompt"],
|
||||
"max_tokens": 30
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/v1/completions")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
ctx.shutdown().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_completion_with_echo() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21022,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": "Echo this prompt",
|
||||
"echo": true,
|
||||
"max_tokens": 20
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/v1/completions")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
ctx.shutdown().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_completion_with_logprobs() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21023,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": "Calculate probability",
|
||||
"logprobs": 5,
|
||||
"max_tokens": 10
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/v1/completions")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
ctx.shutdown().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_completion_with_suffix() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21024,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": "Insert text here: ",
|
||||
"suffix": " and continue from here.",
|
||||
"max_tokens": 20
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/v1/completions")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
ctx.shutdown().await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod stop_sequence_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_stop_sequences_array() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21030,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"text": "Generate until stop",
|
||||
"stop": [".", "!", "?", "\n"],
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/generate")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
ctx.shutdown().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stop_sequences_string() {
|
||||
System::new().block_on(async {
|
||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
||||
port: 21031,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"text": "Generate until stop",
|
||||
"stop": "\n\n",
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let req = actix_test::TestRequest::post()
|
||||
.uri("/generate")
|
||||
.set_json(&payload)
|
||||
.to_request();
|
||||
|
||||
let resp = actix_test::call_service(&app, req).await;
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
ctx.shutdown().await;
|
||||
"top_p": 0.9
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test 3: Request with input_ids
|
||||
let payload = json!({
|
||||
"input_ids": [1, 2, 3, 4, 5],
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 50
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_v1_chat_completions_formats() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19002,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Test 1: Basic chat completion
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"}
|
||||
],
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/v1/chat/completions", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let response = result.unwrap();
|
||||
assert!(response.get("choices").is_some());
|
||||
assert!(response.get("id").is_some());
|
||||
assert_eq!(
|
||||
response.get("object").and_then(|v| v.as_str()),
|
||||
Some("chat.completion")
|
||||
);
|
||||
|
||||
// Test 2: Chat completion with parameters
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Tell me a joke"}
|
||||
],
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 150,
|
||||
"top_p": 0.95,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/v1/chat/completions", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_v1_completions_formats() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19003,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Test 1: Basic completion
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": "Once upon a time",
|
||||
"max_tokens": 50,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/v1/completions", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let response = result.unwrap();
|
||||
assert!(response.get("choices").is_some());
|
||||
assert_eq!(
|
||||
response.get("object").and_then(|v| v.as_str()),
|
||||
Some("text_completion")
|
||||
);
|
||||
|
||||
// Test 2: Completion with array prompt
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": ["First prompt", "Second prompt"],
|
||||
"temperature": 0.5,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/v1/completions", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test 3: Completion with logprobs
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": "The capital of France is",
|
||||
"max_tokens": 10,
|
||||
"logprobs": 5,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/v1/completions", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_requests() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19004,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Test batch text generation
|
||||
let payload = json!({
|
||||
"text": ["First text", "Second text", "Third text"],
|
||||
"sampling_params": {
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 50
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test batch with input_ids
|
||||
let payload = json!({
|
||||
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_special_parameters() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19005,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Test with return_logprob
|
||||
let payload = json!({
|
||||
"text": "Test",
|
||||
"return_logprob": true,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test with json_schema
|
||||
let payload = json!({
|
||||
"text": "Generate JSON",
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"json_schema": "$$ANY$$"
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test with ignore_eos
|
||||
let payload = json!({
|
||||
"text": "Continue forever",
|
||||
"sampling_params": {
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 100,
|
||||
"ignore_eos": true
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_error_handling() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19006,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Test with empty body - should still work with mock worker
|
||||
let payload = json!({});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
// Mock worker accepts empty body
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user