[router] refactor generate to use new pipeline arch (#11323)

This commit is contained in:
Simo Lin
2025-10-08 12:38:50 -04:00
committed by GitHub
parent d6837aea4d
commit 01c9ee1ab4
7 changed files with 713 additions and 1181 deletions

View File

@@ -5,7 +5,7 @@ use crate::core::Worker;
use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::spec::{
ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse,
StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb,
GenerateFinishReason, StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb,
};
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
use crate::tokenizer::traits::Tokenizer;
@@ -809,6 +809,70 @@ pub fn convert_proto_to_openai_logprobs(
})
}
/// Convert proto::OutputLogProbs to Generate format Vec<Vec<Option<f64>>>
///
/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...]
/// Each inner vec contains [logprob (f64), token_id (i32), ...]
pub fn convert_generate_output_logprobs(
proto_logprobs: &proto::OutputLogProbs,
) -> Vec<Vec<Option<f64>>> {
proto_logprobs
.token_logprobs
.iter()
.zip(proto_logprobs.token_ids.iter())
.map(|(&logprob, &token_id)| vec![Some(logprob as f64), Some(token_id as f64)])
.collect()
}
/// Convert proto::InputLogProbs to Generate format Vec<Vec<Option<f64>>>
///
/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...]
/// First token has null logprob: [[null, token_id], [logprob, token_id], ...]
pub fn convert_generate_input_logprobs(
proto_logprobs: &proto::InputLogProbs,
) -> Vec<Vec<Option<f64>>> {
proto_logprobs
.token_logprobs
.iter()
.zip(proto_logprobs.token_ids.iter())
.map(|(token_logprob, &token_id)| {
// InputTokenLogProb has optional value field
let logprob_value = token_logprob.value.map(|v| v as f64);
vec![logprob_value, Some(token_id as f64)]
})
.collect()
}
/// Parse finish_reason string into GenerateFinishReason enum
///
/// Uses serde to deserialize the finish_reason, which handles all tagged variants automatically.
/// The GenerateFinishReason enum is tagged with `#[serde(tag = "type", rename_all = "lowercase")]`,
/// so it expects JSON objects like:
/// - `{"type":"stop"}` -> Stop
/// - `{"type":"length","length":100}` -> Length { length: 100 }
/// - Any other JSON -> Other(...)
///
/// For backward compatibility, also handles simple string "stop" -> Stop
pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> GenerateFinishReason {
if reason_str == "stop" {
return GenerateFinishReason::Stop;
}
if reason_str == "length" {
return GenerateFinishReason::Length {
length: completion_tokens.max(0) as u32,
};
}
match serde_json::from_str::<GenerateFinishReason>(reason_str) {
Ok(finish_reason) => finish_reason,
Err(_) => match serde_json::from_str::<Value>(reason_str) {
Ok(json_value) => GenerateFinishReason::Other(json_value),
Err(_) => GenerateFinishReason::Other(Value::String(reason_str.to_string())),
},
}
}
#[cfg(test)]
mod tests {
use super::*;