[router] refactor generate to use new pipeline arch (#11323)
This commit is contained in:
@@ -2066,39 +2066,64 @@ impl GenerationRequest for GenerateRequest {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(generate): Define GenerateResponse and GenerateChoice structs
|
||||
//
|
||||
// Required for pipeline generate response processing (see grpc/pipeline.rs:931-964)
|
||||
//
|
||||
// #[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
// pub struct GenerateResponse {
|
||||
// pub id: String,
|
||||
// pub object: String, // "text.completion"
|
||||
// pub created: u64,
|
||||
// pub model: String,
|
||||
// pub choices: Vec<GenerateChoice>,
|
||||
// #[serde(skip_serializing_if = "Option::is_none")]
|
||||
// pub usage: Option<Usage>,
|
||||
// #[serde(skip_serializing_if = "Option::is_none")]
|
||||
// pub system_fingerprint: Option<String>,
|
||||
// }
|
||||
//
|
||||
// #[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
// pub struct GenerateChoice {
|
||||
// pub index: u32,
|
||||
// pub text: String,
|
||||
// #[serde(skip_serializing_if = "Option::is_none")]
|
||||
// pub output_ids: Option<Vec<u32>>,
|
||||
// #[serde(skip_serializing_if = "Option::is_none")]
|
||||
// pub finish_reason: Option<String>,
|
||||
// #[serde(skip_serializing_if = "Option::is_none")]
|
||||
// pub logprobs: Option<TopLogprobs>,
|
||||
// #[serde(skip_serializing_if = "Option::is_none")]
|
||||
// pub matched_stop: Option<Value>,
|
||||
// }
|
||||
//
|
||||
// Note: Verify if similar structs already exist elsewhere before implementing.
|
||||
// May need streaming variant (GenerateStreamResponse) as well.
|
||||
// ============================================================================
|
||||
// SGLang Generate Response Types
|
||||
// ============================================================================
|
||||
|
||||
/// SGLang generate response (single completion or array for n>1)
|
||||
///
|
||||
/// Format for n=1:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "text": "...",
|
||||
/// "output_ids": [...],
|
||||
/// "meta_info": { ... }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Format for n>1:
|
||||
/// ```json
|
||||
/// [
|
||||
/// {"text": "...", "output_ids": [...], "meta_info": {...}},
|
||||
/// {"text": "...", "output_ids": [...], "meta_info": {...}}
|
||||
/// ]
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GenerateResponse {
|
||||
pub text: String,
|
||||
pub output_ids: Vec<u32>,
|
||||
pub meta_info: GenerateMetaInfo,
|
||||
}
|
||||
|
||||
/// Metadata for a single generate completion
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GenerateMetaInfo {
|
||||
pub id: String,
|
||||
pub finish_reason: GenerateFinishReason,
|
||||
pub prompt_tokens: u32,
|
||||
pub weight_version: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub input_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
|
||||
pub completion_tokens: u32,
|
||||
pub cached_tokens: u32,
|
||||
pub e2e_latency: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub matched_stop: Option<Value>,
|
||||
}
|
||||
|
||||
/// Finish reason for generate endpoint
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum GenerateFinishReason {
|
||||
Length {
|
||||
length: u32,
|
||||
},
|
||||
Stop,
|
||||
#[serde(untagged)]
|
||||
Other(Value),
|
||||
}
|
||||
|
||||
// Constants for rerank API
|
||||
pub const DEFAULT_MODEL_NAME: &str = "default";
|
||||
|
||||
@@ -12,7 +12,9 @@ use serde_json::Value;
|
||||
|
||||
use crate::core::Worker;
|
||||
use crate::grpc_client::{proto, SglangSchedulerClient};
|
||||
use crate::protocols::spec::{ChatCompletionRequest, ChatCompletionResponse, GenerateRequest};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, GenerateResponse,
|
||||
};
|
||||
use crate::reasoning_parser::ReasoningParserFactory;
|
||||
use crate::tokenizer::stop::StopSequenceDecoder;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
@@ -226,14 +228,6 @@ impl RequestContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to get chat request
|
||||
pub fn try_chat_request(&self) -> Option<&ChatCompletionRequest> {
|
||||
match &self.input.request_type {
|
||||
RequestType::Chat(req) => Some(req.as_ref()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get generate request (panics if not generate)
|
||||
pub fn generate_request(&self) -> &GenerateRequest {
|
||||
match &self.input.request_type {
|
||||
@@ -242,14 +236,6 @@ impl RequestContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to get generate request
|
||||
pub fn try_generate_request(&self) -> Option<&GenerateRequest> {
|
||||
match &self.input.request_type {
|
||||
RequestType::Generate(req) => Some(req.as_ref()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if request is streaming
|
||||
pub fn is_streaming(&self) -> bool {
|
||||
match &self.input.request_type {
|
||||
@@ -257,16 +243,6 @@ impl RequestContext {
|
||||
RequestType::Generate(req) => req.stream,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if request is chat
|
||||
pub fn is_chat(&self) -> bool {
|
||||
matches!(&self.input.request_type, RequestType::Chat(_))
|
||||
}
|
||||
|
||||
/// Check if request is generate
|
||||
pub fn is_generate(&self) -> bool {
|
||||
matches!(&self.input.request_type, RequestType::Generate(_))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -394,5 +370,6 @@ pub enum ExecutionResult {
|
||||
/// Final processed response
|
||||
pub enum FinalResponse {
|
||||
Chat(ChatCompletionResponse),
|
||||
Generate(Box<GenerateRequest>),
|
||||
/// Generate response is a Vec of GenerateResponse (n=1 returns single item, n>1 returns multiple)
|
||||
Generate(Vec<GenerateResponse>),
|
||||
}
|
||||
|
||||
@@ -1,40 +1,27 @@
|
||||
// PD (Prefill-Decode) gRPC Router Implementation
|
||||
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
|
||||
use crate::grpc_client::proto;
|
||||
use crate::grpc_client::SglangSchedulerClient;
|
||||
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType};
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, InputIds,
|
||||
RerankRequest, ResponsesGetParams, ResponsesRequest,
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||
ResponsesGetParams, ResponsesRequest,
|
||||
};
|
||||
use crate::reasoning_parser::ReasoningParserFactory;
|
||||
use crate::routers::http::pd_types::generate_room_id;
|
||||
use crate::routers::{grpc, RouterTrait};
|
||||
use crate::routers::RouterTrait;
|
||||
use crate::server::AppContext;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tokenizer::SequenceDecoderOutput;
|
||||
use crate::tool_parser::ToolParserFactory;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{header, HeaderMap, HeaderValue, StatusCode},
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use grpc::utils;
|
||||
use proto::generate_response::Response::{Chunk, Complete, Error};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::mpsc::unbounded_channel;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use tokio_stream::Stream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{debug, error};
|
||||
use uuid::Uuid;
|
||||
|
||||
use tracing::debug;
|
||||
|
||||
/// gRPC PD (Prefill-Decode) router implementation for SGLang
|
||||
#[derive(Clone)]
|
||||
@@ -50,9 +37,7 @@ pub struct GrpcPDRouter {
|
||||
retry_config: RetryConfig,
|
||||
configured_reasoning_parser: Option<String>,
|
||||
configured_tool_parser: Option<String>,
|
||||
// Pipeline for non-streaming requests
|
||||
pipeline: super::pipeline::ChatCompletionPipeline,
|
||||
// Shared components for pipeline
|
||||
shared_components: Arc<super::context::SharedComponents>,
|
||||
}
|
||||
|
||||
@@ -129,93 +114,10 @@ impl GrpcPDRouter {
|
||||
})
|
||||
}
|
||||
|
||||
/// Select a prefill-decode worker pair using load balancing policies
|
||||
async fn select_pd_pair(
|
||||
&self,
|
||||
request_text: Option<&str>,
|
||||
model_id: Option<&str>,
|
||||
) -> Result<(Arc<dyn Worker>, Arc<dyn Worker>), String> {
|
||||
let effective_model_id = if !self.dp_aware { None } else { model_id };
|
||||
|
||||
debug!(
|
||||
"Selecting PD pair: dp_aware={}, model_id={:?}, effective_model_id={:?}",
|
||||
self.dp_aware, model_id, effective_model_id
|
||||
);
|
||||
|
||||
// Get prefill workers
|
||||
let prefill_workers = if let Some(model) = effective_model_id {
|
||||
self.worker_registry
|
||||
.get_by_model_fast(model)
|
||||
.into_iter()
|
||||
.filter(|w| matches!(w.worker_type(), WorkerType::Prefill { .. }))
|
||||
.collect()
|
||||
} else {
|
||||
self.worker_registry.get_workers_filtered(
|
||||
None,
|
||||
Some(WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
}),
|
||||
Some(ConnectionMode::Grpc { port: None }),
|
||||
true, // only healthy workers
|
||||
)
|
||||
};
|
||||
|
||||
// Get decode workers
|
||||
let decode_workers = if let Some(model) = effective_model_id {
|
||||
self.worker_registry
|
||||
.get_by_model_fast(model)
|
||||
.into_iter()
|
||||
.filter(|w| matches!(w.worker_type(), WorkerType::Decode))
|
||||
.collect()
|
||||
} else {
|
||||
self.worker_registry.get_workers_filtered(
|
||||
None,
|
||||
Some(WorkerType::Decode),
|
||||
Some(ConnectionMode::Grpc { port: None }),
|
||||
true, // only healthy workers
|
||||
)
|
||||
};
|
||||
|
||||
if prefill_workers.is_empty() {
|
||||
return Err("No healthy prefill workers available".to_string());
|
||||
}
|
||||
if decode_workers.is_empty() {
|
||||
return Err("No healthy decode workers available".to_string());
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Found {} prefill workers and {} decode workers",
|
||||
prefill_workers.len(),
|
||||
decode_workers.len()
|
||||
);
|
||||
|
||||
let prefill_policy = self.policy_registry.get_prefill_policy();
|
||||
let decode_policy = self.policy_registry.get_decode_policy();
|
||||
|
||||
let prefill_idx = prefill_policy
|
||||
.select_worker(&prefill_workers, request_text)
|
||||
.ok_or_else(|| "Failed to select prefill worker".to_string())?;
|
||||
|
||||
let decode_idx = decode_policy
|
||||
.select_worker(&decode_workers, request_text)
|
||||
.ok_or_else(|| "Failed to select decode worker".to_string())?;
|
||||
|
||||
let prefill = prefill_workers[prefill_idx].clone();
|
||||
let decode = decode_workers[decode_idx].clone();
|
||||
|
||||
debug!(
|
||||
"Selected PD pair: prefill={}, decode={}",
|
||||
prefill.url(),
|
||||
decode.url()
|
||||
);
|
||||
|
||||
Ok((prefill, decode))
|
||||
}
|
||||
|
||||
/// Main route_generate implementation with PD dual dispatch
|
||||
async fn route_generate_impl(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &GenerateRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
@@ -224,125 +126,15 @@ impl GrpcPDRouter {
|
||||
model_id
|
||||
);
|
||||
|
||||
// Step 1: Resolve input (text or input_ids)
|
||||
let (original_text, token_ids) = match self.resolve_generate_input(body) {
|
||||
Ok(res) => res,
|
||||
Err(msg) => {
|
||||
return utils::bad_request_error(msg);
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Resolved input with {} tokens", token_ids.len());
|
||||
|
||||
// Step 2: Select prefill-decode worker pair
|
||||
let (prefill_worker, decode_worker) = match self
|
||||
.select_pd_pair(original_text.as_deref(), model_id)
|
||||
.await
|
||||
{
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
return utils::service_unavailable_error(e);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Selected PD pair: prefill={}, decode={}",
|
||||
prefill_worker.url(),
|
||||
decode_worker.url()
|
||||
);
|
||||
|
||||
// Step 3: Get gRPC clients for both workers
|
||||
let prefill_client = match utils::get_grpc_client_from_worker(&prefill_worker).await {
|
||||
Ok(client) => client,
|
||||
Err(response) => return response,
|
||||
};
|
||||
|
||||
let decode_client = match utils::get_grpc_client_from_worker(&decode_worker).await {
|
||||
Ok(client) => client,
|
||||
Err(response) => return response,
|
||||
};
|
||||
|
||||
// Step 4: Build the gRPC request
|
||||
let request_id = body
|
||||
.rid
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
|
||||
|
||||
let mut request = match prefill_client.build_plain_generate_request(
|
||||
request_id.clone(),
|
||||
body,
|
||||
original_text.clone(),
|
||||
token_ids,
|
||||
) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
return utils::bad_request_error(e);
|
||||
}
|
||||
};
|
||||
|
||||
// Step 5: Inject bootstrap metadata
|
||||
if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) {
|
||||
return utils::internal_error_message(e);
|
||||
}
|
||||
|
||||
// Step 6: Get weight version for response metadata
|
||||
let weight_version = decode_worker
|
||||
.metadata()
|
||||
.labels
|
||||
.get("weight_version")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "default".to_string());
|
||||
|
||||
// Step 7: Handle streaming vs non-streaming
|
||||
if body.stream {
|
||||
self.handle_streaming_generate(
|
||||
prefill_client,
|
||||
decode_client,
|
||||
request,
|
||||
body,
|
||||
request_id,
|
||||
weight_version,
|
||||
// Use pipeline for ALL requests (streaming and non-streaming)
|
||||
self.pipeline
|
||||
.execute_generate(
|
||||
body.clone(),
|
||||
headers.cloned(),
|
||||
model_id.map(|s| s.to_string()),
|
||||
self.shared_components.clone(),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
self.handle_non_streaming_generate(
|
||||
prefill_client,
|
||||
decode_client,
|
||||
request,
|
||||
body,
|
||||
request_id,
|
||||
weight_version,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Inject bootstrap metadata into a protobuf GenerateRequest
|
||||
fn inject_bootstrap_metadata(
|
||||
request: &mut proto::GenerateRequest,
|
||||
prefill_worker: &dyn Worker,
|
||||
) -> Result<(), String> {
|
||||
let hostname = prefill_worker.bootstrap_host();
|
||||
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
|
||||
|
||||
let room_id = generate_room_id();
|
||||
|
||||
// Create DisaggregatedParams
|
||||
let disagg_params = proto::DisaggregatedParams {
|
||||
bootstrap_host: hostname.to_string(),
|
||||
bootstrap_port: bootstrap_port as i32,
|
||||
bootstrap_room: room_id as i32,
|
||||
};
|
||||
|
||||
// Inject metadata
|
||||
request.disaggregated_params = Some(disagg_params);
|
||||
|
||||
debug!(
|
||||
"Injected bootstrap metadata: host={}, port={}, room={}",
|
||||
hostname, bootstrap_port, room_id
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Main route_chat implementation with PD dual dispatch
|
||||
@@ -367,405 +159,6 @@ impl GrpcPDRouter {
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Resolve the generate input into optional original text and token IDs
|
||||
fn resolve_generate_input(
|
||||
&self,
|
||||
request: &GenerateRequest,
|
||||
) -> Result<(Option<String>, Vec<u32>), String> {
|
||||
if let Some(text) = &request.text {
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(text)
|
||||
.map_err(|e| format!("Tokenization failed: {}", e))?;
|
||||
return Ok((Some(text.to_string()), encoding.token_ids().to_vec()));
|
||||
}
|
||||
|
||||
// Handle input_ids - validate and convert
|
||||
if let Some(input_ids) = &request.input_ids {
|
||||
return match input_ids {
|
||||
InputIds::Single(ids) => ids
|
||||
.iter()
|
||||
.map(|&id| u32::try_from(id))
|
||||
.collect::<Result<Vec<u32>, _>>()
|
||||
.map(|converted| (None, converted))
|
||||
.map_err(|_| "input_ids must be non-negative".to_string()),
|
||||
InputIds::Batch(_) => {
|
||||
Err("Batch input_ids are not supported in PD mode".to_string())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Err("Either `text` or `input_ids` must be provided".to_string())
|
||||
}
|
||||
|
||||
/// Submit request and handle streaming response for generate endpoint (PD mode)
|
||||
async fn handle_streaming_generate(
|
||||
&self,
|
||||
mut prefill_client: SglangSchedulerClient,
|
||||
mut decode_client: SglangSchedulerClient,
|
||||
request: proto::GenerateRequest,
|
||||
original_request: &GenerateRequest,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
) -> Response {
|
||||
// Create channel for SSE streaming
|
||||
let (tx, rx) = unbounded_channel::<Result<bytes::Bytes, std::io::Error>>();
|
||||
|
||||
// Send requests in parallel to both prefill and decode workers
|
||||
debug!("Starting concurrent streaming generate requests to prefill and decode workers");
|
||||
let prefill_request = request.clone();
|
||||
let decode_request = request;
|
||||
|
||||
let (prefill_result, decode_result) = tokio::join!(
|
||||
prefill_client.generate(prefill_request),
|
||||
decode_client.generate(decode_request)
|
||||
);
|
||||
|
||||
// Get prefill stream (for input_logprobs if needed)
|
||||
let prefill_stream = match prefill_result {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return utils::internal_error_message(format!(
|
||||
"Prefill worker failed to start: {}",
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Get decode stream - this is what we'll process for output
|
||||
let decode_stream = match decode_result {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return utils::internal_error_message(format!(
|
||||
"Decode worker failed to start: {}",
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Spawn processing task for both streams
|
||||
let tokenizer = self.tokenizer.clone();
|
||||
let return_logprob = original_request.return_logprob;
|
||||
tokio::spawn(async move {
|
||||
let result = Self::process_generate_streaming(
|
||||
tokenizer,
|
||||
prefill_stream,
|
||||
decode_stream,
|
||||
request_id,
|
||||
weight_version,
|
||||
return_logprob,
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
let error_chunk = format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::json!({
|
||||
"error": {
|
||||
"message": e,
|
||||
"type": "internal_error"
|
||||
}
|
||||
})
|
||||
);
|
||||
let _ = tx.send(Ok(bytes::Bytes::from(error_chunk)));
|
||||
}
|
||||
|
||||
// Send DONE marker
|
||||
let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n")));
|
||||
});
|
||||
|
||||
// Create response with SSE headers
|
||||
let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
|
||||
let mut response = Response::new(Body::from_stream(stream));
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
response.headers_mut().insert(
|
||||
header::CONTENT_TYPE,
|
||||
HeaderValue::from_static("text/event-stream"),
|
||||
);
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Cache-Control", HeaderValue::from_static("no-cache"));
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Connection", HeaderValue::from_static("keep-alive"));
|
||||
response
|
||||
}
|
||||
|
||||
/// Process generate streaming (simplified - no tool calls or reasoning)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn process_generate_streaming(
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
mut prefill_stream: impl Stream<Item = Result<proto::GenerateResponse, tonic::Status>> + Unpin,
|
||||
mut decode_stream: impl Stream<Item = Result<proto::GenerateResponse, tonic::Status>> + Unpin,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
include_logprobs: bool,
|
||||
tx: &UnboundedSender<Result<bytes::Bytes, std::io::Error>>,
|
||||
) -> Result<(), String> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Phase 1: Collect input_logprobs from prefill stream if requested
|
||||
// TODO: Store and emit input_logprobs when implementing prompt logprobs in streaming
|
||||
if include_logprobs {
|
||||
while let Some(response) = prefill_stream.next().await {
|
||||
let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?;
|
||||
match gen_response.response {
|
||||
Some(Complete(_complete)) => {
|
||||
// Input logprobs collected but not yet used in streaming
|
||||
break;
|
||||
}
|
||||
Some(Error(error)) => {
|
||||
return Err(format!("Prefill error: {}", error.message));
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Main streaming loop (decode stream)
|
||||
// Track state per index for n>1 case
|
||||
let mut accumulated_texts: HashMap<u32, String> = HashMap::new();
|
||||
let mut completion_tokens_map: HashMap<u32, u32> = HashMap::new();
|
||||
let mut current_index: u32 = 0;
|
||||
|
||||
while let Some(response) = decode_stream.next().await {
|
||||
let gen_response = response.map_err(|e| format!("Decode stream error: {}", e))?;
|
||||
|
||||
match gen_response.response {
|
||||
Some(Chunk(chunk)) => {
|
||||
// Use our tracked index instead of chunk.index (PD backend bug workaround)
|
||||
let index = current_index;
|
||||
debug!(
|
||||
"Received chunk with backend_index={}, using_index={}, tokens={:?}",
|
||||
chunk.index, index, chunk.token_ids
|
||||
);
|
||||
|
||||
let completion_tokens = completion_tokens_map.entry(index).or_insert(0);
|
||||
*completion_tokens += chunk.token_ids.len() as u32;
|
||||
|
||||
let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default();
|
||||
|
||||
let accumulated_text = accumulated_texts.entry(index).or_default();
|
||||
accumulated_text.push_str(&chunk_text);
|
||||
|
||||
let index_id = format!("{}-{}", request_id, index);
|
||||
|
||||
let chunk_response = serde_json::json!({
|
||||
"text": accumulated_text.clone(),
|
||||
"output_ids": chunk.token_ids,
|
||||
"meta_info": {
|
||||
"id": index_id,
|
||||
"finish_reason": null,
|
||||
"prompt_tokens": chunk.prompt_tokens,
|
||||
"weight_version": weight_version,
|
||||
"completion_tokens": *completion_tokens,
|
||||
"cached_tokens": chunk.cached_tokens
|
||||
},
|
||||
"index": index
|
||||
});
|
||||
|
||||
let sse_chunk = format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&chunk_response).unwrap()
|
||||
);
|
||||
tx.send(Ok(bytes::Bytes::from(sse_chunk)))
|
||||
.map_err(|_| "Failed to send chunk".to_string())?;
|
||||
}
|
||||
Some(Complete(complete)) => {
|
||||
let index = current_index;
|
||||
debug!(
|
||||
"Received Complete with backend_index={}, using_index={}, finish_reason={}",
|
||||
complete.index, index, complete.finish_reason
|
||||
);
|
||||
let accumulated_text =
|
||||
accumulated_texts.get(&index).cloned().unwrap_or_default();
|
||||
let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0);
|
||||
let index_id = format!("{}-{}", request_id, index);
|
||||
let e2e_latency = start_time.elapsed().as_secs_f64();
|
||||
|
||||
// Send final chunk with finish_reason (no new tokens in Complete, they were already sent in Chunks)
|
||||
let finish_response = serde_json::json!({
|
||||
"text": accumulated_text,
|
||||
"output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(),
|
||||
"meta_info": {
|
||||
"id": index_id,
|
||||
"finish_reason": complete.finish_reason,
|
||||
"prompt_tokens": complete.prompt_tokens,
|
||||
"weight_version": weight_version,
|
||||
"completion_tokens": completion_tokens,
|
||||
"cached_tokens": complete.cached_tokens,
|
||||
"e2e_latency": e2e_latency
|
||||
},
|
||||
"index": index
|
||||
});
|
||||
|
||||
let sse_chunk = format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&finish_response).unwrap()
|
||||
);
|
||||
tx.send(Ok(bytes::Bytes::from(sse_chunk)))
|
||||
.map_err(|_| "Failed to send finish chunk".to_string())?;
|
||||
|
||||
// Move to next completion
|
||||
current_index += 1;
|
||||
}
|
||||
Some(Error(error)) => {
|
||||
return Err(error.message);
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Submit request and handle non-streaming response for generate endpoint (PD mode)
|
||||
async fn handle_non_streaming_generate(
|
||||
&self,
|
||||
mut prefill_client: SglangSchedulerClient,
|
||||
mut decode_client: SglangSchedulerClient,
|
||||
request: proto::GenerateRequest,
|
||||
original_request: &GenerateRequest,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
) -> Response {
|
||||
use std::time::Instant;
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Send requests in parallel
|
||||
debug!("Sending concurrent generate requests to prefill and decode workers");
|
||||
let prefill_request = request.clone();
|
||||
let decode_request = request;
|
||||
|
||||
let (prefill_result, decode_result) = tokio::join!(
|
||||
prefill_client.generate(prefill_request),
|
||||
decode_client.generate(decode_request)
|
||||
);
|
||||
|
||||
// Process prefill stream
|
||||
let prefill_stream = match prefill_result {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to start prefill generation: {}", e);
|
||||
return utils::internal_error_message(format!(
|
||||
"Prefill worker failed to start: {}",
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let decode_stream = match decode_result {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to start decode generation: {}", e);
|
||||
return utils::internal_error_message(format!(
|
||||
"Decode worker failed to start: {}",
|
||||
e
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Collect prefill responses
|
||||
// TODO add logprob for generate
|
||||
let _prefill_responses =
|
||||
match utils::collect_stream_responses(prefill_stream, "Prefill").await {
|
||||
Ok(responses) => responses,
|
||||
Err(error_response) => return error_response,
|
||||
};
|
||||
|
||||
// Collect decode responses
|
||||
let decode_responses = match utils::collect_stream_responses(decode_stream, "Decode").await
|
||||
{
|
||||
Ok(responses) => responses,
|
||||
Err(error_response) => return error_response,
|
||||
};
|
||||
|
||||
if decode_responses.is_empty() {
|
||||
return utils::internal_error_static("No completion received from decode worker");
|
||||
}
|
||||
|
||||
// Create stop decoder from sampling params
|
||||
let params = original_request.sampling_params.as_ref();
|
||||
let mut stop_decoder = utils::create_stop_decoder(
|
||||
&self.tokenizer,
|
||||
params.and_then(|p| p.stop.as_ref()),
|
||||
params.and_then(|p| p.stop_token_ids.as_ref()),
|
||||
params.and_then(|p| p.skip_special_tokens).unwrap_or(true),
|
||||
params.and_then(|p| p.no_stop_trim).unwrap_or(false),
|
||||
);
|
||||
|
||||
// Process each completion
|
||||
let mut result_array = Vec::new();
|
||||
for mut complete in decode_responses {
|
||||
stop_decoder.reset();
|
||||
|
||||
// Process tokens through stop decoder
|
||||
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
|
||||
Ok(outputs) => outputs,
|
||||
Err(e) => {
|
||||
return utils::internal_error_message(format!(
|
||||
"Failed to process tokens: {}",
|
||||
e
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Accumulate text with early breaks
|
||||
let mut decoded_text = String::new();
|
||||
for output in outputs {
|
||||
match output {
|
||||
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
|
||||
SequenceDecoderOutput::StoppedWithText(t) => {
|
||||
decoded_text.push_str(&t);
|
||||
break;
|
||||
}
|
||||
SequenceDecoderOutput::Stopped => break,
|
||||
SequenceDecoderOutput::Held => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush remaining text
|
||||
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
|
||||
decoded_text.push_str(&t);
|
||||
}
|
||||
|
||||
let output_ids = complete.output_ids.clone();
|
||||
|
||||
// Build base meta_info
|
||||
let mut meta_info = serde_json::json!({
|
||||
"id": request_id.clone(),
|
||||
"finish_reason": complete.finish_reason.clone(),
|
||||
"prompt_tokens": complete.prompt_tokens,
|
||||
"weight_version": weight_version.clone(),
|
||||
"completion_tokens": complete.completion_tokens,
|
||||
"cached_tokens": complete.cached_tokens,
|
||||
"e2e_latency": start_time.elapsed().as_secs_f64(),
|
||||
});
|
||||
|
||||
let meta_obj = meta_info.as_object_mut().unwrap();
|
||||
|
||||
// Add matched_stop if present
|
||||
if let Some(matched) = complete.matched_stop.take() {
|
||||
use proto::generate_complete::MatchedStop;
|
||||
let matched_value = match matched {
|
||||
MatchedStop::MatchedTokenId(id) => serde_json::json!(id),
|
||||
MatchedStop::MatchedStopStr(s) => serde_json::json!(s),
|
||||
};
|
||||
meta_obj.insert("matched_stop".to_string(), matched_value);
|
||||
}
|
||||
|
||||
result_array.push(serde_json::json!({
|
||||
"text": decoded_text,
|
||||
"output_ids": output_ids,
|
||||
"meta_info": meta_info,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(result_array).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for GrpcPDRouter {
|
||||
|
||||
@@ -11,15 +11,20 @@ use super::context::*;
|
||||
use super::processing;
|
||||
use super::streaming;
|
||||
use super::utils;
|
||||
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType};
|
||||
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
|
||||
use crate::grpc_client::proto;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, InputIds, Usage,
|
||||
ChatCompletionRequest, ChatCompletionResponse, GenerateMetaInfo, GenerateRequest,
|
||||
GenerateResponse, InputIds, Usage,
|
||||
};
|
||||
use crate::tokenizer::stop::SequenceDecoderOutput;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use proto::generate_complete::MatchedStop;
|
||||
use proto::DisaggregatedParams;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||
use uuid::Uuid;
|
||||
|
||||
// ============================================================================
|
||||
@@ -208,7 +213,7 @@ impl PreparationStage {
|
||||
|
||||
fn tokenize_single_text(
|
||||
&self,
|
||||
tokenizer: &Arc<dyn crate::tokenizer::traits::Tokenizer>,
|
||||
tokenizer: &Arc<dyn Tokenizer>,
|
||||
text: &str,
|
||||
) -> Result<(String, Vec<u32>), String> {
|
||||
let encoding = tokenizer
|
||||
@@ -302,7 +307,7 @@ impl WorkerSelectionStage {
|
||||
&self,
|
||||
model_id: Option<&str>,
|
||||
text: Option<&str>,
|
||||
) -> Option<Arc<dyn crate::core::Worker>> {
|
||||
) -> Option<Arc<dyn Worker>> {
|
||||
// Get workers for the specified model, filtered by connection mode
|
||||
let workers = self.worker_registry.get_workers_filtered(
|
||||
model_id,
|
||||
@@ -312,7 +317,7 @@ impl WorkerSelectionStage {
|
||||
);
|
||||
|
||||
// Filter by availability (health + circuit breaker)
|
||||
let available: Vec<Arc<dyn crate::core::Worker>> = workers
|
||||
let available: Vec<Arc<dyn Worker>> = workers
|
||||
.iter()
|
||||
.filter(|w| w.is_available())
|
||||
.cloned()
|
||||
@@ -337,7 +342,7 @@ impl WorkerSelectionStage {
|
||||
&self,
|
||||
model_id: Option<&str>,
|
||||
text: Option<&str>,
|
||||
) -> Option<(Arc<dyn crate::core::Worker>, Arc<dyn crate::core::Worker>)> {
|
||||
) -> Option<(Arc<dyn Worker>, Arc<dyn Worker>)> {
|
||||
// Get prefill workers - use None for WorkerType filter to get all types,
|
||||
// then filter manually (since Prefill is a struct variant)
|
||||
let all_workers = self.worker_registry.get_workers_filtered(
|
||||
@@ -537,10 +542,8 @@ impl RequestBuildingStage {
|
||||
fn inject_bootstrap_metadata(
|
||||
&self,
|
||||
request: &mut proto::GenerateRequest,
|
||||
prefill_worker: &Arc<dyn crate::core::Worker>,
|
||||
prefill_worker: &Arc<dyn Worker>,
|
||||
) {
|
||||
use proto::DisaggregatedParams;
|
||||
|
||||
let hostname = prefill_worker.bootstrap_host();
|
||||
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
|
||||
|
||||
@@ -935,40 +938,183 @@ impl ResponseProcessingStage {
|
||||
|
||||
async fn process_generate_response(
|
||||
&self,
|
||||
_ctx: &mut RequestContext,
|
||||
ctx: &mut RequestContext,
|
||||
) -> Result<Option<Response>, Response> {
|
||||
// TODO(generate): Implement generate response processing
|
||||
//
|
||||
// Required implementation:
|
||||
// 1. Extract execution_result from ctx
|
||||
// 2. Check is_streaming flag
|
||||
// 3. For streaming:
|
||||
// - Add StreamingProcessor::process_streaming_generate() method
|
||||
// - Similar to process_streaming_response but WITHOUT tool/reasoning parsing
|
||||
// - Return Err(sse_response) for early exit
|
||||
// 4. For non-streaming:
|
||||
// - Collect stream responses using utils::collect_stream_responses()
|
||||
// - Process through stop decoder (sequential with reset for n>1, like chat)
|
||||
// - Build GenerateResponse struct (see TODO in protocols/spec.rs)
|
||||
// - Set ctx.state.response.final_response = Some(FinalResponse::Generate(response))
|
||||
//
|
||||
// Reference implementation: router.rs:297-595
|
||||
// Key differences from chat:
|
||||
// - No tool parsing
|
||||
// - No reasoning parsing
|
||||
// - Different response format (GenerateResponse instead of ChatCompletionResponse)
|
||||
// - Still needs: stop decoder, logprobs, finish_reason, matched_stop
|
||||
Err((
|
||||
axum::http::StatusCode::NOT_IMPLEMENTED,
|
||||
axum::Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": "Generate response processing not yet implemented in pipeline",
|
||||
"type": "not_implemented",
|
||||
"code": 501
|
||||
let start_time = Instant::now();
|
||||
let is_streaming = ctx.is_streaming();
|
||||
|
||||
// Extract execution result
|
||||
let execution_result = ctx
|
||||
.state
|
||||
.response
|
||||
.execution_result
|
||||
.take()
|
||||
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
|
||||
|
||||
if is_streaming {
|
||||
// Get dispatch metadata for consistent response fields
|
||||
let dispatch = ctx
|
||||
.state
|
||||
.dispatch
|
||||
.as_ref()
|
||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
|
||||
|
||||
let generate_request = ctx.generate_request().clone();
|
||||
|
||||
// Streaming: Use StreamingProcessor and return SSE response (done)
|
||||
return Ok(Some(
|
||||
self.streaming_processor.clone().process_streaming_generate(
|
||||
execution_result,
|
||||
generate_request,
|
||||
dispatch.clone(),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Non-streaming: Collect all responses
|
||||
let request_logprobs = ctx.generate_request().return_logprob;
|
||||
let all_responses = match execution_result {
|
||||
ExecutionResult::Single { stream } => {
|
||||
utils::collect_stream_responses(stream, "Single").await?
|
||||
}
|
||||
ExecutionResult::Dual { prefill, decode } => {
|
||||
// Collect prefill for input_logprobs
|
||||
let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?;
|
||||
|
||||
// Collect decode for actual output
|
||||
let mut decode_responses =
|
||||
utils::collect_stream_responses(*decode, "Decode").await?;
|
||||
|
||||
// Merge prefill input_logprobs if requested
|
||||
if request_logprobs {
|
||||
if let Some(prefill_input_logprobs) = prefill_responses
|
||||
.first()
|
||||
.and_then(|r| r.input_logprobs.clone())
|
||||
{
|
||||
for response in &mut decode_responses {
|
||||
response.input_logprobs = Some(prefill_input_logprobs.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response())
|
||||
|
||||
decode_responses
|
||||
}
|
||||
};
|
||||
|
||||
if all_responses.is_empty() {
|
||||
return Err(utils::internal_error_static("No responses from server"));
|
||||
}
|
||||
|
||||
// Get stop decoder for processing
|
||||
let stop_decoder = ctx
|
||||
.state
|
||||
.response
|
||||
.stop_decoder
|
||||
.as_mut()
|
||||
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
|
||||
|
||||
// Get dispatch metadata
|
||||
let dispatch = ctx
|
||||
.state
|
||||
.dispatch
|
||||
.as_ref()
|
||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
|
||||
|
||||
// Process each completion (similar to router.rs:336-400)
|
||||
let mut result_array = Vec::new();
|
||||
for mut complete in all_responses {
|
||||
stop_decoder.reset();
|
||||
|
||||
// Process tokens through stop decoder
|
||||
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
|
||||
Ok(outputs) => outputs,
|
||||
Err(e) => {
|
||||
return Err(utils::internal_error_message(format!(
|
||||
"Failed to process tokens: {}",
|
||||
e
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Accumulate text with early breaks
|
||||
let mut decoded_text = String::new();
|
||||
for output in outputs {
|
||||
match output {
|
||||
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
|
||||
SequenceDecoderOutput::StoppedWithText(t) => {
|
||||
decoded_text.push_str(&t);
|
||||
break;
|
||||
}
|
||||
SequenceDecoderOutput::Stopped => break,
|
||||
SequenceDecoderOutput::Held => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush remaining text
|
||||
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
|
||||
decoded_text.push_str(&t);
|
||||
}
|
||||
|
||||
let output_ids = std::mem::take(&mut complete.output_ids);
|
||||
let finish_reason_str = std::mem::take(&mut complete.finish_reason);
|
||||
|
||||
// Parse finish_reason from string to proper type
|
||||
let finish_reason =
|
||||
utils::parse_finish_reason(&finish_reason_str, complete.completion_tokens);
|
||||
|
||||
// Handle matched_stop if present
|
||||
let matched_stop = complete.matched_stop.take().map(|matched| match matched {
|
||||
MatchedStop::MatchedTokenId(id) => serde_json::json!(id),
|
||||
MatchedStop::MatchedStopStr(s) => serde_json::json!(s),
|
||||
});
|
||||
|
||||
// Extract logprobs if requested (convert proto types to Generate format)
|
||||
let input_token_logprobs = if request_logprobs {
|
||||
complete
|
||||
.input_logprobs
|
||||
.as_ref()
|
||||
.map(utils::convert_generate_input_logprobs)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let output_token_logprobs = if request_logprobs {
|
||||
complete
|
||||
.output_logprobs
|
||||
.as_ref()
|
||||
.map(utils::convert_generate_output_logprobs)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build GenerateResponse struct
|
||||
let meta_info = GenerateMetaInfo {
|
||||
id: dispatch.request_id.clone(),
|
||||
finish_reason,
|
||||
prompt_tokens: complete.prompt_tokens as u32,
|
||||
weight_version: dispatch
|
||||
.weight_version
|
||||
.clone()
|
||||
.unwrap_or_else(|| "default".to_string()),
|
||||
input_token_logprobs,
|
||||
output_token_logprobs,
|
||||
completion_tokens: complete.completion_tokens as u32,
|
||||
cached_tokens: complete.cached_tokens as u32,
|
||||
e2e_latency: start_time.elapsed().as_secs_f64(),
|
||||
matched_stop,
|
||||
};
|
||||
|
||||
result_array.push(GenerateResponse {
|
||||
text: decoded_text,
|
||||
output_ids,
|
||||
meta_info,
|
||||
});
|
||||
}
|
||||
|
||||
// Store the final response
|
||||
ctx.state.response.final_response = Some(FinalResponse::Generate(result_array));
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1136,7 +1282,7 @@ impl ChatCompletionPipeline {
|
||||
|
||||
// Extract final response
|
||||
match ctx.state.response.final_response {
|
||||
Some(FinalResponse::Generate(response)) => axum::Json(*response).into_response(),
|
||||
Some(FinalResponse::Generate(response)) => axum::Json(response).into_response(),
|
||||
Some(FinalResponse::Chat(_)) => {
|
||||
utils::internal_error_static("Internal error: wrong response type")
|
||||
}
|
||||
|
||||
@@ -8,28 +8,21 @@ use axum::{
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
|
||||
use crate::grpc_client::{proto, SglangSchedulerClient};
|
||||
use crate::core::WorkerRegistry;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, InputIds,
|
||||
RerankRequest, ResponsesGetParams, ResponsesRequest,
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||
ResponsesGetParams, ResponsesRequest,
|
||||
};
|
||||
use crate::reasoning_parser::ReasoningParserFactory;
|
||||
use crate::routers::{grpc, RouterTrait};
|
||||
use crate::routers::RouterTrait;
|
||||
use crate::server::AppContext;
|
||||
use crate::tokenizer::stop::SequenceDecoderOutput;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ToolParserFactory;
|
||||
use grpc::utils;
|
||||
use serde_json::json;
|
||||
use std::time::Instant;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// gRPC router implementation for SGLang
|
||||
#[derive(Clone)]
|
||||
@@ -45,9 +38,7 @@ pub struct GrpcRouter {
|
||||
retry_config: RetryConfig,
|
||||
configured_reasoning_parser: Option<String>,
|
||||
configured_tool_parser: Option<String>,
|
||||
// Pipeline for non-streaming requests
|
||||
pipeline: super::pipeline::ChatCompletionPipeline,
|
||||
// Shared components for pipeline
|
||||
shared_components: Arc<super::context::SharedComponents>,
|
||||
}
|
||||
|
||||
@@ -149,420 +140,21 @@ impl GrpcRouter {
|
||||
/// Main route_generate implementation
|
||||
async fn route_generate_impl(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &GenerateRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
debug!("Processing generate request for model: {:?}", model_id);
|
||||
|
||||
// Step 1: Resolve input (text, prompt, or input_ids)
|
||||
let (original_text, token_ids) = match self.resolve_generate_input(body) {
|
||||
Ok(res) => res,
|
||||
Err(msg) => {
|
||||
return utils::bad_request_error(msg);
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Resolved input with {} tokens", token_ids.len());
|
||||
|
||||
// Step 2: Select worker (fail fast if no workers available)
|
||||
let worker = match self.select_worker_for_request(model_id, original_text.as_deref()) {
|
||||
Some(w) => w,
|
||||
None => {
|
||||
return utils::service_unavailable_error(format!(
|
||||
"No available workers for model: {:?}",
|
||||
model_id
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Selected worker: {}", worker.url());
|
||||
|
||||
// Step 3: Get gRPC client from worker
|
||||
let client = match utils::get_grpc_client_from_worker(&worker).await {
|
||||
Ok(client) => client,
|
||||
Err(response) => return response,
|
||||
};
|
||||
|
||||
// Step 4: Build the gRPC request
|
||||
let request_id = body
|
||||
.rid
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
|
||||
|
||||
let request = match client.build_plain_generate_request(
|
||||
request_id.clone(),
|
||||
body,
|
||||
original_text.clone(),
|
||||
token_ids,
|
||||
) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
return utils::bad_request_error(e);
|
||||
}
|
||||
};
|
||||
|
||||
// Step 5: Get weight version for response metadata
|
||||
let weight_version = worker
|
||||
.metadata()
|
||||
.labels
|
||||
.get("weight_version")
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "default".to_string());
|
||||
|
||||
// Step 6: Handle streaming vs non-streaming
|
||||
if body.stream {
|
||||
self.handle_streaming_generate(client, request, body, request_id, weight_version)
|
||||
.await
|
||||
} else {
|
||||
self.handle_non_streaming_generate(client, request, body, request_id, weight_version)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Select a worker for the request
|
||||
fn select_worker_for_request(
|
||||
&self,
|
||||
model_id: Option<&str>,
|
||||
text: Option<&str>,
|
||||
) -> Option<Arc<dyn Worker>> {
|
||||
// Get workers for the specified model, filtered by connection mode
|
||||
let workers = self.worker_registry.get_workers_filtered(
|
||||
model_id,
|
||||
Some(WorkerType::Regular),
|
||||
Some(ConnectionMode::Grpc { port: None }),
|
||||
false, // get all workers, we'll filter by is_available() next
|
||||
);
|
||||
|
||||
// Filter by availability (health + circuit breaker)
|
||||
let available: Vec<Arc<dyn Worker>> = workers
|
||||
.iter()
|
||||
.filter(|w| w.is_available())
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
if available.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Get the appropriate policy for this model
|
||||
let policy = match model_id {
|
||||
Some(model) => self.policy_registry.get_policy_or_default(model),
|
||||
None => self.policy_registry.get_default_policy(),
|
||||
};
|
||||
|
||||
// Select worker using the policy
|
||||
let idx = policy.select_worker(&available, text)?;
|
||||
Some(available[idx].clone())
|
||||
}
|
||||
|
||||
/// Resolve the generate input into optional original text and token IDs
|
||||
fn resolve_generate_input(
|
||||
&self,
|
||||
request: &GenerateRequest,
|
||||
) -> Result<(Option<String>, Vec<u32>), String> {
|
||||
if let Some(text) = &request.text {
|
||||
return self
|
||||
.tokenize_single_text(text)
|
||||
.map(|(original, ids)| (Some(original), ids));
|
||||
}
|
||||
|
||||
// Handle input_ids - validate and convert
|
||||
if let Some(input_ids) = &request.input_ids {
|
||||
return match input_ids {
|
||||
InputIds::Single(ids) => ids
|
||||
.iter()
|
||||
.map(|&id| u32::try_from(id))
|
||||
.collect::<Result<Vec<u32>, _>>()
|
||||
.map(|converted| (None, converted))
|
||||
.map_err(|_| "input_ids must be non-negative".to_string()),
|
||||
InputIds::Batch(_) => {
|
||||
Err("Batch input_ids are not supported over gRPC generate yet".to_string())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Err("Either `text` or `input_ids` must be provided".to_string())
|
||||
}
|
||||
|
||||
fn tokenize_single_text(&self, text: &str) -> Result<(String, Vec<u32>), String> {
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(text)
|
||||
.map_err(|e| format!("Tokenization failed: {}", e))?;
|
||||
Ok((text.to_string(), encoding.token_ids().to_vec()))
|
||||
}
|
||||
|
||||
/// Submit request and handle non-streaming response for the `/generate` endpoint
|
||||
async fn handle_non_streaming_generate(
|
||||
&self,
|
||||
mut client: SglangSchedulerClient,
|
||||
request: proto::GenerateRequest,
|
||||
original_request: &GenerateRequest,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
) -> Response {
|
||||
let start_time = Instant::now();
|
||||
|
||||
let stream = match client.generate(request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
return utils::internal_error_message(format!("Failed to start generation: {}", e))
|
||||
}
|
||||
};
|
||||
|
||||
// Collect all responses using utils helper
|
||||
let responses = match utils::collect_stream_responses(stream, "Generate").await {
|
||||
Ok(responses) => responses,
|
||||
Err(error_response) => return error_response,
|
||||
};
|
||||
|
||||
if responses.is_empty() {
|
||||
return utils::internal_error_static("No completion received from scheduler");
|
||||
}
|
||||
|
||||
// Create stop decoder from sampling params
|
||||
let params = original_request.sampling_params.as_ref();
|
||||
let mut stop_decoder = utils::create_stop_decoder(
|
||||
&self.tokenizer,
|
||||
params.and_then(|p| p.stop.as_ref()),
|
||||
params.and_then(|p| p.stop_token_ids.as_ref()),
|
||||
params.and_then(|p| p.skip_special_tokens).unwrap_or(true),
|
||||
params.and_then(|p| p.no_stop_trim).unwrap_or(false),
|
||||
);
|
||||
|
||||
// Process each completion
|
||||
let mut result_array = Vec::new();
|
||||
for mut complete in responses {
|
||||
stop_decoder.reset();
|
||||
|
||||
// Process tokens through stop decoder
|
||||
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
|
||||
Ok(outputs) => outputs,
|
||||
Err(e) => {
|
||||
return utils::internal_error_message(format!(
|
||||
"Failed to process tokens: {}",
|
||||
e
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Accumulate text with early breaks
|
||||
let mut decoded_text = String::new();
|
||||
for output in outputs {
|
||||
match output {
|
||||
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
|
||||
SequenceDecoderOutput::StoppedWithText(t) => {
|
||||
decoded_text.push_str(&t);
|
||||
break;
|
||||
}
|
||||
SequenceDecoderOutput::Stopped => break,
|
||||
SequenceDecoderOutput::Held => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush remaining text
|
||||
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
|
||||
decoded_text.push_str(&t);
|
||||
}
|
||||
|
||||
let output_ids = std::mem::take(&mut complete.output_ids);
|
||||
let finish_reason = std::mem::take(&mut complete.finish_reason);
|
||||
|
||||
// Build base meta_info using json! macro
|
||||
let mut meta_info = json!({
|
||||
"id": request_id.clone(),
|
||||
"finish_reason": finish_reason,
|
||||
"prompt_tokens": complete.prompt_tokens,
|
||||
"weight_version": weight_version.clone(),
|
||||
"completion_tokens": complete.completion_tokens,
|
||||
"cached_tokens": complete.cached_tokens,
|
||||
"e2e_latency": start_time.elapsed().as_secs_f64(),
|
||||
});
|
||||
|
||||
let meta_obj = meta_info.as_object_mut().unwrap();
|
||||
|
||||
// Add matched_stop if present
|
||||
if let Some(matched) = complete.matched_stop.take() {
|
||||
use proto::generate_complete::MatchedStop;
|
||||
let matched_value = match matched {
|
||||
MatchedStop::MatchedTokenId(id) => json!(id),
|
||||
MatchedStop::MatchedStopStr(s) => json!(s),
|
||||
};
|
||||
meta_obj.insert("matched_stop".to_string(), matched_value);
|
||||
}
|
||||
|
||||
result_array.push(json!({
|
||||
"text": decoded_text,
|
||||
"output_ids": output_ids,
|
||||
"meta_info": meta_info,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(result_array).into_response()
|
||||
}
|
||||
|
||||
/// Submit request and handle streaming response for the `/generate` endpoint
|
||||
async fn handle_streaming_generate(
|
||||
&self,
|
||||
mut client: SglangSchedulerClient,
|
||||
request: proto::GenerateRequest,
|
||||
original_request: &GenerateRequest,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
) -> Response {
|
||||
let tokenizer = self.tokenizer.clone();
|
||||
let return_logprob = original_request.return_logprob;
|
||||
|
||||
// Create channel for SSE streaming
|
||||
let (tx, rx) =
|
||||
tokio::sync::mpsc::unbounded_channel::<Result<bytes::Bytes, std::io::Error>>();
|
||||
|
||||
// Start the stream
|
||||
let stream = match client.generate(request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
return utils::internal_error_message(format!("Failed to start generation: {}", e))
|
||||
}
|
||||
};
|
||||
|
||||
// Spawn async task to process stream
|
||||
tokio::spawn(async move {
|
||||
let result = Self::process_generate_streaming(
|
||||
tokenizer,
|
||||
stream,
|
||||
request_id,
|
||||
weight_version,
|
||||
return_logprob,
|
||||
&tx,
|
||||
// Use pipeline for ALL requests (streaming and non-streaming)
|
||||
self.pipeline
|
||||
.execute_generate(
|
||||
body.clone(),
|
||||
headers.cloned(),
|
||||
model_id.map(|s| s.to_string()),
|
||||
self.shared_components.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
let error_chunk = format!("data: {{\"error\": \"{}\"}}\n\n", e);
|
||||
let _ = tx.send(Ok(bytes::Bytes::from(error_chunk)));
|
||||
}
|
||||
|
||||
// Send [DONE] marker
|
||||
let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n")));
|
||||
});
|
||||
|
||||
// Create SSE response stream
|
||||
let body_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "text/event-stream")
|
||||
.header("Cache-Control", "no-cache")
|
||||
.header("Connection", "keep-alive")
|
||||
.body(axum::body::Body::from_stream(body_stream))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Process streaming chunks for generate endpoint
|
||||
async fn process_generate_streaming(
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
mut stream: impl tokio_stream::Stream<Item = Result<proto::GenerateResponse, tonic::Status>>
|
||||
+ Unpin,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
_include_logprobs: bool,
|
||||
tx: &tokio::sync::mpsc::UnboundedSender<Result<bytes::Bytes, std::io::Error>>,
|
||||
) -> Result<(), String> {
|
||||
use proto::generate_response::Response::{Chunk, Complete, Error};
|
||||
use std::time::Instant;
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Track state per index for n>1 case
|
||||
use std::collections::HashMap;
|
||||
let mut accumulated_texts: HashMap<u32, String> = HashMap::new();
|
||||
let mut completion_tokens_map: HashMap<u32, u32> = HashMap::new();
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
|
||||
|
||||
match gen_response.response {
|
||||
Some(Chunk(chunk)) => {
|
||||
let index = chunk.index;
|
||||
|
||||
// Update completion tokens for this index
|
||||
let completion_tokens = completion_tokens_map.entry(index).or_insert(0);
|
||||
*completion_tokens += chunk.token_ids.len() as u32;
|
||||
|
||||
// Decode tokens to text (skip_special_tokens=true to handle newlines correctly)
|
||||
let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default();
|
||||
|
||||
// Accumulate text for this index
|
||||
let accumulated_text = accumulated_texts.entry(index).or_default();
|
||||
accumulated_text.push_str(&chunk_text);
|
||||
|
||||
// Generate unique ID per index
|
||||
let index_id = format!("{}-{}", request_id, index);
|
||||
|
||||
// Build streaming response chunk (SGLang format)
|
||||
let chunk_response = serde_json::json!({
|
||||
"text": accumulated_text.clone(),
|
||||
"output_ids": chunk.token_ids,
|
||||
"meta_info": {
|
||||
"id": index_id,
|
||||
"finish_reason": null,
|
||||
"prompt_tokens": chunk.prompt_tokens,
|
||||
"weight_version": weight_version,
|
||||
"completion_tokens": *completion_tokens,
|
||||
"cached_tokens": chunk.cached_tokens
|
||||
},
|
||||
"index": index
|
||||
});
|
||||
|
||||
let sse_chunk = format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&chunk_response).unwrap()
|
||||
);
|
||||
tx.send(Ok(bytes::Bytes::from(sse_chunk)))
|
||||
.map_err(|_| "Failed to send chunk".to_string())?;
|
||||
}
|
||||
Some(Complete(complete)) => {
|
||||
let index = complete.index;
|
||||
let accumulated_text =
|
||||
accumulated_texts.get(&index).cloned().unwrap_or_default();
|
||||
let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0);
|
||||
let index_id = format!("{}-{}", request_id, index);
|
||||
let e2e_latency = start_time.elapsed().as_secs_f64();
|
||||
|
||||
// Send final chunk with finish_reason (no new tokens in Complete, they were already sent in Chunks)
|
||||
let finish_response = serde_json::json!({
|
||||
"text": accumulated_text,
|
||||
"output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(),
|
||||
"meta_info": {
|
||||
"id": index_id,
|
||||
"finish_reason": complete.finish_reason,
|
||||
"prompt_tokens": complete.prompt_tokens,
|
||||
"weight_version": weight_version,
|
||||
"completion_tokens": completion_tokens,
|
||||
"cached_tokens": complete.cached_tokens,
|
||||
"e2e_latency": e2e_latency
|
||||
},
|
||||
"index": index
|
||||
});
|
||||
|
||||
let sse_chunk = format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&finish_response).unwrap()
|
||||
);
|
||||
tx.send(Ok(bytes::Bytes::from(sse_chunk)))
|
||||
.map_err(|_| "Failed to send finish chunk".to_string())?;
|
||||
|
||||
// Continue to process all completions if n>1
|
||||
}
|
||||
Some(Error(error)) => {
|
||||
return Err(error.message);
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,15 +17,18 @@ use tokio_stream::StreamExt;
|
||||
use tonic::codec::Streaming;
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
use super::context;
|
||||
use super::utils;
|
||||
use crate::grpc_client::proto;
|
||||
use crate::protocols::spec::*;
|
||||
use crate::reasoning_parser::ReasoningParser;
|
||||
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ToolParser;
|
||||
|
||||
use super::context;
|
||||
use super::utils;
|
||||
use proto::generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId};
|
||||
use proto::generate_response::Response::{Chunk, Complete, Error};
|
||||
use std::time::Instant;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Shared streaming processor for both single and dual dispatch modes
|
||||
#[derive(Clone)]
|
||||
@@ -65,7 +68,7 @@ impl StreamingProcessor {
|
||||
execution_result: context::ExecutionResult,
|
||||
chat_request: ChatCompletionRequest,
|
||||
dispatch: context::DispatchMetadata,
|
||||
) -> axum::response::Response {
|
||||
) -> Response {
|
||||
use bytes::Bytes;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
@@ -194,7 +197,7 @@ impl StreamingProcessor {
|
||||
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
|
||||
|
||||
match gen_response.response {
|
||||
Some(proto::generate_response::Response::Chunk(chunk)) => {
|
||||
Some(Chunk(chunk)) => {
|
||||
let index = chunk.index;
|
||||
|
||||
// Get or create stop decoder for this index
|
||||
@@ -336,7 +339,7 @@ impl StreamingProcessor {
|
||||
.map_err(|_| "Failed to send content chunk".to_string())?;
|
||||
}
|
||||
}
|
||||
Some(proto::generate_response::Response::Complete(complete)) => {
|
||||
Some(Complete(complete)) => {
|
||||
let index = complete.index;
|
||||
|
||||
// Flush any remaining text for this index's stop_decoder
|
||||
@@ -385,19 +388,17 @@ impl StreamingProcessor {
|
||||
|
||||
// Extract matched_stop
|
||||
let matched_stop_value = match &complete.matched_stop {
|
||||
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
|
||||
Some(MatchedTokenId(token_id)) => {
|
||||
Some(Value::Number(serde_json::Number::from(*token_id)))
|
||||
}
|
||||
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
|
||||
Some(Value::String(stop_str.clone()))
|
||||
}
|
||||
Some(MatchedStopStr(stop_str)) => Some(Value::String(stop_str.clone())),
|
||||
None => None,
|
||||
};
|
||||
matched_stops.insert(index, matched_stop_value);
|
||||
|
||||
// Don't break - continue reading all Complete messages for n>1
|
||||
}
|
||||
Some(proto::generate_response::Response::Error(error)) => {
|
||||
Some(Error(error)) => {
|
||||
return Err(error.message);
|
||||
}
|
||||
None => continue,
|
||||
@@ -536,12 +537,12 @@ impl StreamingProcessor {
|
||||
while let Some(response) = prefill_stream.next().await {
|
||||
let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?;
|
||||
match gen_response.response {
|
||||
Some(proto::generate_response::Response::Complete(_complete)) => {
|
||||
Some(Complete(_complete)) => {
|
||||
// Input logprobs collected but not yet used in streaming
|
||||
// (OpenAI spec doesn't require prompt logprobs in streaming responses)
|
||||
break;
|
||||
}
|
||||
Some(proto::generate_response::Response::Error(error)) => {
|
||||
Some(Error(error)) => {
|
||||
return Err(format!("Prefill error: {}", error.message));
|
||||
}
|
||||
_ => continue,
|
||||
@@ -554,23 +555,359 @@ impl StreamingProcessor {
|
||||
.await
|
||||
}
|
||||
|
||||
// TODO(generate): Add streaming generate handler
|
||||
//
|
||||
// pub async fn process_streaming_generate(
|
||||
// self: Arc<Self>,
|
||||
// execution_result: context::ExecutionResult,
|
||||
// generate_request: GenerateRequest,
|
||||
// dispatch: context::DispatchMetadata,
|
||||
// ) -> axum::response::Response {
|
||||
// // Similar to process_streaming_response but:
|
||||
// // - No tool parsing
|
||||
// // - No reasoning parsing
|
||||
// // - Simpler chunk format (just text + finish_reason + logprobs)
|
||||
// // - Extract stop params from generate_request.sampling_params
|
||||
// // - Use same per-index stop decoder logic
|
||||
// // - Emit SSE chunks with format similar to chat but without delta.tool_calls
|
||||
// // Reference: router.rs:422-595
|
||||
// }
|
||||
/// Process streaming generate response and return SSE response
|
||||
///
|
||||
/// Simpler than chat - no tool/reasoning parsing, just text accumulation
|
||||
pub fn process_streaming_generate(
|
||||
self: Arc<Self>,
|
||||
execution_result: context::ExecutionResult,
|
||||
generate_request: GenerateRequest,
|
||||
dispatch: context::DispatchMetadata,
|
||||
) -> Response {
|
||||
let return_logprob = generate_request.return_logprob;
|
||||
|
||||
// Create SSE channel
|
||||
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
|
||||
|
||||
// Spawn background task based on execution mode
|
||||
match execution_result {
|
||||
context::ExecutionResult::Single { stream } => {
|
||||
let tokenizer = self.tokenizer.clone();
|
||||
let request_id = dispatch.request_id.clone();
|
||||
let weight_version = dispatch
|
||||
.weight_version
|
||||
.clone()
|
||||
.unwrap_or_else(|| "default".to_string());
|
||||
tokio::spawn(async move {
|
||||
let result = Self::process_generate_streaming(
|
||||
tokenizer,
|
||||
stream,
|
||||
request_id,
|
||||
weight_version,
|
||||
return_logprob,
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
let error_chunk = format!("data: {{\"error\": \"{}\"}}\n\n", e);
|
||||
let _ = tx.send(Ok(Bytes::from(error_chunk)));
|
||||
}
|
||||
|
||||
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
|
||||
});
|
||||
}
|
||||
context::ExecutionResult::Dual { prefill, decode } => {
|
||||
// For PD mode, need to handle prefill stream for input_logprobs
|
||||
let tokenizer = self.tokenizer.clone();
|
||||
let request_id = dispatch.request_id.clone();
|
||||
let weight_version = dispatch
|
||||
.weight_version
|
||||
.clone()
|
||||
.unwrap_or_else(|| "default".to_string());
|
||||
tokio::spawn(async move {
|
||||
let result = Self::process_generate_streaming_dual(
|
||||
tokenizer,
|
||||
prefill,
|
||||
*decode,
|
||||
request_id,
|
||||
weight_version,
|
||||
return_logprob,
|
||||
&tx,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
let error_chunk = format!("data: {{\"error\": \"{}\"}}\n\n", e);
|
||||
let _ = tx.send(Ok(Bytes::from(error_chunk)));
|
||||
}
|
||||
|
||||
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Return SSE response
|
||||
build_sse_response(rx)
|
||||
}
|
||||
|
||||
//TODO add streaming logprob support
|
||||
/// Process streaming chunks for generate endpoint (no tool/reasoning parsing)
|
||||
async fn process_generate_streaming(
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
mut stream: Streaming<proto::GenerateResponse>,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
_include_logprobs: bool,
|
||||
tx: &UnboundedSender<Result<Bytes, io::Error>>,
|
||||
) -> Result<(), String> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Track state per index for n>1 case
|
||||
let mut accumulated_texts: HashMap<u32, String> = HashMap::new();
|
||||
let mut completion_tokens_map: HashMap<u32, u32> = HashMap::new();
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
|
||||
|
||||
match gen_response.response {
|
||||
Some(Chunk(chunk)) => {
|
||||
let index = chunk.index;
|
||||
|
||||
// Update completion tokens for this index
|
||||
let completion_tokens = completion_tokens_map.entry(index).or_insert(0);
|
||||
*completion_tokens += chunk.token_ids.len() as u32;
|
||||
|
||||
// Decode tokens to text (skip_special_tokens=true to handle newlines correctly)
|
||||
let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default();
|
||||
|
||||
// Accumulate text for this index
|
||||
let accumulated_text = accumulated_texts.entry(index).or_default();
|
||||
accumulated_text.push_str(&chunk_text);
|
||||
|
||||
// Generate unique ID per index
|
||||
let index_id = format!("{}-{}", request_id, index);
|
||||
|
||||
// Build streaming response chunk (SGLang format)
|
||||
let chunk_response = serde_json::json!({
|
||||
"text": accumulated_text.clone(),
|
||||
"output_ids": chunk.token_ids,
|
||||
"meta_info": {
|
||||
"id": index_id,
|
||||
"finish_reason": null,
|
||||
"prompt_tokens": chunk.prompt_tokens,
|
||||
"weight_version": &weight_version,
|
||||
"completion_tokens": *completion_tokens,
|
||||
"cached_tokens": chunk.cached_tokens
|
||||
},
|
||||
"index": index
|
||||
});
|
||||
|
||||
let sse_chunk = format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&chunk_response).unwrap()
|
||||
);
|
||||
tx.send(Ok(Bytes::from(sse_chunk)))
|
||||
.map_err(|_| "Failed to send chunk".to_string())?;
|
||||
}
|
||||
Some(Complete(complete)) => {
|
||||
let index = complete.index;
|
||||
let accumulated_text =
|
||||
accumulated_texts.get(&index).cloned().unwrap_or_default();
|
||||
let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0);
|
||||
let index_id = format!("{}-{}", request_id, index);
|
||||
let e2e_latency = start_time.elapsed().as_secs_f64();
|
||||
|
||||
// Send final chunk with finish_reason
|
||||
let finish_response = serde_json::json!({
|
||||
"text": accumulated_text,
|
||||
"output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(),
|
||||
"meta_info": {
|
||||
"id": index_id,
|
||||
"finish_reason": complete.finish_reason,
|
||||
"prompt_tokens": complete.prompt_tokens,
|
||||
"weight_version": &weight_version,
|
||||
"completion_tokens": completion_tokens,
|
||||
"cached_tokens": complete.cached_tokens,
|
||||
"e2e_latency": e2e_latency
|
||||
},
|
||||
"index": index
|
||||
});
|
||||
|
||||
let sse_chunk = format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&finish_response).unwrap()
|
||||
);
|
||||
tx.send(Ok(Bytes::from(sse_chunk)))
|
||||
.map_err(|_| "Failed to send finish chunk".to_string())?;
|
||||
|
||||
// Continue to process all completions if n>1
|
||||
}
|
||||
Some(Error(error)) => {
|
||||
return Err(error.message);
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process dual streaming for generate endpoint (PD mode with logprobs support)
|
||||
async fn process_generate_streaming_dual(
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
mut prefill_stream: Streaming<proto::GenerateResponse>,
|
||||
decode_stream: Streaming<proto::GenerateResponse>,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
return_logprob: bool,
|
||||
tx: &UnboundedSender<Result<Bytes, io::Error>>,
|
||||
) -> Result<(), String> {
|
||||
// Collect input_logprobs from prefill stream if requested
|
||||
let input_token_logprobs = if return_logprob {
|
||||
let mut input_logprobs = None;
|
||||
while let Some(response) = prefill_stream.next().await {
|
||||
let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?;
|
||||
match gen_response.response {
|
||||
Some(Complete(complete)) => {
|
||||
// Extract input_logprobs from prefill Complete message (convert proto to SGLang format)
|
||||
input_logprobs = complete
|
||||
.input_logprobs
|
||||
.as_ref()
|
||||
.map(utils::convert_generate_input_logprobs);
|
||||
break;
|
||||
}
|
||||
Some(Error(error)) => {
|
||||
return Err(format!("Prefill error: {}", error.message));
|
||||
}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
input_logprobs
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Process decode stream with input_logprobs prepended
|
||||
Self::process_generate_streaming_with_input_logprobs(
|
||||
tokenizer,
|
||||
decode_stream,
|
||||
request_id,
|
||||
weight_version,
|
||||
return_logprob,
|
||||
input_token_logprobs,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Process generate streaming with optional input_logprobs
|
||||
async fn process_generate_streaming_with_input_logprobs(
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
mut stream: Streaming<proto::GenerateResponse>,
|
||||
request_id: String,
|
||||
weight_version: String,
|
||||
_include_logprobs: bool,
|
||||
input_token_logprobs: Option<Vec<Vec<Option<f64>>>>,
|
||||
tx: &UnboundedSender<Result<Bytes, io::Error>>,
|
||||
) -> Result<(), String> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Track state per index for n>1 case
|
||||
let mut accumulated_texts: HashMap<u32, String> = HashMap::new();
|
||||
let mut accumulated_output_logprobs: HashMap<u32, Option<Vec<Vec<Option<f64>>>>> =
|
||||
HashMap::new();
|
||||
let mut completion_tokens_map: HashMap<u32, u32> = HashMap::new();
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
|
||||
|
||||
match gen_response.response {
|
||||
Some(Chunk(chunk)) => {
|
||||
let index = chunk.index;
|
||||
|
||||
// Update completion tokens for this index
|
||||
let completion_tokens = completion_tokens_map.entry(index).or_insert(0);
|
||||
*completion_tokens += chunk.token_ids.len() as u32;
|
||||
|
||||
// Decode tokens to text
|
||||
let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default();
|
||||
|
||||
// Accumulate text for this index
|
||||
let accumulated_text = accumulated_texts.entry(index).or_default();
|
||||
accumulated_text.push_str(&chunk_text);
|
||||
|
||||
// Store latest output logprobs (cumulative from proto, convert to SGLang format)
|
||||
if let Some(ref output_logprobs) = chunk.output_logprobs {
|
||||
let converted =
|
||||
super::utils::convert_generate_output_logprobs(output_logprobs);
|
||||
accumulated_output_logprobs.insert(index, Some(converted));
|
||||
}
|
||||
|
||||
// Generate unique ID per index
|
||||
let index_id = format!("{}-{}", request_id, index);
|
||||
|
||||
// Build streaming response chunk with cumulative logprobs
|
||||
let current_output_logprobs = accumulated_output_logprobs
|
||||
.get(&index)
|
||||
.and_then(|o| o.as_ref());
|
||||
|
||||
let chunk_response = serde_json::json!({
|
||||
"text": accumulated_text.clone(),
|
||||
"output_ids": chunk.token_ids,
|
||||
"meta_info": {
|
||||
"id": index_id,
|
||||
"finish_reason": null,
|
||||
"prompt_tokens": chunk.prompt_tokens,
|
||||
"weight_version": &weight_version,
|
||||
"input_token_logprobs": input_token_logprobs.as_ref(),
|
||||
"output_token_logprobs": current_output_logprobs,
|
||||
"completion_tokens": *completion_tokens,
|
||||
"cached_tokens": chunk.cached_tokens
|
||||
},
|
||||
"index": index
|
||||
});
|
||||
|
||||
let sse_chunk = format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&chunk_response).unwrap()
|
||||
);
|
||||
tx.send(Ok(Bytes::from(sse_chunk)))
|
||||
.map_err(|_| "Failed to send chunk".to_string())?;
|
||||
}
|
||||
Some(Complete(complete)) => {
|
||||
let index = complete.index;
|
||||
let accumulated_text =
|
||||
accumulated_texts.get(&index).cloned().unwrap_or_default();
|
||||
let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0);
|
||||
let final_output_logprobs = accumulated_output_logprobs
|
||||
.get(&index)
|
||||
.and_then(|o| o.as_ref());
|
||||
let index_id = format!("{}-{}", request_id, index);
|
||||
let e2e_latency = start_time.elapsed().as_secs_f64();
|
||||
|
||||
// Parse finish_reason
|
||||
let finish_reason = utils::parse_finish_reason(
|
||||
&complete.finish_reason,
|
||||
complete.completion_tokens,
|
||||
);
|
||||
|
||||
// Send final chunk with finish_reason
|
||||
let finish_response = serde_json::json!({
|
||||
"text": accumulated_text,
|
||||
"output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(),
|
||||
"meta_info": {
|
||||
"id": index_id,
|
||||
"finish_reason": finish_reason,
|
||||
"prompt_tokens": complete.prompt_tokens,
|
||||
"weight_version": &weight_version,
|
||||
"input_token_logprobs": input_token_logprobs.as_ref(),
|
||||
"output_token_logprobs": final_output_logprobs,
|
||||
"completion_tokens": completion_tokens,
|
||||
"cached_tokens": complete.cached_tokens,
|
||||
"e2e_latency": e2e_latency
|
||||
},
|
||||
"index": index
|
||||
});
|
||||
|
||||
let sse_chunk = format!(
|
||||
"data: {}\n\n",
|
||||
serde_json::to_string(&finish_response).unwrap()
|
||||
);
|
||||
tx.send(Ok(Bytes::from(sse_chunk)))
|
||||
.map_err(|_| "Failed to send finish chunk".to_string())?;
|
||||
|
||||
// Continue to process all completions if n>1
|
||||
}
|
||||
Some(Error(error)) => {
|
||||
return Err(error.message);
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Helper Methods
|
||||
@@ -842,9 +1179,7 @@ impl StreamingProcessor {
|
||||
}
|
||||
|
||||
/// Build SSE response with proper headers
|
||||
pub fn build_sse_response(
|
||||
rx: tokio::sync::mpsc::UnboundedReceiver<Result<Bytes, io::Error>>,
|
||||
) -> Response {
|
||||
pub fn build_sse_response(rx: mpsc::UnboundedReceiver<Result<Bytes, io::Error>>) -> Response {
|
||||
let stream = UnboundedReceiverStream::new(rx);
|
||||
let mut response = Response::new(Body::from_stream(stream));
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::core::Worker;
|
||||
use crate::grpc_client::{proto, SglangSchedulerClient};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse,
|
||||
StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb,
|
||||
GenerateFinishReason, StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb,
|
||||
};
|
||||
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
@@ -809,6 +809,70 @@ pub fn convert_proto_to_openai_logprobs(
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert proto::OutputLogProbs to Generate format Vec<Vec<Option<f64>>>
|
||||
///
|
||||
/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...]
|
||||
/// Each inner vec contains [logprob (f64), token_id (i32), ...]
|
||||
pub fn convert_generate_output_logprobs(
|
||||
proto_logprobs: &proto::OutputLogProbs,
|
||||
) -> Vec<Vec<Option<f64>>> {
|
||||
proto_logprobs
|
||||
.token_logprobs
|
||||
.iter()
|
||||
.zip(proto_logprobs.token_ids.iter())
|
||||
.map(|(&logprob, &token_id)| vec![Some(logprob as f64), Some(token_id as f64)])
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convert proto::InputLogProbs to Generate format Vec<Vec<Option<f64>>>
|
||||
///
|
||||
/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...]
|
||||
/// First token has null logprob: [[null, token_id], [logprob, token_id], ...]
|
||||
pub fn convert_generate_input_logprobs(
|
||||
proto_logprobs: &proto::InputLogProbs,
|
||||
) -> Vec<Vec<Option<f64>>> {
|
||||
proto_logprobs
|
||||
.token_logprobs
|
||||
.iter()
|
||||
.zip(proto_logprobs.token_ids.iter())
|
||||
.map(|(token_logprob, &token_id)| {
|
||||
// InputTokenLogProb has optional value field
|
||||
let logprob_value = token_logprob.value.map(|v| v as f64);
|
||||
vec![logprob_value, Some(token_id as f64)]
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse finish_reason string into GenerateFinishReason enum
|
||||
///
|
||||
/// Uses serde to deserialize the finish_reason, which handles all tagged variants automatically.
|
||||
/// The GenerateFinishReason enum is tagged with `#[serde(tag = "type", rename_all = "lowercase")]`,
|
||||
/// so it expects JSON objects like:
|
||||
/// - `{"type":"stop"}` -> Stop
|
||||
/// - `{"type":"length","length":100}` -> Length { length: 100 }
|
||||
/// - Any other JSON -> Other(...)
|
||||
///
|
||||
/// For backward compatibility, also handles simple string "stop" -> Stop
|
||||
pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> GenerateFinishReason {
|
||||
if reason_str == "stop" {
|
||||
return GenerateFinishReason::Stop;
|
||||
}
|
||||
|
||||
if reason_str == "length" {
|
||||
return GenerateFinishReason::Length {
|
||||
length: completion_tokens.max(0) as u32,
|
||||
};
|
||||
}
|
||||
|
||||
match serde_json::from_str::<GenerateFinishReason>(reason_str) {
|
||||
Ok(finish_reason) => finish_reason,
|
||||
Err(_) => match serde_json::from_str::<Value>(reason_str) {
|
||||
Ok(json_value) => GenerateFinishReason::Other(json_value),
|
||||
Err(_) => GenerateFinishReason::Other(Value::String(reason_str.to_string())),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user