[router][grpc] Add logprobs support to router (#11082)
This commit is contained in:
@@ -174,11 +174,14 @@ message GenerateStreamChunk {
|
||||
int32 completion_tokens = 3;
|
||||
int32 cached_tokens = 4;
|
||||
|
||||
// Logprobs (if requested)
|
||||
LogProbs logprobs = 5;
|
||||
// Output logprobs (if requested) - incremental for streaming
|
||||
LogProbs output_logprobs = 5;
|
||||
|
||||
// Hidden states (if requested)
|
||||
repeated float hidden_states = 6;
|
||||
|
||||
// Input logprobs (if requested) - only in first chunk
|
||||
LogProbs input_logprobs = 7;
|
||||
}
|
||||
|
||||
message GenerateComplete {
|
||||
@@ -193,8 +196,8 @@ message GenerateComplete {
|
||||
int32 completion_tokens = 4;
|
||||
int32 cached_tokens = 5;
|
||||
|
||||
// All logprobs if requested
|
||||
repeated LogProbs all_logprobs = 6;
|
||||
// Output logprobs if requested (cumulative)
|
||||
LogProbs output_logprobs = 6;
|
||||
|
||||
// All hidden states if requested
|
||||
repeated HiddenStates all_hidden_states = 7;
|
||||
@@ -204,6 +207,9 @@ message GenerateComplete {
|
||||
uint32 matched_token_id = 8;
|
||||
string matched_stop_str = 9;
|
||||
}
|
||||
|
||||
// Input logprobs if requested (for prompt tokens)
|
||||
LogProbs input_logprobs = 10;
|
||||
}
|
||||
|
||||
message GenerateError {
|
||||
@@ -218,15 +224,11 @@ message LogProbs {
|
||||
|
||||
// Top logprobs at each position
|
||||
repeated TopLogProbs top_logprobs = 3;
|
||||
|
||||
// Decoded text for tokens
|
||||
repeated string token_texts = 4;
|
||||
}
|
||||
|
||||
message TopLogProbs {
|
||||
repeated float values = 1;
|
||||
repeated int32 token_ids = 2;
|
||||
repeated string token_texts = 3;
|
||||
}
|
||||
|
||||
message HiddenStates {
|
||||
|
||||
@@ -730,6 +730,73 @@ impl GrpcRouter {
|
||||
Json(response).into_response()
|
||||
}
|
||||
|
||||
/// Convert proto LogProbs to OpenAI ChatLogProbs format
|
||||
/// Note: Always decodes with skip_special_tokens=false to show actual tokens generated
|
||||
fn convert_proto_to_openai_logprobs(
|
||||
&self,
|
||||
proto_logprobs: &proto::LogProbs,
|
||||
) -> Result<crate::protocols::spec::ChatLogProbs, String> {
|
||||
let mut content_items = Vec::new();
|
||||
|
||||
// Decode token IDs to text (always with skip_special_tokens=false for logprobs)
|
||||
let token_texts: Vec<String> = proto_logprobs
|
||||
.token_ids
|
||||
.iter()
|
||||
.map(|&token_id| {
|
||||
self.tokenizer
|
||||
.decode(&[token_id as u32], false)
|
||||
.unwrap_or_else(|_| format!("<token_{}>", token_id))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Build ChatLogProbsContent for each token
|
||||
for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() {
|
||||
let token_text = token_texts.get(i).cloned().unwrap_or_default();
|
||||
let bytes = Some(token_text.as_bytes().to_vec());
|
||||
|
||||
// Build top_logprobs for this position
|
||||
let mut top_logprobs = Vec::new();
|
||||
if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) {
|
||||
// Decode top token IDs (always with skip_special_tokens=false)
|
||||
let top_token_texts: Vec<String> = top_logprobs_entry
|
||||
.token_ids
|
||||
.iter()
|
||||
.map(|&tid| {
|
||||
self.tokenizer
|
||||
.decode(&[tid as u32], false)
|
||||
.unwrap_or_else(|_| format!("<token_{}>", tid))
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry
|
||||
.values
|
||||
.iter()
|
||||
.zip(top_logprobs_entry.token_ids.iter())
|
||||
.enumerate()
|
||||
{
|
||||
if let Some(top_token_text) = top_token_texts.get(j) {
|
||||
top_logprobs.push(crate::protocols::spec::TopLogProb {
|
||||
token: top_token_text.clone(),
|
||||
logprob: top_logprob,
|
||||
bytes: Some(top_token_text.as_bytes().to_vec()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
content_items.push(crate::protocols::spec::ChatLogProbsContent {
|
||||
token: token_text,
|
||||
logprob,
|
||||
bytes,
|
||||
top_logprobs,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(crate::protocols::spec::ChatLogProbs::Detailed {
|
||||
content: (!content_items.is_empty()).then_some(content_items),
|
||||
})
|
||||
}
|
||||
|
||||
/// Process a single GenerateComplete response into a ChatChoice
|
||||
async fn process_single_choice(
|
||||
&self,
|
||||
@@ -855,7 +922,22 @@ impl GrpcRouter {
|
||||
None => None,
|
||||
};
|
||||
|
||||
// Step 4: Build ChatCompletionMessage (proper response message type)
|
||||
// Step 4: Convert output logprobs if present
|
||||
// Note: complete.input_logprobs exists in proto but is not used for chat completions
|
||||
// (input logprobs are only used in /v1/completions endpoint with echo=true)
|
||||
let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs {
|
||||
match self.convert_proto_to_openai_logprobs(proto_logprobs) {
|
||||
Ok(logprobs) => Some(logprobs),
|
||||
Err(e) => {
|
||||
error!("Failed to convert logprobs: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Step 5: Build ChatCompletionMessage (proper response message type)
|
||||
let chat_message = ChatCompletionMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: if processed_text.is_empty() {
|
||||
@@ -867,11 +949,11 @@ impl GrpcRouter {
|
||||
reasoning_content: reasoning_text,
|
||||
};
|
||||
|
||||
// Step 5: Build ChatChoice
|
||||
// Step 6: Build ChatChoice
|
||||
let choice = ChatChoice {
|
||||
index: index as u32,
|
||||
message: chat_message,
|
||||
logprobs: None,
|
||||
logprobs,
|
||||
finish_reason: Some(final_finish_reason_str.to_string()),
|
||||
matched_stop,
|
||||
hidden_states: None,
|
||||
|
||||
Reference in New Issue
Block a user