[router]: Add Embedding routing logic (#10129)
Signed-off-by: Jintao Zhang <zhangjintao9020@gmail.com> Co-authored-by: Waël Boukhobza <wawa_wael@live.fr>
This commit is contained in:
40
.github/workflows/pr-test-pd-router.yml
vendored
40
.github/workflows/pr-test-pd-router.yml
vendored
@@ -155,33 +155,35 @@ jobs:
|
||||
id: start_servers
|
||||
run: |
|
||||
echo "Starting disaggregation servers..."
|
||||
bash scripts/ci/ci_start_disaggregation_servers.sh &
|
||||
READY_FILE=".disagg_ready"
|
||||
rm -f "$READY_FILE"
|
||||
DISAGG_READY_FILE="$READY_FILE" bash scripts/ci/ci_start_disaggregation_servers.sh &
|
||||
SERVER_PID=$!
|
||||
echo "server_pid=$SERVER_PID" >> $GITHUB_OUTPUT
|
||||
|
||||
# Wait for all 8 servers to be healthy (script already does this)
|
||||
wait_count=0
|
||||
while [ $wait_count -lt 30 ]; do
|
||||
if ps -p $SERVER_PID > /dev/null; then
|
||||
# Check if the startup script printed success message
|
||||
sleep 2
|
||||
wait_count=$((wait_count + 1))
|
||||
else
|
||||
# Script exited - check if it was successful
|
||||
wait $SERVER_PID
|
||||
exit_code=$?
|
||||
if [ $exit_code -eq 0 ]; then
|
||||
echo "✓ All disaggregation servers are healthy"
|
||||
break
|
||||
else
|
||||
echo "Error: Server startup failed with code $exit_code"
|
||||
exit 1
|
||||
fi
|
||||
# Wait until script signals readiness (8/8 healthy) or timeout
|
||||
TIMEOUT=300
|
||||
ELAPSED=0
|
||||
while [ $ELAPSED -lt $TIMEOUT ]; do
|
||||
if [ -f "$READY_FILE" ]; then
|
||||
echo "✓ All disaggregation servers are healthy (signal detected)"
|
||||
break
|
||||
fi
|
||||
if ! ps -p $SERVER_PID > /dev/null; then
|
||||
echo "Error: server bootstrap script exited prematurely"
|
||||
exit 1
|
||||
fi
|
||||
sleep 5
|
||||
ELAPSED=$((ELAPSED + 5))
|
||||
done
|
||||
if [ $ELAPSED -ge $TIMEOUT ]; then
|
||||
echo "❌ Timeout waiting for disaggregation servers to be healthy"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ Servers started (PID: $SERVER_PID)"
|
||||
|
||||
|
||||
- name: Test all policies sequentially
|
||||
timeout-minutes: 30
|
||||
run: |
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Optional: set DISAGG_READY_FILE to a filepath; when all servers are healthy, the script will
|
||||
# create this file as a readiness signal (useful for CI to proceed to next steps).
|
||||
DISAGG_READY_FILE="${DISAGG_READY_FILE:-}"
|
||||
|
||||
MODEL_PATH="/raid/models/meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
@@ -81,6 +86,13 @@ while true; do
|
||||
|
||||
if [ $HEALTHY_COUNT -eq 8 ]; then
|
||||
echo "✅ All 8 servers are healthy!"
|
||||
# Emit readiness signal file if requested
|
||||
if [ -n "$DISAGG_READY_FILE" ]; then
|
||||
echo "Creating readiness flag: $DISAGG_READY_FILE"
|
||||
# Ensure parent dir exists; ignore errors
|
||||
mkdir -p "$(dirname "$DISAGG_READY_FILE")" 2>/dev/null || true
|
||||
touch "$DISAGG_READY_FILE"
|
||||
fi
|
||||
break
|
||||
else
|
||||
sleep 10 # Wait 10 seconds before next check
|
||||
|
||||
@@ -715,6 +715,29 @@ def e2e_router_only_rr():
|
||||
_terminate(proc)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def e2e_embedding_model() -> str:
|
||||
"""Embedding model to use for E2E tests.
|
||||
|
||||
Defaults to an E5 Mistral model, can be overridden via E2E_EMBEDDING_MODEL env var.
|
||||
"""
|
||||
import os
|
||||
|
||||
return os.getenv("E2E_EMBEDDING_MODEL", "intfloat/e5-mistral-7b-instruct")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def e2e_primary_embedding_worker(e2e_embedding_model: str):
|
||||
"""Launch a single embedding worker using the specified model."""
|
||||
port = _find_available_port()
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
proc = _popen_launch_worker(e2e_embedding_model, base_url)
|
||||
try:
|
||||
yield SimpleNamespace(proc=proc, url=base_url)
|
||||
finally:
|
||||
_terminate(proc)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def e2e_primary_worker(e2e_model: str):
|
||||
port = _find_available_port()
|
||||
|
||||
38
sgl-router/py_test/e2e/test_e2e_embeddings.py
Normal file
38
sgl-router/py_test/e2e/test_e2e_embeddings.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
def test_embeddings_basic(
|
||||
e2e_router_only_rr, e2e_primary_embedding_worker, e2e_embedding_model
|
||||
):
|
||||
base = e2e_router_only_rr.url
|
||||
worker_url = e2e_primary_embedding_worker.url
|
||||
|
||||
# Attach embedding worker to router-only instance
|
||||
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
|
||||
r.raise_for_status()
|
||||
|
||||
# Simple embedding request with two inputs
|
||||
payload = {
|
||||
"model": e2e_embedding_model,
|
||||
"input": [
|
||||
"the quick brown fox",
|
||||
"jumps over the lazy dog",
|
||||
],
|
||||
}
|
||||
r = requests.post(f"{base}/v1/embeddings", json=payload, timeout=120)
|
||||
|
||||
assert r.status_code == 200, f"unexpected status: {r.status_code} {r.text}"
|
||||
|
||||
data = r.json()
|
||||
assert "data" in data and isinstance(data["data"], list)
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
# Validate shape of embedding objects
|
||||
for item in data["data"]:
|
||||
assert "embedding" in item and isinstance(item["embedding"], list)
|
||||
# Ensure non-empty vectors
|
||||
assert len(item["embedding"]) > 0
|
||||
@@ -143,6 +143,18 @@ pub fn init_metrics() {
|
||||
"Generate request duration"
|
||||
);
|
||||
|
||||
// Embedding request specific metrics
|
||||
describe_counter!("sgl_router_embeddings_total", "Total embedding requests");
|
||||
describe_histogram!(
|
||||
"sgl_router_embeddings_duration_seconds",
|
||||
"Embedding request duration"
|
||||
);
|
||||
describe_counter!(
|
||||
"sgl_router_embeddings_errors_total",
|
||||
"Embedding request errors"
|
||||
);
|
||||
describe_gauge!("sgl_router_embeddings_queue_size", "Embedding queue size");
|
||||
|
||||
// Running requests gauge for cache-aware policy
|
||||
describe_gauge!(
|
||||
"sgl_router_running_requests",
|
||||
@@ -440,6 +452,27 @@ impl RouterMetrics {
|
||||
histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
// Embeddings metrics
|
||||
pub fn record_embeddings_request() {
|
||||
counter!("sgl_router_embeddings_total").increment(1);
|
||||
}
|
||||
|
||||
pub fn record_embeddings_duration(duration: Duration) {
|
||||
histogram!("sgl_router_embeddings_duration_seconds").record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
pub fn record_embeddings_error(error_type: &str) {
|
||||
counter!(
|
||||
"sgl_router_embeddings_errors_total",
|
||||
"error_type" => error_type.to_string()
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
pub fn set_embeddings_queue_size(size: usize) {
|
||||
gauge!("sgl_router_embeddings_queue_size").set(size as f64);
|
||||
}
|
||||
|
||||
// Running requests for cache-aware policy
|
||||
pub fn set_running_requests(worker: &str, count: usize) {
|
||||
gauge!("sgl_router_running_requests",
|
||||
|
||||
@@ -3,6 +3,7 @@ use axum::{
|
||||
response::IntoResponse, response::Response,
|
||||
};
|
||||
use rand::Rng;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
@@ -13,6 +14,7 @@ use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
|
||||
|
||||
pub use crate::core::token_bucket::TokenBucket;
|
||||
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::server::AppState;
|
||||
|
||||
/// Generate OpenAI-compatible request ID based on endpoint
|
||||
@@ -441,6 +443,11 @@ pub async fn concurrency_limit_middleware(
|
||||
request: Request<axum::body::Body>,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
// Static counter for embeddings queue size
|
||||
static EMBEDDINGS_QUEUE_SIZE: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
// Identify if this is an embeddings request based on path
|
||||
let is_embeddings = request.uri().path().contains("/v1/embeddings");
|
||||
let token_bucket = app_state.context.rate_limiter.clone();
|
||||
|
||||
// Try to acquire token immediately
|
||||
@@ -468,10 +475,23 @@ pub async fn concurrency_limit_middleware(
|
||||
// Try to send to queue
|
||||
match queue_tx.try_send(queued) {
|
||||
Ok(_) => {
|
||||
// On successful enqueue, update embeddings queue gauge if applicable
|
||||
if is_embeddings {
|
||||
let new_val = EMBEDDINGS_QUEUE_SIZE.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
RouterMetrics::set_embeddings_queue_size(new_val as usize);
|
||||
}
|
||||
|
||||
// Wait for token from queue processor
|
||||
match permit_rx.await {
|
||||
Ok(Ok(())) => {
|
||||
debug!("Acquired token from queue");
|
||||
// Dequeue for embeddings
|
||||
if is_embeddings {
|
||||
let new_val =
|
||||
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
|
||||
RouterMetrics::set_embeddings_queue_size(new_val as usize);
|
||||
}
|
||||
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Return the token to the bucket
|
||||
@@ -481,10 +501,22 @@ pub async fn concurrency_limit_middleware(
|
||||
}
|
||||
Ok(Err(status)) => {
|
||||
warn!("Queue returned error status: {}", status);
|
||||
// Dequeue for embeddings on error
|
||||
if is_embeddings {
|
||||
let new_val =
|
||||
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
|
||||
RouterMetrics::set_embeddings_queue_size(new_val as usize);
|
||||
}
|
||||
status.into_response()
|
||||
}
|
||||
Err(_) => {
|
||||
error!("Queue response channel closed");
|
||||
// Dequeue for embeddings on channel error
|
||||
if is_embeddings {
|
||||
let new_val =
|
||||
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
|
||||
RouterMetrics::set_embeddings_queue_size(new_val as usize);
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,10 @@ use std::collections::HashMap;
|
||||
// 6. **SGLANG SPEC - RERANK API**
|
||||
// - Request/Response structures
|
||||
//
|
||||
// 7. **COMMON**
|
||||
// 7. **OPENAI SPEC - Embeddings API**
|
||||
// - Request structures
|
||||
//
|
||||
// 8. **COMMON**
|
||||
// - GenerationRequest trait
|
||||
// - StringOrArray & LoRAPath types
|
||||
// - Helper functions
|
||||
@@ -2013,6 +2016,61 @@ impl RerankResponse {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = OPENAI SPEC - Embeddings API =
|
||||
// ==================================================================
|
||||
|
||||
/// Embeddings request compatible with OpenAI API
|
||||
/// We intentionally keep fields flexible to pass through to workers.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
/// ID of the model to use
|
||||
pub model: String,
|
||||
|
||||
/// Input can be a string, array of strings, tokens, or batch inputs
|
||||
pub input: serde_json::Value,
|
||||
|
||||
/// Optional encoding format (e.g., "float", "base64")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub encoding_format: Option<String>,
|
||||
|
||||
/// Optional user identifier
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
|
||||
/// Optional number of dimensions for the embedding
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub dimensions: Option<u32>,
|
||||
|
||||
/// SGLang extension: request id for tracking
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rid: Option<String>,
|
||||
}
|
||||
|
||||
impl GenerationRequest for EmbeddingRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
// Embeddings are non-streaming
|
||||
false
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
Some(&self.model)
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
// Best effort: extract text content for routing decisions
|
||||
match &self.input {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Array(arr) => arr
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" "),
|
||||
_ => String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = COMMON =
|
||||
// ==================================================================
|
||||
@@ -2715,4 +2773,102 @@ mod tests {
|
||||
assert_eq!(deserialized.results.len(), 2);
|
||||
assert_eq!(deserialized.model, response.model);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = EMBEDDINGS REQUEST TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_embedding_request_serialization_string_input() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "test-emb".to_string(),
|
||||
input: serde_json::Value::String("hello".to_string()),
|
||||
encoding_format: Some("float".to_string()),
|
||||
user: Some("user-1".to_string()),
|
||||
dimensions: Some(128),
|
||||
rid: Some("rid-123".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&req).unwrap();
|
||||
let deserialized: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.model, req.model);
|
||||
assert_eq!(deserialized.input, req.input);
|
||||
assert_eq!(deserialized.encoding_format, req.encoding_format);
|
||||
assert_eq!(deserialized.user, req.user);
|
||||
assert_eq!(deserialized.dimensions, req.dimensions);
|
||||
assert_eq!(deserialized.rid, req.rid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_request_serialization_array_input() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "test-emb".to_string(),
|
||||
input: serde_json::json!(["a", "b", "c"]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&req).unwrap();
|
||||
let de: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(de.model, req.model);
|
||||
assert_eq!(de.input, req.input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_string() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::Value::String("hello".to_string()),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
assert!(!req.is_stream());
|
||||
assert_eq!(req.get_model(), Some("emb-model"));
|
||||
assert_eq!(req.extract_text_for_routing(), "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_array() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::json!(["hello", "world"]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
assert_eq!(req.extract_text_for_routing(), "hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_non_text() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::json!({"tokens": [1, 2, 3]}),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
assert_eq!(req.extract_text_for_routing(), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::json!(["a", ["b", "c"], 123, {"k": "v"}]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
// Only top-level string elements are extracted
|
||||
assert_eq!(req.extract_text_for_routing(), "a");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -309,7 +309,12 @@ impl RouterTrait for GrpcPDRouter {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::EmbeddingRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
|
||||
@@ -242,7 +242,12 @@ impl RouterTrait for GrpcRouter {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::EmbeddingRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
|
||||
@@ -395,7 +395,12 @@ impl super::super::RouterTrait for OpenAIRouter {
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::EmbeddingRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Embeddings endpoint not implemented for OpenAI backend",
|
||||
|
||||
@@ -1938,8 +1938,17 @@ impl RouterTrait for PDRouter {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
todo!()
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::EmbeddingRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Embeddings endpoint not implemented for PD router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_rerank(
|
||||
|
||||
@@ -6,8 +6,8 @@ use crate::core::{
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, RerankRequest,
|
||||
RerankResponse, RerankResult, ResponsesRequest,
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
|
||||
RerankRequest, RerankResponse, RerankResult, ResponsesRequest,
|
||||
};
|
||||
use crate::routers::header_utils;
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
@@ -1430,8 +1430,28 @@ impl RouterTrait for Router {
|
||||
self.route_post_empty_request(headers, &endpoint).await
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
todo!()
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &EmbeddingRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Record embeddings-specific metrics in addition to general request metrics
|
||||
let start = Instant::now();
|
||||
let res = self
|
||||
.route_typed_request(headers, body, "/v1/embeddings", model_id)
|
||||
.await;
|
||||
|
||||
// Embedding specific metrics
|
||||
if res.status().is_success() {
|
||||
RouterMetrics::record_embeddings_request();
|
||||
RouterMetrics::record_embeddings_duration(start.elapsed());
|
||||
} else {
|
||||
let error_type = format!("http_{}", res.status().as_u16());
|
||||
RouterMetrics::record_embeddings_error(&error_type);
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
async fn route_rerank(
|
||||
|
||||
@@ -10,7 +10,8 @@ use axum::{
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||
ResponsesRequest,
|
||||
};
|
||||
|
||||
pub mod factory;
|
||||
@@ -123,7 +124,13 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
||||
/// Route embedding requests (OpenAI-compatible /v1/embeddings)
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &EmbeddingRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response;
|
||||
|
||||
async fn route_rerank(
|
||||
&self,
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
use crate::config::RouterConfig;
|
||||
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||
ResponsesRequest,
|
||||
};
|
||||
use crate::protocols::worker_spec::{
|
||||
ServerInfo, WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse, WorkerInfo,
|
||||
@@ -665,22 +666,6 @@ impl RouterTrait for RouterManager {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"responses api not yet implemented in inference gateway mode",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"responses api not yet implemented in inference gateway mode",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
@@ -701,17 +686,51 @@ impl RouterTrait for RouterManager {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Route embeddings request
|
||||
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response {
|
||||
// Try to select a router based on headers
|
||||
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
||||
let router = self.select_router_for_request(headers, None);
|
||||
|
||||
if let Some(router) = router {
|
||||
router.route_embeddings(headers, body).await
|
||||
router.get_response(headers, response_id).await
|
||||
} else {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
"No router available for embeddings request",
|
||||
format!("No router available to get response '{}'", response_id),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
||||
let router = self.select_router_for_request(headers, None);
|
||||
if let Some(router) = router {
|
||||
router.cancel_response(headers, response_id).await
|
||||
} else {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("No router available to cancel response '{}'", response_id),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Route embeddings request
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &EmbeddingRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Select router based on headers and model
|
||||
let router = self.select_router_for_request(headers, Some(&body.model));
|
||||
|
||||
if let Some(router) = router {
|
||||
router
|
||||
.route_embeddings(headers, body, Some(&body.model))
|
||||
.await
|
||||
} else {
|
||||
// Return 404 when the specified model is not found
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Model '{}' not found or no router available", body.model),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ use crate::metrics::{self, PrometheusConfig};
|
||||
use crate::middleware::TokenBucket;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
|
||||
V1RerankReqInput,
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||
ResponsesRequest, V1RerankReqInput,
|
||||
};
|
||||
use crate::protocols::worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse};
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
@@ -208,6 +208,17 @@ async fn v1_responses(
|
||||
.await
|
||||
}
|
||||
|
||||
async fn v1_embeddings(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<EmbeddingRequest>,
|
||||
) -> Response {
|
||||
state
|
||||
.router
|
||||
.route_embeddings(Some(&headers), &body, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn v1_responses_get(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(response_id): Path<String>,
|
||||
@@ -465,6 +476,7 @@ pub fn build_app(
|
||||
.route("/rerank", post(rerank))
|
||||
.route("/v1/rerank", post(v1_rerank))
|
||||
.route("/v1/responses", post(v1_responses))
|
||||
.route("/v1/embeddings", post(v1_embeddings))
|
||||
.route("/v1/responses/{response_id}", get(v1_responses_get))
|
||||
.route(
|
||||
"/v1/responses/{response_id}/cancel",
|
||||
|
||||
@@ -1090,10 +1090,14 @@ mod responses_endpoint_tests {
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
// First create a response to obtain an id
|
||||
let resp_id = "test-get-resp-id-123";
|
||||
let payload = json!({
|
||||
"input": "Hello Responses API",
|
||||
"model": "mock-model",
|
||||
"stream": false
|
||||
"stream": false,
|
||||
"store": true,
|
||||
"background": true,
|
||||
"request_id": resp_id
|
||||
});
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
@@ -1103,11 +1107,6 @@ mod responses_endpoint_tests {
|
||||
.unwrap();
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
let resp_id = body_json["id"].as_str().unwrap().to_string();
|
||||
|
||||
// Retrieve the response
|
||||
let req = Request::builder()
|
||||
@@ -1140,10 +1139,14 @@ mod responses_endpoint_tests {
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
// First create a response to obtain an id
|
||||
let resp_id = "test-cancel-resp-id-456";
|
||||
let payload = json!({
|
||||
"input": "Hello Responses API",
|
||||
"model": "mock-model",
|
||||
"stream": false
|
||||
"stream": false,
|
||||
"store": true,
|
||||
"background": true,
|
||||
"request_id": resp_id
|
||||
});
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
@@ -1153,11 +1156,6 @@ mod responses_endpoint_tests {
|
||||
.unwrap();
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
let resp_id = body_json["id"].as_str().unwrap().to_string();
|
||||
|
||||
// Cancel the response
|
||||
let req = Request::builder()
|
||||
|
||||
@@ -20,7 +20,12 @@ import torch
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
|
||||
from sglang.test.test_utils import (
|
||||
CustomTestCase,
|
||||
get_similarities,
|
||||
is_in_amd_ci,
|
||||
is_in_ci,
|
||||
)
|
||||
|
||||
MODELS = [
|
||||
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
|
||||
@@ -74,11 +79,13 @@ class TestEmbeddingModels(CustomTestCase):
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.forward(truncated_prompts)
|
||||
|
||||
attention_backend = "triton" if is_in_amd_ci() else None
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="embedding",
|
||||
attention_backend=attention_backend,
|
||||
) as srt_runner:
|
||||
srt_outputs = srt_runner.forward(truncated_prompts)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user