[router] refactor generate to use new pipeline arch (#11323)
This commit is contained in:
@@ -5,7 +5,7 @@ use crate::core::Worker;
|
||||
use crate::grpc_client::{proto, SglangSchedulerClient};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse,
|
||||
StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb,
|
||||
GenerateFinishReason, StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb,
|
||||
};
|
||||
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
@@ -809,6 +809,70 @@ pub fn convert_proto_to_openai_logprobs(
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert proto::OutputLogProbs to Generate format Vec<Vec<Option<f64>>>
|
||||
///
|
||||
/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...]
|
||||
/// Each inner vec contains [logprob (f64), token_id (i32), ...]
|
||||
pub fn convert_generate_output_logprobs(
|
||||
proto_logprobs: &proto::OutputLogProbs,
|
||||
) -> Vec<Vec<Option<f64>>> {
|
||||
proto_logprobs
|
||||
.token_logprobs
|
||||
.iter()
|
||||
.zip(proto_logprobs.token_ids.iter())
|
||||
.map(|(&logprob, &token_id)| vec![Some(logprob as f64), Some(token_id as f64)])
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convert proto::InputLogProbs to Generate format Vec<Vec<Option<f64>>>
|
||||
///
|
||||
/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...]
|
||||
/// First token has null logprob: [[null, token_id], [logprob, token_id], ...]
|
||||
pub fn convert_generate_input_logprobs(
|
||||
proto_logprobs: &proto::InputLogProbs,
|
||||
) -> Vec<Vec<Option<f64>>> {
|
||||
proto_logprobs
|
||||
.token_logprobs
|
||||
.iter()
|
||||
.zip(proto_logprobs.token_ids.iter())
|
||||
.map(|(token_logprob, &token_id)| {
|
||||
// InputTokenLogProb has optional value field
|
||||
let logprob_value = token_logprob.value.map(|v| v as f64);
|
||||
vec![logprob_value, Some(token_id as f64)]
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse finish_reason string into GenerateFinishReason enum
|
||||
///
|
||||
/// Uses serde to deserialize the finish_reason, which handles all tagged variants automatically.
|
||||
/// The GenerateFinishReason enum is tagged with `#[serde(tag = "type", rename_all = "lowercase")]`,
|
||||
/// so it expects JSON objects like:
|
||||
/// - `{"type":"stop"}` -> Stop
|
||||
/// - `{"type":"length","length":100}` -> Length { length: 100 }
|
||||
/// - Any other JSON -> Other(...)
|
||||
///
|
||||
/// For backward compatibility, also handles simple string "stop" -> Stop
|
||||
pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> GenerateFinishReason {
|
||||
if reason_str == "stop" {
|
||||
return GenerateFinishReason::Stop;
|
||||
}
|
||||
|
||||
if reason_str == "length" {
|
||||
return GenerateFinishReason::Length {
|
||||
length: completion_tokens.max(0) as u32,
|
||||
};
|
||||
}
|
||||
|
||||
match serde_json::from_str::<GenerateFinishReason>(reason_str) {
|
||||
Ok(finish_reason) => finish_reason,
|
||||
Err(_) => match serde_json::from_str::<Value>(reason_str) {
|
||||
Ok(json_value) => GenerateFinishReason::Other(json_value),
|
||||
Err(_) => GenerateFinishReason::Other(Value::String(reason_str.to_string())),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user