[router][grpc] Further delegate non-stream processing to processing.rs (#11553)
This commit is contained in:
@@ -23,6 +23,9 @@ use std::sync::Arc;
|
||||
|
||||
use tracing::debug;
|
||||
|
||||
use super::context::SharedComponents;
|
||||
use super::pipeline::RequestPipeline;
|
||||
|
||||
/// gRPC PD (Prefill-Decode) router implementation for SGLang
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||
@@ -37,8 +40,8 @@ pub struct GrpcPDRouter {
|
||||
retry_config: RetryConfig,
|
||||
configured_reasoning_parser: Option<String>,
|
||||
configured_tool_parser: Option<String>,
|
||||
pipeline: super::pipeline::ChatCompletionPipeline,
|
||||
shared_components: Arc<super::context::SharedComponents>,
|
||||
pipeline: RequestPipeline,
|
||||
shared_components: Arc<SharedComponents>,
|
||||
}
|
||||
|
||||
impl GrpcPDRouter {
|
||||
@@ -66,36 +69,21 @@ impl GrpcPDRouter {
|
||||
.clone();
|
||||
|
||||
// Create shared components for pipeline
|
||||
let shared_components = Arc::new(super::context::SharedComponents {
|
||||
let shared_components = Arc::new(SharedComponents {
|
||||
tokenizer: tokenizer.clone(),
|
||||
tool_parser_factory: tool_parser_factory.clone(),
|
||||
reasoning_parser_factory: reasoning_parser_factory.clone(),
|
||||
});
|
||||
|
||||
// Create response processor
|
||||
let processor = super::processing::ResponseProcessor::new(
|
||||
tokenizer.clone(),
|
||||
tool_parser_factory.clone(),
|
||||
reasoning_parser_factory.clone(),
|
||||
ctx.configured_tool_parser.clone(),
|
||||
ctx.configured_reasoning_parser.clone(),
|
||||
);
|
||||
|
||||
// Create streaming processor
|
||||
let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new(
|
||||
tokenizer.clone(),
|
||||
tool_parser_factory.clone(),
|
||||
reasoning_parser_factory.clone(),
|
||||
ctx.configured_tool_parser.clone(),
|
||||
ctx.configured_reasoning_parser.clone(),
|
||||
));
|
||||
|
||||
// Create PD pipeline
|
||||
let pipeline = super::pipeline::ChatCompletionPipeline::new_pd(
|
||||
let pipeline = RequestPipeline::new_pd(
|
||||
worker_registry.clone(),
|
||||
policy_registry.clone(),
|
||||
processor,
|
||||
streaming_processor,
|
||||
tokenizer.clone(),
|
||||
tool_parser_factory.clone(),
|
||||
reasoning_parser_factory.clone(),
|
||||
ctx.configured_tool_parser.clone(),
|
||||
ctx.configured_reasoning_parser.clone(),
|
||||
);
|
||||
|
||||
Ok(GrpcPDRouter {
|
||||
|
||||
@@ -14,13 +14,10 @@ use super::utils;
|
||||
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
|
||||
use crate::grpc_client::proto;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, ChatCompletionResponse, GenerateMetaInfo, GenerateRequest,
|
||||
GenerateResponse, InputIds, Usage,
|
||||
};
|
||||
use crate::tokenizer::stop::SequenceDecoderOutput;
|
||||
use crate::protocols::spec::{ChatCompletionRequest, GenerateRequest, InputIds};
|
||||
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use proto::generate_complete::MatchedStop;
|
||||
use crate::tool_parser::ParserFactory as ToolParserFactory;
|
||||
use proto::DisaggregatedParams;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
@@ -790,114 +787,32 @@ impl ResponseProcessingStage {
|
||||
.take()
|
||||
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
|
||||
|
||||
if is_streaming {
|
||||
// Get dispatch metadata for consistent response fields
|
||||
let dispatch = ctx
|
||||
.state
|
||||
.dispatch
|
||||
.as_ref()
|
||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
|
||||
// Get dispatch metadata (needed by both streaming and non-streaming)
|
||||
let dispatch = ctx
|
||||
.state
|
||||
.dispatch
|
||||
.as_ref()
|
||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?
|
||||
.clone();
|
||||
|
||||
if is_streaming {
|
||||
// Streaming: Use StreamingProcessor and return SSE response (done)
|
||||
return Ok(Some(
|
||||
self.streaming_processor.clone().process_streaming_response(
|
||||
execution_result,
|
||||
ctx.chat_request_arc(), // Cheap Arc clone (8 bytes)
|
||||
dispatch.clone(),
|
||||
dispatch,
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Non-streaming: Extract chat request details before mutable borrows
|
||||
// Non-streaming: Delegate to ResponseProcessor
|
||||
let request_logprobs = match &ctx.input.request_type {
|
||||
RequestType::Chat(req) => req.logprobs,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
// Collect all responses from the execution result
|
||||
let all_responses = match execution_result {
|
||||
ExecutionResult::Single { mut stream } => {
|
||||
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
|
||||
stream.mark_completed();
|
||||
responses
|
||||
}
|
||||
ExecutionResult::Dual {
|
||||
mut prefill,
|
||||
decode,
|
||||
} => {
|
||||
// Collect prefill for input_logprobs (don't mark completed yet)
|
||||
let prefill_responses =
|
||||
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
|
||||
|
||||
// Collect decode for actual output (don't mark completed yet)
|
||||
let mut decode_stream = *decode;
|
||||
let mut decode_responses =
|
||||
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
|
||||
|
||||
// Mark both streams as completed now that both succeeded
|
||||
prefill.mark_completed();
|
||||
decode_stream.mark_completed();
|
||||
|
||||
// Merge prefill input_logprobs if requested
|
||||
if request_logprobs {
|
||||
if let Some(prefill_input_logprobs) = prefill_responses
|
||||
.first()
|
||||
.and_then(|r| r.input_logprobs.clone())
|
||||
{
|
||||
for response in &mut decode_responses {
|
||||
response.input_logprobs = Some(prefill_input_logprobs.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
decode_responses
|
||||
}
|
||||
};
|
||||
|
||||
if all_responses.is_empty() {
|
||||
return Err(utils::internal_error_static("No responses from server"));
|
||||
}
|
||||
|
||||
let chat_request = ctx.chat_request_arc();
|
||||
let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request);
|
||||
|
||||
// Check parser availability once upfront (not per choice)
|
||||
let reasoning_parser_available = chat_request.separate_reasoning
|
||||
&& utils::check_reasoning_parser_availability(
|
||||
&self.processor.reasoning_parser_factory,
|
||||
self.processor.configured_reasoning_parser.as_ref(),
|
||||
&chat_request.model,
|
||||
);
|
||||
|
||||
let tool_choice_enabled = !matches!(
|
||||
&chat_request.tool_choice,
|
||||
Some(crate::protocols::spec::ToolChoice::Value(
|
||||
crate::protocols::spec::ToolChoiceValue::None
|
||||
))
|
||||
);
|
||||
|
||||
let tool_parser_available = tool_choice_enabled
|
||||
&& chat_request.tools.is_some()
|
||||
&& utils::check_tool_parser_availability(
|
||||
&self.processor.tool_parser_factory,
|
||||
self.processor.configured_tool_parser.as_ref(),
|
||||
&chat_request.model,
|
||||
);
|
||||
|
||||
// Log once per request (not per choice)
|
||||
if chat_request.separate_reasoning && !reasoning_parser_available {
|
||||
debug!(
|
||||
"No reasoning parser found for model '{}', skipping reasoning parsing",
|
||||
chat_request.model
|
||||
);
|
||||
}
|
||||
|
||||
if chat_request.tools.is_some() && tool_choice_enabled && !tool_parser_available {
|
||||
debug!(
|
||||
"No tool parser found for model '{}', skipping tool call parsing",
|
||||
chat_request.model
|
||||
);
|
||||
}
|
||||
|
||||
let stop_decoder = ctx
|
||||
.state
|
||||
@@ -906,60 +821,16 @@ impl ResponseProcessingStage {
|
||||
.as_mut()
|
||||
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
|
||||
|
||||
let mut choices = Vec::new();
|
||||
for (index, complete) in all_responses.iter().enumerate() {
|
||||
match self
|
||||
.processor
|
||||
.process_single_choice(
|
||||
complete,
|
||||
index,
|
||||
&chat_request,
|
||||
stop_decoder,
|
||||
history_tool_calls_count,
|
||||
reasoning_parser_available,
|
||||
tool_parser_available,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(choice) => choices.push(choice),
|
||||
Err(e) => {
|
||||
return Err(utils::internal_error_message(format!(
|
||||
"Failed to process choice {}: {}",
|
||||
index, e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build usage
|
||||
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
|
||||
let total_completion_tokens: u32 = all_responses
|
||||
.iter()
|
||||
.map(|r| r.completion_tokens as u32)
|
||||
.sum();
|
||||
let usage = Usage {
|
||||
prompt_tokens: total_prompt_tokens,
|
||||
completion_tokens: total_completion_tokens,
|
||||
total_tokens: total_prompt_tokens + total_completion_tokens,
|
||||
completion_tokens_details: None,
|
||||
};
|
||||
|
||||
// Build final ChatCompletionResponse
|
||||
let dispatch = ctx
|
||||
.state
|
||||
.dispatch
|
||||
.as_ref()
|
||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
|
||||
|
||||
let response = ChatCompletionResponse {
|
||||
id: dispatch.request_id.clone(),
|
||||
object: "chat.completion".to_string(),
|
||||
created: dispatch.created,
|
||||
model: dispatch.model.clone(),
|
||||
choices,
|
||||
usage: Some(usage),
|
||||
system_fingerprint: dispatch.weight_version.clone(),
|
||||
};
|
||||
let response = self
|
||||
.processor
|
||||
.process_non_streaming_chat_response(
|
||||
execution_result,
|
||||
chat_request,
|
||||
dispatch,
|
||||
stop_decoder,
|
||||
request_logprobs,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Store the final response
|
||||
ctx.state.response.final_response = Some(FinalResponse::Chat(response));
|
||||
@@ -982,70 +853,29 @@ impl ResponseProcessingStage {
|
||||
.take()
|
||||
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
|
||||
|
||||
if is_streaming {
|
||||
// Get dispatch metadata for consistent response fields
|
||||
let dispatch = ctx
|
||||
.state
|
||||
.dispatch
|
||||
.as_ref()
|
||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
|
||||
// Get dispatch metadata (needed by both streaming and non-streaming)
|
||||
let dispatch = ctx
|
||||
.state
|
||||
.dispatch
|
||||
.as_ref()
|
||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?
|
||||
.clone();
|
||||
|
||||
if is_streaming {
|
||||
// Streaming: Use StreamingProcessor and return SSE response (done)
|
||||
return Ok(Some(
|
||||
self.streaming_processor.clone().process_streaming_generate(
|
||||
execution_result,
|
||||
ctx.generate_request_arc(), // Cheap Arc clone (8 bytes)
|
||||
dispatch.clone(),
|
||||
dispatch,
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Non-streaming: Collect all responses
|
||||
// Non-streaming: Delegate to ResponseProcessor
|
||||
let request_logprobs = ctx.generate_request().return_logprob;
|
||||
let all_responses = match execution_result {
|
||||
ExecutionResult::Single { mut stream } => {
|
||||
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
|
||||
stream.mark_completed();
|
||||
responses
|
||||
}
|
||||
ExecutionResult::Dual {
|
||||
mut prefill,
|
||||
decode,
|
||||
} => {
|
||||
// Collect prefill for input_logprobs (don't mark completed yet)
|
||||
let prefill_responses =
|
||||
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
|
||||
let generate_request = ctx.generate_request_arc();
|
||||
|
||||
// Collect decode for actual output (don't mark completed yet)
|
||||
let mut decode_stream = *decode;
|
||||
let mut decode_responses =
|
||||
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
|
||||
|
||||
// Mark both streams as completed now that both succeeded
|
||||
prefill.mark_completed();
|
||||
decode_stream.mark_completed();
|
||||
|
||||
// Merge prefill input_logprobs if requested
|
||||
if request_logprobs {
|
||||
if let Some(prefill_input_logprobs) = prefill_responses
|
||||
.first()
|
||||
.and_then(|r| r.input_logprobs.clone())
|
||||
{
|
||||
for response in &mut decode_responses {
|
||||
response.input_logprobs = Some(prefill_input_logprobs.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
decode_responses
|
||||
}
|
||||
};
|
||||
|
||||
if all_responses.is_empty() {
|
||||
return Err(utils::internal_error_static("No responses from server"));
|
||||
}
|
||||
|
||||
// Get stop decoder for processing
|
||||
let stop_decoder = ctx
|
||||
.state
|
||||
.response
|
||||
@@ -1053,103 +883,17 @@ impl ResponseProcessingStage {
|
||||
.as_mut()
|
||||
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
|
||||
|
||||
// Get dispatch metadata
|
||||
let dispatch = ctx
|
||||
.state
|
||||
.dispatch
|
||||
.as_ref()
|
||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
|
||||
|
||||
// Process each completion (similar to router.rs:336-400)
|
||||
let mut result_array = Vec::new();
|
||||
for mut complete in all_responses {
|
||||
stop_decoder.reset();
|
||||
|
||||
// Process tokens through stop decoder
|
||||
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
|
||||
Ok(outputs) => outputs,
|
||||
Err(e) => {
|
||||
return Err(utils::internal_error_message(format!(
|
||||
"Failed to process tokens: {}",
|
||||
e
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Accumulate text with early breaks
|
||||
let mut decoded_text = String::new();
|
||||
for output in outputs {
|
||||
match output {
|
||||
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
|
||||
SequenceDecoderOutput::StoppedWithText(t) => {
|
||||
decoded_text.push_str(&t);
|
||||
break;
|
||||
}
|
||||
SequenceDecoderOutput::Stopped => break,
|
||||
SequenceDecoderOutput::Held => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush remaining text
|
||||
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
|
||||
decoded_text.push_str(&t);
|
||||
}
|
||||
|
||||
let output_ids = std::mem::take(&mut complete.output_ids);
|
||||
let finish_reason_str = std::mem::take(&mut complete.finish_reason);
|
||||
|
||||
// Parse finish_reason from string to proper type
|
||||
let finish_reason =
|
||||
utils::parse_finish_reason(&finish_reason_str, complete.completion_tokens);
|
||||
|
||||
// Handle matched_stop if present
|
||||
let matched_stop = complete.matched_stop.take().map(|matched| match matched {
|
||||
MatchedStop::MatchedTokenId(id) => serde_json::json!(id),
|
||||
MatchedStop::MatchedStopStr(s) => serde_json::json!(s),
|
||||
});
|
||||
|
||||
// Extract logprobs if requested (convert proto types to Generate format)
|
||||
let input_token_logprobs = if request_logprobs {
|
||||
complete
|
||||
.input_logprobs
|
||||
.as_ref()
|
||||
.map(utils::convert_generate_input_logprobs)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let output_token_logprobs = if request_logprobs {
|
||||
complete
|
||||
.output_logprobs
|
||||
.as_ref()
|
||||
.map(utils::convert_generate_output_logprobs)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build GenerateResponse struct
|
||||
let meta_info = GenerateMetaInfo {
|
||||
id: dispatch.request_id.clone(),
|
||||
finish_reason,
|
||||
prompt_tokens: complete.prompt_tokens as u32,
|
||||
weight_version: dispatch
|
||||
.weight_version
|
||||
.clone()
|
||||
.unwrap_or_else(|| "default".to_string()),
|
||||
input_token_logprobs,
|
||||
output_token_logprobs,
|
||||
completion_tokens: complete.completion_tokens as u32,
|
||||
cached_tokens: complete.cached_tokens as u32,
|
||||
e2e_latency: start_time.elapsed().as_secs_f64(),
|
||||
matched_stop,
|
||||
};
|
||||
|
||||
result_array.push(GenerateResponse {
|
||||
text: decoded_text,
|
||||
output_ids,
|
||||
meta_info,
|
||||
});
|
||||
}
|
||||
let result_array = self
|
||||
.processor
|
||||
.process_non_streaming_generate_response(
|
||||
execution_result,
|
||||
generate_request,
|
||||
dispatch,
|
||||
stop_decoder,
|
||||
request_logprobs,
|
||||
start_time,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Store the final response
|
||||
ctx.state.response.final_response = Some(FinalResponse::Generate(result_array));
|
||||
@@ -1162,23 +906,44 @@ impl ResponseProcessingStage {
|
||||
// Pipeline Orchestrator
|
||||
// ============================================================================
|
||||
|
||||
/// Complete chat completion pipeline
|
||||
/// Generic request pipeline for all request types
|
||||
///
|
||||
/// Orchestrates all stages from request preparation to response delivery.
|
||||
/// Configured differently for regular vs PD mode.
|
||||
#[derive(Clone)]
|
||||
pub struct ChatCompletionPipeline {
|
||||
pub struct RequestPipeline {
|
||||
stages: Arc<Vec<Box<dyn PipelineStage>>>,
|
||||
}
|
||||
|
||||
impl ChatCompletionPipeline {
|
||||
impl RequestPipeline {
|
||||
/// Create a regular (single-worker) pipeline
|
||||
pub fn new_regular(
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
policy_registry: Arc<PolicyRegistry>,
|
||||
processor: processing::ResponseProcessor,
|
||||
streaming_processor: Arc<streaming::StreamingProcessor>,
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
tool_parser_factory: ToolParserFactory,
|
||||
reasoning_parser_factory: ReasoningParserFactory,
|
||||
configured_tool_parser: Option<String>,
|
||||
configured_reasoning_parser: Option<String>,
|
||||
) -> Self {
|
||||
// Create response processor
|
||||
let processor = processing::ResponseProcessor::new(
|
||||
tokenizer.clone(),
|
||||
tool_parser_factory.clone(),
|
||||
reasoning_parser_factory.clone(),
|
||||
configured_tool_parser.clone(),
|
||||
configured_reasoning_parser.clone(),
|
||||
);
|
||||
|
||||
// Create streaming processor
|
||||
let streaming_processor = Arc::new(streaming::StreamingProcessor::new(
|
||||
tokenizer,
|
||||
tool_parser_factory,
|
||||
reasoning_parser_factory,
|
||||
configured_tool_parser,
|
||||
configured_reasoning_parser,
|
||||
));
|
||||
|
||||
let stages: Vec<Box<dyn PipelineStage>> = vec![
|
||||
Box::new(PreparationStage),
|
||||
Box::new(WorkerSelectionStage::new(
|
||||
@@ -1190,10 +955,7 @@ impl ChatCompletionPipeline {
|
||||
Box::new(RequestBuildingStage::new(false)), // No PD metadata
|
||||
Box::new(DispatchMetadataStage),
|
||||
Box::new(RequestExecutionStage::new(ExecutionMode::Single)),
|
||||
Box::new(ResponseProcessingStage::new(
|
||||
processor,
|
||||
streaming_processor.clone(),
|
||||
)),
|
||||
Box::new(ResponseProcessingStage::new(processor, streaming_processor)),
|
||||
];
|
||||
|
||||
Self {
|
||||
@@ -1205,9 +967,30 @@ impl ChatCompletionPipeline {
|
||||
pub fn new_pd(
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
policy_registry: Arc<PolicyRegistry>,
|
||||
processor: processing::ResponseProcessor,
|
||||
streaming_processor: Arc<streaming::StreamingProcessor>,
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
tool_parser_factory: ToolParserFactory,
|
||||
reasoning_parser_factory: ReasoningParserFactory,
|
||||
configured_tool_parser: Option<String>,
|
||||
configured_reasoning_parser: Option<String>,
|
||||
) -> Self {
|
||||
// Create response processor
|
||||
let processor = processing::ResponseProcessor::new(
|
||||
tokenizer.clone(),
|
||||
tool_parser_factory.clone(),
|
||||
reasoning_parser_factory.clone(),
|
||||
configured_tool_parser.clone(),
|
||||
configured_reasoning_parser.clone(),
|
||||
);
|
||||
|
||||
// Create streaming processor
|
||||
let streaming_processor = Arc::new(streaming::StreamingProcessor::new(
|
||||
tokenizer,
|
||||
tool_parser_factory,
|
||||
reasoning_parser_factory,
|
||||
configured_tool_parser,
|
||||
configured_reasoning_parser,
|
||||
));
|
||||
|
||||
let stages: Vec<Box<dyn PipelineStage>> = vec![
|
||||
Box::new(PreparationStage),
|
||||
Box::new(WorkerSelectionStage::new(
|
||||
@@ -1219,10 +1002,7 @@ impl ChatCompletionPipeline {
|
||||
Box::new(RequestBuildingStage::new(true)), // Inject PD metadata
|
||||
Box::new(DispatchMetadataStage),
|
||||
Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)),
|
||||
Box::new(ResponseProcessingStage::new(
|
||||
processor,
|
||||
streaming_processor.clone(),
|
||||
)),
|
||||
Box::new(ResponseProcessingStage::new(processor, streaming_processor)),
|
||||
];
|
||||
|
||||
Self {
|
||||
|
||||
@@ -10,14 +10,18 @@ use tracing::error;
|
||||
|
||||
use crate::grpc_client::proto;
|
||||
use crate::protocols::spec::{
|
||||
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, FunctionCallResponse, ToolCall,
|
||||
ToolChoice, ToolChoiceValue,
|
||||
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
|
||||
FunctionCallResponse, GenerateMetaInfo, GenerateRequest, GenerateResponse, ToolCall,
|
||||
ToolChoice, ToolChoiceValue, Usage,
|
||||
};
|
||||
use crate::reasoning_parser::ParserFactory as ReasoningParserFactory;
|
||||
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserFactory as ToolParserFactory;
|
||||
use proto::generate_complete::MatchedStop;
|
||||
use std::time::Instant;
|
||||
|
||||
use super::context::{DispatchMetadata, ExecutionResult};
|
||||
use super::utils;
|
||||
|
||||
// ============================================================================
|
||||
@@ -51,6 +55,57 @@ impl ResponseProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to collect responses from execution result and merge logprobs if needed
|
||||
async fn collect_and_merge_responses(
|
||||
execution_result: ExecutionResult,
|
||||
request_logprobs: bool,
|
||||
) -> Result<Vec<proto::GenerateComplete>, axum::response::Response> {
|
||||
let all_responses = match execution_result {
|
||||
ExecutionResult::Single { mut stream } => {
|
||||
let responses = utils::collect_stream_responses(&mut stream, "Single").await?;
|
||||
stream.mark_completed();
|
||||
responses
|
||||
}
|
||||
ExecutionResult::Dual {
|
||||
mut prefill,
|
||||
decode,
|
||||
} => {
|
||||
// Collect prefill for input_logprobs (don't mark completed yet)
|
||||
let prefill_responses =
|
||||
utils::collect_stream_responses(&mut prefill, "Prefill").await?;
|
||||
|
||||
// Collect decode for actual output (don't mark completed yet)
|
||||
let mut decode_stream = *decode;
|
||||
let mut decode_responses =
|
||||
utils::collect_stream_responses(&mut decode_stream, "Decode").await?;
|
||||
|
||||
// Mark both streams as completed now that both succeeded
|
||||
prefill.mark_completed();
|
||||
decode_stream.mark_completed();
|
||||
|
||||
// Merge prefill input_logprobs if requested
|
||||
if request_logprobs {
|
||||
if let Some(prefill_input_logprobs) = prefill_responses
|
||||
.first()
|
||||
.and_then(|r| r.input_logprobs.clone())
|
||||
{
|
||||
for response in &mut decode_responses {
|
||||
response.input_logprobs = Some(prefill_input_logprobs.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
decode_responses
|
||||
}
|
||||
};
|
||||
|
||||
if all_responses.is_empty() {
|
||||
return Err(utils::internal_error_static("No responses from server"));
|
||||
}
|
||||
|
||||
Ok(all_responses)
|
||||
}
|
||||
|
||||
/// Process a single choice from GenerateComplete response (EXACT COPY from router.rs:1573-1725)
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn process_single_choice(
|
||||
@@ -158,12 +213,10 @@ impl ResponseProcessor {
|
||||
|
||||
// Extract matched_stop information from proto
|
||||
let matched_stop = match &complete.matched_stop {
|
||||
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
|
||||
Some(MatchedStop::MatchedTokenId(token_id)) => {
|
||||
Some(Value::Number(serde_json::Number::from(*token_id)))
|
||||
}
|
||||
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
|
||||
Some(Value::String(stop_str.clone()))
|
||||
}
|
||||
Some(MatchedStop::MatchedStopStr(stop_str)) => Some(Value::String(stop_str.clone())),
|
||||
None => None,
|
||||
};
|
||||
|
||||
@@ -205,6 +258,109 @@ impl ResponseProcessor {
|
||||
Ok(choice)
|
||||
}
|
||||
|
||||
/// Process non-streaming chat response (collects all responses and builds final response)
|
||||
pub async fn process_non_streaming_chat_response(
|
||||
&self,
|
||||
execution_result: ExecutionResult,
|
||||
chat_request: Arc<ChatCompletionRequest>,
|
||||
dispatch: DispatchMetadata,
|
||||
stop_decoder: &mut StopSequenceDecoder,
|
||||
request_logprobs: bool,
|
||||
) -> Result<ChatCompletionResponse, axum::response::Response> {
|
||||
// Collect all responses from the execution result
|
||||
let all_responses =
|
||||
Self::collect_and_merge_responses(execution_result, request_logprobs).await?;
|
||||
|
||||
let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request);
|
||||
|
||||
// Check parser availability once upfront (not per choice)
|
||||
let reasoning_parser_available = chat_request.separate_reasoning
|
||||
&& utils::check_reasoning_parser_availability(
|
||||
&self.reasoning_parser_factory,
|
||||
self.configured_reasoning_parser.as_ref(),
|
||||
&chat_request.model,
|
||||
);
|
||||
|
||||
let tool_choice_enabled = !matches!(
|
||||
&chat_request.tool_choice,
|
||||
Some(ToolChoice::Value(ToolChoiceValue::None))
|
||||
);
|
||||
|
||||
let tool_parser_available = tool_choice_enabled
|
||||
&& chat_request.tools.is_some()
|
||||
&& utils::check_tool_parser_availability(
|
||||
&self.tool_parser_factory,
|
||||
self.configured_tool_parser.as_ref(),
|
||||
&chat_request.model,
|
||||
);
|
||||
|
||||
// Log once per request (not per choice)
|
||||
if chat_request.separate_reasoning && !reasoning_parser_available {
|
||||
tracing::debug!(
|
||||
"No reasoning parser found for model '{}', skipping reasoning parsing",
|
||||
chat_request.model
|
||||
);
|
||||
}
|
||||
|
||||
if chat_request.tools.is_some() && tool_choice_enabled && !tool_parser_available {
|
||||
tracing::debug!(
|
||||
"No tool parser found for model '{}', skipping tool call parsing",
|
||||
chat_request.model
|
||||
);
|
||||
}
|
||||
|
||||
// Process all choices
|
||||
let mut choices = Vec::new();
|
||||
for (index, complete) in all_responses.iter().enumerate() {
|
||||
match self
|
||||
.process_single_choice(
|
||||
complete,
|
||||
index,
|
||||
&chat_request,
|
||||
stop_decoder,
|
||||
history_tool_calls_count,
|
||||
reasoning_parser_available,
|
||||
tool_parser_available,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(choice) => choices.push(choice),
|
||||
Err(e) => {
|
||||
return Err(utils::internal_error_message(format!(
|
||||
"Failed to process choice {}: {}",
|
||||
index, e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build usage
|
||||
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
|
||||
let total_completion_tokens: u32 = all_responses
|
||||
.iter()
|
||||
.map(|r| r.completion_tokens as u32)
|
||||
.sum();
|
||||
let usage = Usage {
|
||||
prompt_tokens: total_prompt_tokens,
|
||||
completion_tokens: total_completion_tokens,
|
||||
total_tokens: total_prompt_tokens + total_completion_tokens,
|
||||
completion_tokens_details: None,
|
||||
};
|
||||
|
||||
// Build final ChatCompletionResponse
|
||||
let response = ChatCompletionResponse {
|
||||
id: dispatch.request_id.clone(),
|
||||
object: "chat.completion".to_string(),
|
||||
created: dispatch.created,
|
||||
model: dispatch.model.clone(),
|
||||
choices,
|
||||
usage: Some(usage),
|
||||
system_fingerprint: dispatch.weight_version.clone(),
|
||||
};
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
/// Parse tool calls using model-specific parser (EXACT COPY from router.rs:296-361)
|
||||
pub async fn parse_tool_calls(
|
||||
&self,
|
||||
@@ -264,4 +420,112 @@ impl ResponseProcessor {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process non-streaming generate response (collects all responses and builds final response array)
|
||||
pub async fn process_non_streaming_generate_response(
|
||||
&self,
|
||||
execution_result: ExecutionResult,
|
||||
_generate_request: Arc<GenerateRequest>,
|
||||
dispatch: DispatchMetadata,
|
||||
stop_decoder: &mut StopSequenceDecoder,
|
||||
request_logprobs: bool,
|
||||
start_time: Instant,
|
||||
) -> Result<Vec<GenerateResponse>, axum::response::Response> {
|
||||
// Collect all responses from the execution result
|
||||
let all_responses =
|
||||
Self::collect_and_merge_responses(execution_result, request_logprobs).await?;
|
||||
|
||||
// Process each completion
|
||||
let mut result_array = Vec::new();
|
||||
for mut complete in all_responses {
|
||||
stop_decoder.reset();
|
||||
|
||||
// Process tokens through stop decoder
|
||||
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
|
||||
Ok(outputs) => outputs,
|
||||
Err(e) => {
|
||||
return Err(utils::internal_error_message(format!(
|
||||
"Failed to process tokens: {}",
|
||||
e
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Accumulate text with early breaks
|
||||
let mut decoded_text = String::new();
|
||||
for output in outputs {
|
||||
match output {
|
||||
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
|
||||
SequenceDecoderOutput::StoppedWithText(t) => {
|
||||
decoded_text.push_str(&t);
|
||||
break;
|
||||
}
|
||||
SequenceDecoderOutput::Stopped => break,
|
||||
SequenceDecoderOutput::Held => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush remaining text
|
||||
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
|
||||
decoded_text.push_str(&t);
|
||||
}
|
||||
|
||||
let output_ids = std::mem::take(&mut complete.output_ids);
|
||||
let finish_reason_str = std::mem::take(&mut complete.finish_reason);
|
||||
|
||||
// Parse finish_reason from string to proper type
|
||||
let finish_reason =
|
||||
utils::parse_finish_reason(&finish_reason_str, complete.completion_tokens);
|
||||
|
||||
// Handle matched_stop if present
|
||||
let matched_stop = complete.matched_stop.take().map(|matched| match matched {
|
||||
MatchedStop::MatchedTokenId(id) => serde_json::json!(id),
|
||||
MatchedStop::MatchedStopStr(s) => serde_json::json!(s),
|
||||
});
|
||||
|
||||
// Extract logprobs if requested (convert proto types to Generate format)
|
||||
let input_token_logprobs = if request_logprobs {
|
||||
complete
|
||||
.input_logprobs
|
||||
.as_ref()
|
||||
.map(utils::convert_generate_input_logprobs)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let output_token_logprobs = if request_logprobs {
|
||||
complete
|
||||
.output_logprobs
|
||||
.as_ref()
|
||||
.map(utils::convert_generate_output_logprobs)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build GenerateResponse struct
|
||||
let meta_info = GenerateMetaInfo {
|
||||
id: dispatch.request_id.clone(),
|
||||
finish_reason,
|
||||
prompt_tokens: complete.prompt_tokens as u32,
|
||||
weight_version: dispatch
|
||||
.weight_version
|
||||
.clone()
|
||||
.unwrap_or_else(|| "default".to_string()),
|
||||
input_token_logprobs,
|
||||
output_token_logprobs,
|
||||
completion_tokens: complete.completion_tokens as u32,
|
||||
cached_tokens: complete.cached_tokens as u32,
|
||||
e2e_latency: start_time.elapsed().as_secs_f64(),
|
||||
matched_stop,
|
||||
};
|
||||
|
||||
result_array.push(GenerateResponse {
|
||||
text: decoded_text,
|
||||
output_ids,
|
||||
meta_info,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(result_array)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,9 @@ use crate::server::AppContext;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserFactory as ToolParserFactory;
|
||||
|
||||
use super::context::SharedComponents;
|
||||
use super::pipeline::RequestPipeline;
|
||||
|
||||
/// gRPC router implementation for SGLang
|
||||
#[derive(Clone)]
|
||||
#[allow(dead_code)]
|
||||
@@ -38,8 +41,8 @@ pub struct GrpcRouter {
|
||||
retry_config: RetryConfig,
|
||||
configured_reasoning_parser: Option<String>,
|
||||
configured_tool_parser: Option<String>,
|
||||
pipeline: super::pipeline::ChatCompletionPipeline,
|
||||
shared_components: Arc<super::context::SharedComponents>,
|
||||
pipeline: RequestPipeline,
|
||||
shared_components: Arc<SharedComponents>,
|
||||
}
|
||||
|
||||
impl GrpcRouter {
|
||||
@@ -66,36 +69,21 @@ impl GrpcRouter {
|
||||
let policy_registry = ctx.policy_registry.clone();
|
||||
|
||||
// Create shared components for pipeline
|
||||
let shared_components = Arc::new(super::context::SharedComponents {
|
||||
let shared_components = Arc::new(SharedComponents {
|
||||
tokenizer: tokenizer.clone(),
|
||||
tool_parser_factory: tool_parser_factory.clone(),
|
||||
reasoning_parser_factory: reasoning_parser_factory.clone(),
|
||||
});
|
||||
|
||||
// Create response processor
|
||||
let processor = super::processing::ResponseProcessor::new(
|
||||
tokenizer.clone(),
|
||||
tool_parser_factory.clone(),
|
||||
reasoning_parser_factory.clone(),
|
||||
ctx.configured_tool_parser.clone(),
|
||||
ctx.configured_reasoning_parser.clone(),
|
||||
);
|
||||
|
||||
// Create streaming processor
|
||||
let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new(
|
||||
tokenizer.clone(),
|
||||
tool_parser_factory.clone(),
|
||||
reasoning_parser_factory.clone(),
|
||||
ctx.configured_tool_parser.clone(),
|
||||
ctx.configured_reasoning_parser.clone(),
|
||||
));
|
||||
|
||||
// Create pipeline
|
||||
let pipeline = super::pipeline::ChatCompletionPipeline::new_regular(
|
||||
let pipeline = RequestPipeline::new_regular(
|
||||
worker_registry.clone(),
|
||||
policy_registry.clone(),
|
||||
processor,
|
||||
streaming_processor,
|
||||
tokenizer.clone(),
|
||||
tool_parser_factory.clone(),
|
||||
reasoning_parser_factory.clone(),
|
||||
ctx.configured_tool_parser.clone(),
|
||||
ctx.configured_reasoning_parser.clone(),
|
||||
);
|
||||
|
||||
Ok(GrpcRouter {
|
||||
|
||||
Reference in New Issue
Block a user