diff --git a/.github/workflows/pr-test-pd-router.yml b/.github/workflows/pr-test-pd-router.yml index 2a1bde1b4..4c02e835a 100644 --- a/.github/workflows/pr-test-pd-router.yml +++ b/.github/workflows/pr-test-pd-router.yml @@ -155,33 +155,35 @@ jobs: id: start_servers run: | echo "Starting disaggregation servers..." - bash scripts/ci/ci_start_disaggregation_servers.sh & + READY_FILE=".disagg_ready" + rm -f "$READY_FILE" + DISAGG_READY_FILE="$READY_FILE" bash scripts/ci/ci_start_disaggregation_servers.sh & SERVER_PID=$! echo "server_pid=$SERVER_PID" >> $GITHUB_OUTPUT - # Wait for all 8 servers to be healthy (script already does this) - wait_count=0 - while [ $wait_count -lt 30 ]; do - if ps -p $SERVER_PID > /dev/null; then - # Check if the startup script printed success message - sleep 2 - wait_count=$((wait_count + 1)) - else - # Script exited - check if it was successful - wait $SERVER_PID - exit_code=$? - if [ $exit_code -eq 0 ]; then - echo "✓ All disaggregation servers are healthy" - break - else - echo "Error: Server startup failed with code $exit_code" - exit 1 - fi + # Wait until script signals readiness (8/8 healthy) or timeout + TIMEOUT=300 + ELAPSED=0 + while [ $ELAPSED -lt $TIMEOUT ]; do + if [ -f "$READY_FILE" ]; then + echo "✓ All disaggregation servers are healthy (signal detected)" + break fi + if ! ps -p $SERVER_PID > /dev/null; then + echo "Error: server bootstrap script exited prematurely" + exit 1 + fi + sleep 5 + ELAPSED=$((ELAPSED + 5)) done + if [ $ELAPSED -ge $TIMEOUT ]; then + echo "❌ Timeout waiting for disaggregation servers to be healthy" + exit 1 + fi echo "✓ Servers started (PID: $SERVER_PID)" + - name: Test all policies sequentially timeout-minutes: 30 run: | diff --git a/scripts/ci/ci_start_disaggregation_servers.sh b/scripts/ci/ci_start_disaggregation_servers.sh index 56490bb06..bbfdac9d2 100755 --- a/scripts/ci/ci_start_disaggregation_servers.sh +++ b/scripts/ci/ci_start_disaggregation_servers.sh @@ -1,4 +1,9 @@ #!/bin/bash +set -euo pipefail + +# Optional: set DISAGG_READY_FILE to a filepath; when all servers are healthy, the script will +# create this file as a readiness signal (useful for CI to proceed to next steps). +DISAGG_READY_FILE="${DISAGG_READY_FILE:-}" MODEL_PATH="/raid/models/meta-llama/Llama-3.1-8B-Instruct" @@ -81,6 +86,13 @@ while true; do if [ $HEALTHY_COUNT -eq 8 ]; then echo "✅ All 8 servers are healthy!" + # Emit readiness signal file if requested + if [ -n "$DISAGG_READY_FILE" ]; then + echo "Creating readiness flag: $DISAGG_READY_FILE" + # Ensure parent dir exists; ignore errors + mkdir -p "$(dirname "$DISAGG_READY_FILE")" 2>/dev/null || true + touch "$DISAGG_READY_FILE" + fi break else sleep 10 # Wait 10 seconds before next check diff --git a/sgl-router/py_test/e2e/conftest.py b/sgl-router/py_test/e2e/conftest.py index 3acec82b2..c170ec09f 100644 --- a/sgl-router/py_test/e2e/conftest.py +++ b/sgl-router/py_test/e2e/conftest.py @@ -715,6 +715,29 @@ def e2e_router_only_rr(): _terminate(proc) +@pytest.fixture(scope="session") +def e2e_embedding_model() -> str: + """Embedding model to use for E2E tests. + + Defaults to an E5 Mistral model, can be overridden via E2E_EMBEDDING_MODEL env var. + """ + import os + + return os.getenv("E2E_EMBEDDING_MODEL", "intfloat/e5-mistral-7b-instruct") + + +@pytest.fixture +def e2e_primary_embedding_worker(e2e_embedding_model: str): + """Launch a single embedding worker using the specified model.""" + port = _find_available_port() + base_url = f"http://127.0.0.1:{port}" + proc = _popen_launch_worker(e2e_embedding_model, base_url) + try: + yield SimpleNamespace(proc=proc, url=base_url) + finally: + _terminate(proc) + + @pytest.fixture(scope="session") def e2e_primary_worker(e2e_model: str): port = _find_available_port() diff --git a/sgl-router/py_test/e2e/test_e2e_embeddings.py b/sgl-router/py_test/e2e/test_e2e_embeddings.py new file mode 100644 index 000000000..538d4df6f --- /dev/null +++ b/sgl-router/py_test/e2e/test_e2e_embeddings.py @@ -0,0 +1,38 @@ +from types import SimpleNamespace + +import pytest +import requests + + +@pytest.mark.e2e +def test_embeddings_basic( + e2e_router_only_rr, e2e_primary_embedding_worker, e2e_embedding_model +): + base = e2e_router_only_rr.url + worker_url = e2e_primary_embedding_worker.url + + # Attach embedding worker to router-only instance + r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180) + r.raise_for_status() + + # Simple embedding request with two inputs + payload = { + "model": e2e_embedding_model, + "input": [ + "the quick brown fox", + "jumps over the lazy dog", + ], + } + r = requests.post(f"{base}/v1/embeddings", json=payload, timeout=120) + + assert r.status_code == 200, f"unexpected status: {r.status_code} {r.text}" + + data = r.json() + assert "data" in data and isinstance(data["data"], list) + assert len(data["data"]) == 2 + + # Validate shape of embedding objects + for item in data["data"]: + assert "embedding" in item and isinstance(item["embedding"], list) + # Ensure non-empty vectors + assert len(item["embedding"]) > 0 diff --git a/sgl-router/src/metrics.rs b/sgl-router/src/metrics.rs index afcccb549..7235370fe 100644 --- a/sgl-router/src/metrics.rs +++ b/sgl-router/src/metrics.rs @@ -143,6 +143,18 @@ pub fn init_metrics() { "Generate request duration" ); + // Embedding request specific metrics + describe_counter!("sgl_router_embeddings_total", "Total embedding requests"); + describe_histogram!( + "sgl_router_embeddings_duration_seconds", + "Embedding request duration" + ); + describe_counter!( + "sgl_router_embeddings_errors_total", + "Embedding request errors" + ); + describe_gauge!("sgl_router_embeddings_queue_size", "Embedding queue size"); + // Running requests gauge for cache-aware policy describe_gauge!( "sgl_router_running_requests", @@ -440,6 +452,27 @@ impl RouterMetrics { histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64()); } + // Embeddings metrics + pub fn record_embeddings_request() { + counter!("sgl_router_embeddings_total").increment(1); + } + + pub fn record_embeddings_duration(duration: Duration) { + histogram!("sgl_router_embeddings_duration_seconds").record(duration.as_secs_f64()); + } + + pub fn record_embeddings_error(error_type: &str) { + counter!( + "sgl_router_embeddings_errors_total", + "error_type" => error_type.to_string() + ) + .increment(1); + } + + pub fn set_embeddings_queue_size(size: usize) { + gauge!("sgl_router_embeddings_queue_size").set(size as f64); + } + // Running requests for cache-aware policy pub fn set_running_requests(worker: &str, count: usize) { gauge!("sgl_router_running_requests", diff --git a/sgl-router/src/middleware.rs b/sgl-router/src/middleware.rs index abe137572..cadff6878 100644 --- a/sgl-router/src/middleware.rs +++ b/sgl-router/src/middleware.rs @@ -3,6 +3,7 @@ use axum::{ response::IntoResponse, response::Response, }; use rand::Rng; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::Duration; use std::time::Instant; @@ -13,6 +14,7 @@ use tracing::{debug, error, field::Empty, info, info_span, warn, Span}; pub use crate::core::token_bucket::TokenBucket; +use crate::metrics::RouterMetrics; use crate::server::AppState; /// Generate OpenAI-compatible request ID based on endpoint @@ -441,6 +443,11 @@ pub async fn concurrency_limit_middleware( request: Request, next: Next, ) -> Response { + // Static counter for embeddings queue size + static EMBEDDINGS_QUEUE_SIZE: AtomicU64 = AtomicU64::new(0); + + // Identify if this is an embeddings request based on path + let is_embeddings = request.uri().path().contains("/v1/embeddings"); let token_bucket = app_state.context.rate_limiter.clone(); // Try to acquire token immediately @@ -468,10 +475,23 @@ pub async fn concurrency_limit_middleware( // Try to send to queue match queue_tx.try_send(queued) { Ok(_) => { + // On successful enqueue, update embeddings queue gauge if applicable + if is_embeddings { + let new_val = EMBEDDINGS_QUEUE_SIZE.fetch_add(1, Ordering::Relaxed) + 1; + RouterMetrics::set_embeddings_queue_size(new_val as usize); + } + // Wait for token from queue processor match permit_rx.await { Ok(Ok(())) => { debug!("Acquired token from queue"); + // Dequeue for embeddings + if is_embeddings { + let new_val = + EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1; + RouterMetrics::set_embeddings_queue_size(new_val as usize); + } + let response = next.run(request).await; // Return the token to the bucket @@ -481,10 +501,22 @@ pub async fn concurrency_limit_middleware( } Ok(Err(status)) => { warn!("Queue returned error status: {}", status); + // Dequeue for embeddings on error + if is_embeddings { + let new_val = + EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1; + RouterMetrics::set_embeddings_queue_size(new_val as usize); + } status.into_response() } Err(_) => { error!("Queue response channel closed"); + // Dequeue for embeddings on channel error + if is_embeddings { + let new_val = + EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1; + RouterMetrics::set_embeddings_queue_size(new_val as usize); + } StatusCode::INTERNAL_SERVER_ERROR.into_response() } } diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index 583829747..4760626b5 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -41,7 +41,10 @@ use std::collections::HashMap; // 6. **SGLANG SPEC - RERANK API** // - Request/Response structures // -// 7. **COMMON** +// 7. **OPENAI SPEC - Embeddings API** +// - Request structures +// +// 8. **COMMON** // - GenerationRequest trait // - StringOrArray & LoRAPath types // - Helper functions @@ -2013,6 +2016,61 @@ impl RerankResponse { } } +// ================================================================== +// = OPENAI SPEC - Embeddings API = +// ================================================================== + +/// Embeddings request compatible with OpenAI API +/// We intentionally keep fields flexible to pass through to workers. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct EmbeddingRequest { + /// ID of the model to use + pub model: String, + + /// Input can be a string, array of strings, tokens, or batch inputs + pub input: serde_json::Value, + + /// Optional encoding format (e.g., "float", "base64") + #[serde(skip_serializing_if = "Option::is_none")] + pub encoding_format: Option, + + /// Optional user identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// Optional number of dimensions for the embedding + #[serde(skip_serializing_if = "Option::is_none")] + pub dimensions: Option, + + /// SGLang extension: request id for tracking + #[serde(skip_serializing_if = "Option::is_none")] + pub rid: Option, +} + +impl GenerationRequest for EmbeddingRequest { + fn is_stream(&self) -> bool { + // Embeddings are non-streaming + false + } + + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn extract_text_for_routing(&self) -> String { + // Best effort: extract text content for routing decisions + match &self.input { + serde_json::Value::String(s) => s.clone(), + serde_json::Value::Array(arr) => arr + .iter() + .filter_map(|v| v.as_str()) + .collect::>() + .join(" "), + _ => String::new(), + } + } +} + // ================================================================== // = COMMON = // ================================================================== @@ -2715,4 +2773,102 @@ mod tests { assert_eq!(deserialized.results.len(), 2); assert_eq!(deserialized.model, response.model); } + + // ================================================================== + // = EMBEDDINGS REQUEST TESTS = + // ================================================================== + + #[test] + fn test_embedding_request_serialization_string_input() { + let req = EmbeddingRequest { + model: "test-emb".to_string(), + input: serde_json::Value::String("hello".to_string()), + encoding_format: Some("float".to_string()), + user: Some("user-1".to_string()), + dimensions: Some(128), + rid: Some("rid-123".to_string()), + }; + + let serialized = serde_json::to_string(&req).unwrap(); + let deserialized: EmbeddingRequest = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(deserialized.model, req.model); + assert_eq!(deserialized.input, req.input); + assert_eq!(deserialized.encoding_format, req.encoding_format); + assert_eq!(deserialized.user, req.user); + assert_eq!(deserialized.dimensions, req.dimensions); + assert_eq!(deserialized.rid, req.rid); + } + + #[test] + fn test_embedding_request_serialization_array_input() { + let req = EmbeddingRequest { + model: "test-emb".to_string(), + input: serde_json::json!(["a", "b", "c"]), + encoding_format: None, + user: None, + dimensions: None, + rid: None, + }; + + let serialized = serde_json::to_string(&req).unwrap(); + let de: EmbeddingRequest = serde_json::from_str(&serialized).unwrap(); + assert_eq!(de.model, req.model); + assert_eq!(de.input, req.input); + } + + #[test] + fn test_embedding_generation_request_trait_string() { + let req = EmbeddingRequest { + model: "emb-model".to_string(), + input: serde_json::Value::String("hello".to_string()), + encoding_format: None, + user: None, + dimensions: None, + rid: None, + }; + assert!(!req.is_stream()); + assert_eq!(req.get_model(), Some("emb-model")); + assert_eq!(req.extract_text_for_routing(), "hello"); + } + + #[test] + fn test_embedding_generation_request_trait_array() { + let req = EmbeddingRequest { + model: "emb-model".to_string(), + input: serde_json::json!(["hello", "world"]), + encoding_format: None, + user: None, + dimensions: None, + rid: None, + }; + assert_eq!(req.extract_text_for_routing(), "hello world"); + } + + #[test] + fn test_embedding_generation_request_trait_non_text() { + let req = EmbeddingRequest { + model: "emb-model".to_string(), + input: serde_json::json!({"tokens": [1, 2, 3]}), + encoding_format: None, + user: None, + dimensions: None, + rid: None, + }; + assert_eq!(req.extract_text_for_routing(), ""); + } + + #[test] + fn test_embedding_generation_request_trait_mixed_array_ignores_nested() { + let req = EmbeddingRequest { + model: "emb-model".to_string(), + input: serde_json::json!(["a", ["b", "c"], 123, {"k": "v"}]), + encoding_format: None, + user: None, + dimensions: None, + rid: None, + }; + // Only top-level string elements are extracted + assert_eq!(req.extract_text_for_routing(), "a"); + } } diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index 89254c831..4bd9d024d 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -309,7 +309,12 @@ impl RouterTrait for GrpcPDRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } - async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + async fn route_embeddings( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::EmbeddingRequest, + _model_id: Option<&str>, + ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index f631c1358..ff38e3469 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -242,7 +242,12 @@ impl RouterTrait for GrpcRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } - async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + async fn route_embeddings( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::EmbeddingRequest, + _model_id: Option<&str>, + ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index bfe5a1d7b..63392d124 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -395,7 +395,12 @@ impl super::super::RouterTrait for OpenAIRouter { } } - async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + async fn route_embeddings( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::EmbeddingRequest, + _model_id: Option<&str>, + ) -> Response { ( StatusCode::NOT_IMPLEMENTED, "Embeddings endpoint not implemented for OpenAI backend", diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 7b23c298f..a31186177 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -1938,8 +1938,17 @@ impl RouterTrait for PDRouter { .into_response() } - async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { - todo!() + async fn route_embeddings( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::EmbeddingRequest, + _model_id: Option<&str>, + ) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Embeddings endpoint not implemented for PD router", + ) + .into_response() } async fn route_rerank( diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index 6c1eb2554..5e9100a54 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -6,8 +6,8 @@ use crate::core::{ use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, RerankRequest, - RerankResponse, RerankResult, ResponsesRequest, + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest, + RerankRequest, RerankResponse, RerankResult, ResponsesRequest, }; use crate::routers::header_utils; use crate::routers::{RouterTrait, WorkerManagement}; @@ -1430,8 +1430,28 @@ impl RouterTrait for Router { self.route_post_empty_request(headers, &endpoint).await } - async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { - todo!() + async fn route_embeddings( + &self, + headers: Option<&HeaderMap>, + body: &EmbeddingRequest, + model_id: Option<&str>, + ) -> Response { + // Record embeddings-specific metrics in addition to general request metrics + let start = Instant::now(); + let res = self + .route_typed_request(headers, body, "/v1/embeddings", model_id) + .await; + + // Embedding specific metrics + if res.status().is_success() { + RouterMetrics::record_embeddings_request(); + RouterMetrics::record_embeddings_duration(start.elapsed()); + } else { + let error_type = format!("http_{}", res.status().as_u16()); + RouterMetrics::record_embeddings_error(&error_type); + } + + res } async fn route_rerank( diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index d47951a40..ea64c12e1 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -10,7 +10,8 @@ use axum::{ use std::fmt::Debug; use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest, + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, + ResponsesRequest, }; pub mod factory; @@ -123,7 +124,13 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { .into_response() } - async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response; + /// Route embedding requests (OpenAI-compatible /v1/embeddings) + async fn route_embeddings( + &self, + headers: Option<&HeaderMap>, + body: &EmbeddingRequest, + model_id: Option<&str>, + ) -> Response; async fn route_rerank( &self, diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index e6a2053be..fe8e3844b 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -7,7 +7,8 @@ use crate::config::RouterConfig; use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry}; use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest, + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, + ResponsesRequest, }; use crate::protocols::worker_spec::{ ServerInfo, WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse, WorkerInfo, @@ -665,22 +666,6 @@ impl RouterTrait for RouterManager { .into_response() } - async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { - ( - StatusCode::NOT_IMPLEMENTED, - "responses api not yet implemented in inference gateway mode", - ) - .into_response() - } - - async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { - ( - StatusCode::NOT_IMPLEMENTED, - "responses api not yet implemented in inference gateway mode", - ) - .into_response() - } - async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { ( StatusCode::NOT_IMPLEMENTED, @@ -701,17 +686,51 @@ impl RouterTrait for RouterManager { .into_response() } - /// Route embeddings request - async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response { - // Try to select a router based on headers + async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { let router = self.select_router_for_request(headers, None); - if let Some(router) = router { - router.route_embeddings(headers, body).await + router.get_response(headers, response_id).await } else { ( StatusCode::NOT_FOUND, - "No router available for embeddings request", + format!("No router available to get response '{}'", response_id), + ) + .into_response() + } + } + + async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { + let router = self.select_router_for_request(headers, None); + if let Some(router) = router { + router.cancel_response(headers, response_id).await + } else { + ( + StatusCode::NOT_FOUND, + format!("No router available to cancel response '{}'", response_id), + ) + .into_response() + } + } + + /// Route embeddings request + async fn route_embeddings( + &self, + headers: Option<&HeaderMap>, + body: &EmbeddingRequest, + _model_id: Option<&str>, + ) -> Response { + // Select router based on headers and model + let router = self.select_router_for_request(headers, Some(&body.model)); + + if let Some(router) = router { + router + .route_embeddings(headers, body, Some(&body.model)) + .await + } else { + // Return 404 when the specified model is not found + ( + StatusCode::NOT_FOUND, + format!("Model '{}' not found or no router available", body.model), ) .into_response() } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index b6eeafafe..512defdde 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -5,8 +5,8 @@ use crate::metrics::{self, PrometheusConfig}; use crate::middleware::TokenBucket; use crate::policies::PolicyRegistry; use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest, - V1RerankReqInput, + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, + ResponsesRequest, V1RerankReqInput, }; use crate::protocols::worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}; use crate::reasoning_parser::ParserFactory; @@ -208,6 +208,17 @@ async fn v1_responses( .await } +async fn v1_embeddings( + State(state): State>, + headers: http::HeaderMap, + Json(body): Json, +) -> Response { + state + .router + .route_embeddings(Some(&headers), &body, None) + .await +} + async fn v1_responses_get( State(state): State>, Path(response_id): Path, @@ -465,6 +476,7 @@ pub fn build_app( .route("/rerank", post(rerank)) .route("/v1/rerank", post(v1_rerank)) .route("/v1/responses", post(v1_responses)) + .route("/v1/embeddings", post(v1_embeddings)) .route("/v1/responses/{response_id}", get(v1_responses_get)) .route( "/v1/responses/{response_id}/cancel", diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index b7a346338..14b9f2f99 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -1090,10 +1090,14 @@ mod responses_endpoint_tests { let app = ctx.create_app().await; // First create a response to obtain an id + let resp_id = "test-get-resp-id-123"; let payload = json!({ "input": "Hello Responses API", "model": "mock-model", - "stream": false + "stream": false, + "store": true, + "background": true, + "request_id": resp_id }); let req = Request::builder() .method("POST") @@ -1103,11 +1107,6 @@ mod responses_endpoint_tests { .unwrap(); let resp = app.clone().oneshot(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); - let body = axum::body::to_bytes(resp.into_body(), usize::MAX) - .await - .unwrap(); - let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); - let resp_id = body_json["id"].as_str().unwrap().to_string(); // Retrieve the response let req = Request::builder() @@ -1140,10 +1139,14 @@ mod responses_endpoint_tests { let app = ctx.create_app().await; // First create a response to obtain an id + let resp_id = "test-cancel-resp-id-456"; let payload = json!({ "input": "Hello Responses API", "model": "mock-model", - "stream": false + "stream": false, + "store": true, + "background": true, + "request_id": resp_id }); let req = Request::builder() .method("POST") @@ -1153,11 +1156,6 @@ mod responses_endpoint_tests { .unwrap(); let resp = app.clone().oneshot(req).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); - let body = axum::body::to_bytes(resp.into_body(), usize::MAX) - .await - .unwrap(); - let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); - let resp_id = body_json["id"].as_str().unwrap().to_string(); // Cancel the response let req = Request::builder() diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py index b56e952d7..c9dc86f1a 100644 --- a/test/srt/models/test_embedding_models.py +++ b/test/srt/models/test_embedding_models.py @@ -20,7 +20,12 @@ import torch from transformers import AutoConfig, AutoTokenizer from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner -from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci +from sglang.test.test_utils import ( + CustomTestCase, + get_similarities, + is_in_amd_ci, + is_in_ci, +) MODELS = [ ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5), @@ -74,11 +79,13 @@ class TestEmbeddingModels(CustomTestCase): ) as hf_runner: hf_outputs = hf_runner.forward(truncated_prompts) + attention_backend = "triton" if is_in_amd_ci() else None with SRTRunner( model_path, tp_size=tp_size, torch_dtype=torch_dtype, model_type="embedding", + attention_backend=attention_backend, ) as srt_runner: srt_outputs = srt_runner.forward(truncated_prompts)