[router] refactor generate to use new pipeline arch (#11323)
This commit is contained in:
@@ -11,15 +11,20 @@ use super::context::*;
|
||||
use super::processing;
|
||||
use super::streaming;
|
||||
use super::utils;
|
||||
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType};
|
||||
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
|
||||
use crate::grpc_client::proto;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, InputIds, Usage,
|
||||
ChatCompletionRequest, ChatCompletionResponse, GenerateMetaInfo, GenerateRequest,
|
||||
GenerateResponse, InputIds, Usage,
|
||||
};
|
||||
use crate::tokenizer::stop::SequenceDecoderOutput;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use proto::generate_complete::MatchedStop;
|
||||
use proto::DisaggregatedParams;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||
use uuid::Uuid;
|
||||
|
||||
// ============================================================================
|
||||
@@ -208,7 +213,7 @@ impl PreparationStage {
|
||||
|
||||
fn tokenize_single_text(
|
||||
&self,
|
||||
tokenizer: &Arc<dyn crate::tokenizer::traits::Tokenizer>,
|
||||
tokenizer: &Arc<dyn Tokenizer>,
|
||||
text: &str,
|
||||
) -> Result<(String, Vec<u32>), String> {
|
||||
let encoding = tokenizer
|
||||
@@ -302,7 +307,7 @@ impl WorkerSelectionStage {
|
||||
&self,
|
||||
model_id: Option<&str>,
|
||||
text: Option<&str>,
|
||||
) -> Option<Arc<dyn crate::core::Worker>> {
|
||||
) -> Option<Arc<dyn Worker>> {
|
||||
// Get workers for the specified model, filtered by connection mode
|
||||
let workers = self.worker_registry.get_workers_filtered(
|
||||
model_id,
|
||||
@@ -312,7 +317,7 @@ impl WorkerSelectionStage {
|
||||
);
|
||||
|
||||
// Filter by availability (health + circuit breaker)
|
||||
let available: Vec<Arc<dyn crate::core::Worker>> = workers
|
||||
let available: Vec<Arc<dyn Worker>> = workers
|
||||
.iter()
|
||||
.filter(|w| w.is_available())
|
||||
.cloned()
|
||||
@@ -337,7 +342,7 @@ impl WorkerSelectionStage {
|
||||
&self,
|
||||
model_id: Option<&str>,
|
||||
text: Option<&str>,
|
||||
) -> Option<(Arc<dyn crate::core::Worker>, Arc<dyn crate::core::Worker>)> {
|
||||
) -> Option<(Arc<dyn Worker>, Arc<dyn Worker>)> {
|
||||
// Get prefill workers - use None for WorkerType filter to get all types,
|
||||
// then filter manually (since Prefill is a struct variant)
|
||||
let all_workers = self.worker_registry.get_workers_filtered(
|
||||
@@ -537,10 +542,8 @@ impl RequestBuildingStage {
|
||||
fn inject_bootstrap_metadata(
|
||||
&self,
|
||||
request: &mut proto::GenerateRequest,
|
||||
prefill_worker: &Arc<dyn crate::core::Worker>,
|
||||
prefill_worker: &Arc<dyn Worker>,
|
||||
) {
|
||||
use proto::DisaggregatedParams;
|
||||
|
||||
let hostname = prefill_worker.bootstrap_host();
|
||||
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
|
||||
|
||||
@@ -935,40 +938,183 @@ impl ResponseProcessingStage {
|
||||
|
||||
async fn process_generate_response(
|
||||
&self,
|
||||
_ctx: &mut RequestContext,
|
||||
ctx: &mut RequestContext,
|
||||
) -> Result<Option<Response>, Response> {
|
||||
// TODO(generate): Implement generate response processing
|
||||
//
|
||||
// Required implementation:
|
||||
// 1. Extract execution_result from ctx
|
||||
// 2. Check is_streaming flag
|
||||
// 3. For streaming:
|
||||
// - Add StreamingProcessor::process_streaming_generate() method
|
||||
// - Similar to process_streaming_response but WITHOUT tool/reasoning parsing
|
||||
// - Return Err(sse_response) for early exit
|
||||
// 4. For non-streaming:
|
||||
// - Collect stream responses using utils::collect_stream_responses()
|
||||
// - Process through stop decoder (sequential with reset for n>1, like chat)
|
||||
// - Build GenerateResponse struct (see TODO in protocols/spec.rs)
|
||||
// - Set ctx.state.response.final_response = Some(FinalResponse::Generate(response))
|
||||
//
|
||||
// Reference implementation: router.rs:297-595
|
||||
// Key differences from chat:
|
||||
// - No tool parsing
|
||||
// - No reasoning parsing
|
||||
// - Different response format (GenerateResponse instead of ChatCompletionResponse)
|
||||
// - Still needs: stop decoder, logprobs, finish_reason, matched_stop
|
||||
Err((
|
||||
axum::http::StatusCode::NOT_IMPLEMENTED,
|
||||
axum::Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": "Generate response processing not yet implemented in pipeline",
|
||||
"type": "not_implemented",
|
||||
"code": 501
|
||||
let start_time = Instant::now();
|
||||
let is_streaming = ctx.is_streaming();
|
||||
|
||||
// Extract execution result
|
||||
let execution_result = ctx
|
||||
.state
|
||||
.response
|
||||
.execution_result
|
||||
.take()
|
||||
.ok_or_else(|| utils::internal_error_static("No execution result"))?;
|
||||
|
||||
if is_streaming {
|
||||
// Get dispatch metadata for consistent response fields
|
||||
let dispatch = ctx
|
||||
.state
|
||||
.dispatch
|
||||
.as_ref()
|
||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
|
||||
|
||||
let generate_request = ctx.generate_request().clone();
|
||||
|
||||
// Streaming: Use StreamingProcessor and return SSE response (done)
|
||||
return Ok(Some(
|
||||
self.streaming_processor.clone().process_streaming_generate(
|
||||
execution_result,
|
||||
generate_request,
|
||||
dispatch.clone(),
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Non-streaming: Collect all responses
|
||||
let request_logprobs = ctx.generate_request().return_logprob;
|
||||
let all_responses = match execution_result {
|
||||
ExecutionResult::Single { stream } => {
|
||||
utils::collect_stream_responses(stream, "Single").await?
|
||||
}
|
||||
ExecutionResult::Dual { prefill, decode } => {
|
||||
// Collect prefill for input_logprobs
|
||||
let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?;
|
||||
|
||||
// Collect decode for actual output
|
||||
let mut decode_responses =
|
||||
utils::collect_stream_responses(*decode, "Decode").await?;
|
||||
|
||||
// Merge prefill input_logprobs if requested
|
||||
if request_logprobs {
|
||||
if let Some(prefill_input_logprobs) = prefill_responses
|
||||
.first()
|
||||
.and_then(|r| r.input_logprobs.clone())
|
||||
{
|
||||
for response in &mut decode_responses {
|
||||
response.input_logprobs = Some(prefill_input_logprobs.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response())
|
||||
|
||||
decode_responses
|
||||
}
|
||||
};
|
||||
|
||||
if all_responses.is_empty() {
|
||||
return Err(utils::internal_error_static("No responses from server"));
|
||||
}
|
||||
|
||||
// Get stop decoder for processing
|
||||
let stop_decoder = ctx
|
||||
.state
|
||||
.response
|
||||
.stop_decoder
|
||||
.as_mut()
|
||||
.ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?;
|
||||
|
||||
// Get dispatch metadata
|
||||
let dispatch = ctx
|
||||
.state
|
||||
.dispatch
|
||||
.as_ref()
|
||||
.ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?;
|
||||
|
||||
// Process each completion (similar to router.rs:336-400)
|
||||
let mut result_array = Vec::new();
|
||||
for mut complete in all_responses {
|
||||
stop_decoder.reset();
|
||||
|
||||
// Process tokens through stop decoder
|
||||
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
|
||||
Ok(outputs) => outputs,
|
||||
Err(e) => {
|
||||
return Err(utils::internal_error_message(format!(
|
||||
"Failed to process tokens: {}",
|
||||
e
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// Accumulate text with early breaks
|
||||
let mut decoded_text = String::new();
|
||||
for output in outputs {
|
||||
match output {
|
||||
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
|
||||
SequenceDecoderOutput::StoppedWithText(t) => {
|
||||
decoded_text.push_str(&t);
|
||||
break;
|
||||
}
|
||||
SequenceDecoderOutput::Stopped => break,
|
||||
SequenceDecoderOutput::Held => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush remaining text
|
||||
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
|
||||
decoded_text.push_str(&t);
|
||||
}
|
||||
|
||||
let output_ids = std::mem::take(&mut complete.output_ids);
|
||||
let finish_reason_str = std::mem::take(&mut complete.finish_reason);
|
||||
|
||||
// Parse finish_reason from string to proper type
|
||||
let finish_reason =
|
||||
utils::parse_finish_reason(&finish_reason_str, complete.completion_tokens);
|
||||
|
||||
// Handle matched_stop if present
|
||||
let matched_stop = complete.matched_stop.take().map(|matched| match matched {
|
||||
MatchedStop::MatchedTokenId(id) => serde_json::json!(id),
|
||||
MatchedStop::MatchedStopStr(s) => serde_json::json!(s),
|
||||
});
|
||||
|
||||
// Extract logprobs if requested (convert proto types to Generate format)
|
||||
let input_token_logprobs = if request_logprobs {
|
||||
complete
|
||||
.input_logprobs
|
||||
.as_ref()
|
||||
.map(utils::convert_generate_input_logprobs)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let output_token_logprobs = if request_logprobs {
|
||||
complete
|
||||
.output_logprobs
|
||||
.as_ref()
|
||||
.map(utils::convert_generate_output_logprobs)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build GenerateResponse struct
|
||||
let meta_info = GenerateMetaInfo {
|
||||
id: dispatch.request_id.clone(),
|
||||
finish_reason,
|
||||
prompt_tokens: complete.prompt_tokens as u32,
|
||||
weight_version: dispatch
|
||||
.weight_version
|
||||
.clone()
|
||||
.unwrap_or_else(|| "default".to_string()),
|
||||
input_token_logprobs,
|
||||
output_token_logprobs,
|
||||
completion_tokens: complete.completion_tokens as u32,
|
||||
cached_tokens: complete.cached_tokens as u32,
|
||||
e2e_latency: start_time.elapsed().as_secs_f64(),
|
||||
matched_stop,
|
||||
};
|
||||
|
||||
result_array.push(GenerateResponse {
|
||||
text: decoded_text,
|
||||
output_ids,
|
||||
meta_info,
|
||||
});
|
||||
}
|
||||
|
||||
// Store the final response
|
||||
ctx.state.response.final_response = Some(FinalResponse::Generate(result_array));
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1136,7 +1282,7 @@ impl ChatCompletionPipeline {
|
||||
|
||||
// Extract final response
|
||||
match ctx.state.response.final_response {
|
||||
Some(FinalResponse::Generate(response)) => axum::Json(*response).into_response(),
|
||||
Some(FinalResponse::Generate(response)) => axum::Json(response).into_response(),
|
||||
Some(FinalResponse::Chat(_)) => {
|
||||
utils::internal_error_static("Internal error: wrong response type")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user