[router] refactor generate to use new pipeline arch (#11323)

This commit is contained in:
Simo Lin
2025-10-08 12:38:50 -04:00
committed by GitHub
parent d6837aea4d
commit 01c9ee1ab4
7 changed files with 713 additions and 1181 deletions

View File

@@ -11,15 +11,20 @@ use super::context::*;
use super::processing;
use super::streaming;
use super::utils;
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType};
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
use crate::grpc_client::proto;
use crate::policies::PolicyRegistry;
use crate::protocols::spec::{
ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, InputIds, Usage,
ChatCompletionRequest, ChatCompletionResponse, GenerateMetaInfo, GenerateRequest,
GenerateResponse, InputIds, Usage,
};
use crate::tokenizer::stop::SequenceDecoderOutput;
use crate::tokenizer::traits::Tokenizer;
use proto::generate_complete::MatchedStop;
use proto::DisaggregatedParams;
use rand::Rng;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use uuid::Uuid;
// ============================================================================
@@ -208,7 +213,7 @@ impl PreparationStage {
fn tokenize_single_text(
&self,
tokenizer: &Arc<dyn crate::tokenizer::traits::Tokenizer>,
tokenizer: &Arc<dyn Tokenizer>,
text: &str,
) -> Result<(String, Vec<u32>), String> {
let encoding = tokenizer
@@ -302,7 +307,7 @@ impl WorkerSelectionStage {
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn crate::core::Worker>> {
) -> Option<Arc<dyn Worker>> {
// Get workers for the specified model, filtered by connection mode
let workers = self.worker_registry.get_workers_filtered(
model_id,
@@ -312,7 +317,7 @@ impl WorkerSelectionStage {
);
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn crate::core::Worker>> = workers
let available: Vec<Arc<dyn Worker>> = workers
.iter()
.filter(|w| w.is_available())
.cloned()
@@ -337,7 +342,7 @@ impl WorkerSelectionStage {
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<(Arc<dyn crate::core::Worker>, Arc<dyn crate::core::Worker>)> {
) -> Option<(Arc<dyn Worker>, Arc<dyn Worker>)> {
// Get prefill workers - use None for WorkerType filter to get all types,
// then filter manually (since Prefill is a struct variant)
let all_workers = self.worker_registry.get_workers_filtered(
@@ -537,10 +542,8 @@ impl RequestBuildingStage {
fn inject_bootstrap_metadata(
&self,
request: &mut proto::GenerateRequest,
prefill_worker: &Arc<dyn crate::core::Worker>,
prefill_worker: &Arc<dyn Worker>,
) {
use proto::DisaggregatedParams;
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
@@ -935,40 +938,183 @@ impl ResponseProcessingStage {
async fn process_generate_response(
&self,
_ctx: &mut RequestContext,
ctx: &mut RequestContext,
) -> Result<Option<Response>, Response> {
// TODO(generate): Implement generate response processing
//
// Required implementation:
// 1. Extract execution_result from ctx
// 2. Check is_streaming flag
// 3. For streaming:
// - Add StreamingProcessor::process_streaming_generate() method
// - Similar to process_streaming_response but WITHOUT tool/reasoning parsing
// - Return Err(sse_response) for early exit
// 4. For non-streaming:
// - Collect stream responses using utils::collect_stream_responses()
// - Process through stop decoder (sequential with reset for n>1, like chat)
// - Build GenerateResponse struct (see TODO in protocols/spec.rs)
// - Set ctx.state.response.final_response = Some(FinalResponse::Generate(response))
//
// Reference implementation: router.rs:297-595
// Key differences from chat:
// - No tool parsing
// - No reasoning parsing
// - Different response format (GenerateResponse instead of ChatCompletionResponse)
// - Still needs: stop decoder, logprobs, finish_reason, matched_stop
Err((
axum::http::StatusCode::NOT_IMPLEMENTED,
axum::Json(serde_json::json!({
"error": {
"message": "Generate response processing not yet implemented in pipeline",
"type": "not_implemented",
"code": 501
let start_time = Instant::now();
let is_streaming = ctx.is_streaming();
// Extract execution result
let execution_result = ctx
.state
.response
.execution_result
.take()
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
if is_streaming {
// Get dispatch metadata for consistent response fields
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
let generate_request = ctx.generate_request().clone();
// Streaming: Use StreamingProcessor and return SSE response (done)
return Ok(Some(
self.streaming_processor.clone().process_streaming_generate(
execution_result,
generate_request,
dispatch.clone(),
),
));
}
// Non-streaming: Collect all responses
let request_logprobs = ctx.generate_request().return_logprob;
let all_responses = match execution_result {
ExecutionResult::Single { stream } => {
utils::collect_stream_responses(stream, "Single").await?
}
ExecutionResult::Dual { prefill, decode } => {
// Collect prefill for input_logprobs
let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?;
// Collect decode for actual output
let mut decode_responses =
utils::collect_stream_responses(*decode, "Decode").await?;
// Merge prefill input_logprobs if requested
if request_logprobs {
if let Some(prefill_input_logprobs) = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone())
{
for response in &mut decode_responses {
response.input_logprobs = Some(prefill_input_logprobs.clone());
}
}
}
})),
)
.into_response())
decode_responses
}
};
if all_responses.is_empty() {
return Err(utils::internal_error_static("No responses from server"));
}
// Get stop decoder for processing
let stop_decoder = ctx
.state
.response
.stop_decoder
.as_mut()
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
// Get dispatch metadata
let dispatch = ctx
.state
.dispatch
.as_ref()
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
// Process each completion (similar to router.rs:336-400)
let mut result_array = Vec::new();
for mut complete in all_responses {
stop_decoder.reset();
// Process tokens through stop decoder
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
Ok(outputs) => outputs,
Err(e) => {
return Err(utils::internal_error_message(format!(
"Failed to process tokens: {}",
e
)))
}
};
// Accumulate text with early breaks
let mut decoded_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
decoded_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
decoded_text.push_str(&t);
}
let output_ids = std::mem::take(&mut complete.output_ids);
let finish_reason_str = std::mem::take(&mut complete.finish_reason);
// Parse finish_reason from string to proper type
let finish_reason =
utils::parse_finish_reason(&finish_reason_str, complete.completion_tokens);
// Handle matched_stop if present
let matched_stop = complete.matched_stop.take().map(|matched| match matched {
MatchedStop::MatchedTokenId(id) => serde_json::json!(id),
MatchedStop::MatchedStopStr(s) => serde_json::json!(s),
});
// Extract logprobs if requested (convert proto types to Generate format)
let input_token_logprobs = if request_logprobs {
complete
.input_logprobs
.as_ref()
.map(utils::convert_generate_input_logprobs)
} else {
None
};
let output_token_logprobs = if request_logprobs {
complete
.output_logprobs
.as_ref()
.map(utils::convert_generate_output_logprobs)
} else {
None
};
// Build GenerateResponse struct
let meta_info = GenerateMetaInfo {
id: dispatch.request_id.clone(),
finish_reason,
prompt_tokens: complete.prompt_tokens as u32,
weight_version: dispatch
.weight_version
.clone()
.unwrap_or_else(|| "default".to_string()),
input_token_logprobs,
output_token_logprobs,
completion_tokens: complete.completion_tokens as u32,
cached_tokens: complete.cached_tokens as u32,
e2e_latency: start_time.elapsed().as_secs_f64(),
matched_stop,
};
result_array.push(GenerateResponse {
text: decoded_text,
output_ids,
meta_info,
});
}
// Store the final response
ctx.state.response.final_response = Some(FinalResponse::Generate(result_array));
Ok(None)
}
}
@@ -1136,7 +1282,7 @@ impl ChatCompletionPipeline {
// Extract final response
match ctx.state.response.final_response {
Some(FinalResponse::Generate(response)) => axum::Json(*response).into_response(),
Some(FinalResponse::Generate(response)) => axum::Json(response).into_response(),
Some(FinalResponse::Chat(_)) => {
utils::internal_error_static("Internal error: wrong response type")
}