[router][grpc] Add helpfer functions for decoder in router.rs and fix specs (#10971)
This commit is contained in:
@@ -36,9 +36,9 @@ message SamplingParams {
|
|||||||
float presence_penalty = 6;
|
float presence_penalty = 6;
|
||||||
float repetition_penalty = 7;
|
float repetition_penalty = 7;
|
||||||
|
|
||||||
int32 max_new_tokens = 8;
|
optional int32 max_new_tokens = 8;
|
||||||
repeated string stop = 9;
|
repeated string stop = 9;
|
||||||
repeated int32 stop_token_ids = 10;
|
repeated uint32 stop_token_ids = 10;
|
||||||
bool skip_special_tokens = 11;
|
bool skip_special_tokens = 11;
|
||||||
bool spaces_between_special_tokens = 12;
|
bool spaces_between_special_tokens = 12;
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ message GenerateRequest {
|
|||||||
bool return_logprob = 5;
|
bool return_logprob = 5;
|
||||||
int32 logprob_start_len = 6;
|
int32 logprob_start_len = 6;
|
||||||
int32 top_logprobs_num = 7;
|
int32 top_logprobs_num = 7;
|
||||||
repeated int32 token_ids_logprob = 8;
|
repeated uint32 token_ids_logprob = 8;
|
||||||
bool return_hidden_states = 9;
|
bool return_hidden_states = 9;
|
||||||
|
|
||||||
// For disaggregated serving
|
// For disaggregated serving
|
||||||
@@ -129,7 +129,7 @@ message GenerateRequest {
|
|||||||
|
|
||||||
message TokenizedInput {
|
message TokenizedInput {
|
||||||
string original_text = 1; // For reference
|
string original_text = 1; // For reference
|
||||||
repeated int32 input_ids = 2;
|
repeated uint32 input_ids = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message MultimodalInputs {
|
message MultimodalInputs {
|
||||||
@@ -167,7 +167,7 @@ message GenerateResponse {
|
|||||||
|
|
||||||
message GenerateStreamChunk {
|
message GenerateStreamChunk {
|
||||||
// Generated tokens (incremental chunk)
|
// Generated tokens (incremental chunk)
|
||||||
repeated int32 token_ids = 1;
|
repeated uint32 token_ids = 1;
|
||||||
|
|
||||||
// Cumulative counts
|
// Cumulative counts
|
||||||
int32 prompt_tokens = 2;
|
int32 prompt_tokens = 2;
|
||||||
@@ -183,7 +183,7 @@ message GenerateStreamChunk {
|
|||||||
|
|
||||||
message GenerateComplete {
|
message GenerateComplete {
|
||||||
// Final output
|
// Final output
|
||||||
repeated int32 output_ids = 1;
|
repeated uint32 output_ids = 1;
|
||||||
|
|
||||||
// Finish reason
|
// Finish reason
|
||||||
enum FinishReason {
|
enum FinishReason {
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -20,7 +20,7 @@ pub struct SglangSchedulerClient {
|
|||||||
|
|
||||||
impl SglangSchedulerClient {
|
impl SglangSchedulerClient {
|
||||||
/// Create a new client and connect to the scheduler
|
/// Create a new client and connect to the scheduler
|
||||||
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
debug!("Connecting to SGLang scheduler at {}", endpoint);
|
debug!("Connecting to SGLang scheduler at {}", endpoint);
|
||||||
|
|
||||||
// Convert grpc:// to http:// for tonic
|
// Convert grpc:// to http:// for tonic
|
||||||
@@ -41,10 +41,11 @@ impl SglangSchedulerClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Submit a generation request (returns streaming response)
|
/// Submit a generation request (returns streaming response)
|
||||||
pub async fn generate_stream(
|
pub async fn generate(
|
||||||
&mut self,
|
&mut self,
|
||||||
req: proto::GenerateRequest,
|
req: proto::GenerateRequest,
|
||||||
) -> Result<tonic::Streaming<proto::GenerateResponse>, Box<dyn std::error::Error>> {
|
) -> Result<tonic::Streaming<proto::GenerateResponse>, Box<dyn std::error::Error + Send + Sync>>
|
||||||
|
{
|
||||||
let request = Request::new(req);
|
let request = Request::new(req);
|
||||||
let response = self.client.generate(request).await?;
|
let response = self.client.generate(request).await?;
|
||||||
Ok(response.into_inner())
|
Ok(response.into_inner())
|
||||||
@@ -53,7 +54,7 @@ impl SglangSchedulerClient {
|
|||||||
/// Perform health check
|
/// Perform health check
|
||||||
pub async fn health_check(
|
pub async fn health_check(
|
||||||
&mut self,
|
&mut self,
|
||||||
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
|
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
debug!("Sending health check request");
|
debug!("Sending health check request");
|
||||||
let request = Request::new(proto::HealthCheckRequest {
|
let request = Request::new(proto::HealthCheckRequest {
|
||||||
tokenized: Some(proto::TokenizedInput {
|
tokenized: Some(proto::TokenizedInput {
|
||||||
@@ -72,7 +73,7 @@ impl SglangSchedulerClient {
|
|||||||
&mut self,
|
&mut self,
|
||||||
request_id: String,
|
request_id: String,
|
||||||
reason: String,
|
reason: String,
|
||||||
) -> Result<(), Box<dyn std::error::Error>> {
|
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let request = Request::new(proto::AbortRequest { request_id, reason });
|
let request = Request::new(proto::AbortRequest { request_id, reason });
|
||||||
|
|
||||||
self.client.abort(request).await?;
|
self.client.abort(request).await?;
|
||||||
@@ -85,7 +86,7 @@ impl SglangSchedulerClient {
|
|||||||
request_id: String,
|
request_id: String,
|
||||||
body: &ChatCompletionRequest,
|
body: &ChatCompletionRequest,
|
||||||
processed_text: String,
|
processed_text: String,
|
||||||
token_ids: Vec<i32>,
|
token_ids: Vec<u32>,
|
||||||
multimodal_inputs: Option<proto::MultimodalInputs>,
|
multimodal_inputs: Option<proto::MultimodalInputs>,
|
||||||
tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value)
|
tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value)
|
||||||
) -> Result<proto::GenerateRequest, String> {
|
) -> Result<proto::GenerateRequest, String> {
|
||||||
@@ -153,6 +154,8 @@ impl SglangSchedulerClient {
|
|||||||
stop: stop_sequences,
|
stop: stop_sequences,
|
||||||
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
|
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
|
||||||
skip_special_tokens,
|
skip_special_tokens,
|
||||||
|
ignore_eos: request.ignore_eos,
|
||||||
|
no_stop_trim: request.no_stop_trim,
|
||||||
n: request.n.unwrap_or(1) as i32,
|
n: request.n.unwrap_or(1) as i32,
|
||||||
constraint: self.build_constraint(request, tool_call_constraint)?,
|
constraint: self.build_constraint(request, tool_call_constraint)?,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ message SamplingParams {
|
|||||||
|
|
||||||
optional int32 max_new_tokens = 8;
|
optional int32 max_new_tokens = 8;
|
||||||
repeated string stop = 9;
|
repeated string stop = 9;
|
||||||
repeated int32 stop_token_ids = 10;
|
repeated uint32 stop_token_ids = 10;
|
||||||
bool skip_special_tokens = 11;
|
bool skip_special_tokens = 11;
|
||||||
bool spaces_between_special_tokens = 12;
|
bool spaces_between_special_tokens = 12;
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ message GenerateRequest {
|
|||||||
bool return_logprob = 5;
|
bool return_logprob = 5;
|
||||||
int32 logprob_start_len = 6;
|
int32 logprob_start_len = 6;
|
||||||
int32 top_logprobs_num = 7;
|
int32 top_logprobs_num = 7;
|
||||||
repeated int32 token_ids_logprob = 8;
|
repeated uint32 token_ids_logprob = 8;
|
||||||
bool return_hidden_states = 9;
|
bool return_hidden_states = 9;
|
||||||
|
|
||||||
// For disaggregated serving
|
// For disaggregated serving
|
||||||
@@ -129,7 +129,7 @@ message GenerateRequest {
|
|||||||
|
|
||||||
message TokenizedInput {
|
message TokenizedInput {
|
||||||
string original_text = 1; // For reference
|
string original_text = 1; // For reference
|
||||||
repeated int32 input_ids = 2;
|
repeated uint32 input_ids = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message MultimodalInputs {
|
message MultimodalInputs {
|
||||||
@@ -167,7 +167,7 @@ message GenerateResponse {
|
|||||||
|
|
||||||
message GenerateStreamChunk {
|
message GenerateStreamChunk {
|
||||||
// Generated tokens (incremental chunk)
|
// Generated tokens (incremental chunk)
|
||||||
repeated int32 token_ids = 1;
|
repeated uint32 token_ids = 1;
|
||||||
|
|
||||||
// Cumulative counts
|
// Cumulative counts
|
||||||
int32 prompt_tokens = 2;
|
int32 prompt_tokens = 2;
|
||||||
@@ -183,7 +183,7 @@ message GenerateStreamChunk {
|
|||||||
|
|
||||||
message GenerateComplete {
|
message GenerateComplete {
|
||||||
// Final output
|
// Final output
|
||||||
repeated int32 output_ids = 1;
|
repeated uint32 output_ids = 1;
|
||||||
|
|
||||||
// Finish reason
|
// Finish reason
|
||||||
enum FinishReason {
|
enum FinishReason {
|
||||||
|
|||||||
@@ -313,7 +313,7 @@ pub struct ChatCompletionRequest {
|
|||||||
|
|
||||||
/// Specific token IDs to use as stop conditions
|
/// Specific token IDs to use as stop conditions
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub stop_token_ids: Option<Vec<i32>>,
|
pub stop_token_ids: Option<Vec<u32>>,
|
||||||
|
|
||||||
/// Skip trimming stop tokens from output
|
/// Skip trimming stop tokens from output
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@@ -564,7 +564,7 @@ pub struct CompletionRequest {
|
|||||||
|
|
||||||
/// Specific token IDs to use as stop conditions
|
/// Specific token IDs to use as stop conditions
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub stop_token_ids: Option<Vec<i32>>,
|
pub stop_token_ids: Option<Vec<u32>>,
|
||||||
|
|
||||||
/// Skip trimming stop tokens from output
|
/// Skip trimming stop tokens from output
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@@ -1864,7 +1864,7 @@ pub struct SamplingParams {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub min_tokens: Option<u32>,
|
pub min_tokens: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub stop_token_ids: Option<Vec<i32>>,
|
pub stop_token_ids: Option<Vec<u32>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub no_stop_trim: Option<bool>,
|
pub no_stop_trim: Option<bool>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
|||||||
@@ -17,19 +17,20 @@ use crate::grpc_client::{proto, SglangSchedulerClient};
|
|||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::policies::PolicyRegistry;
|
use crate::policies::PolicyRegistry;
|
||||||
use crate::protocols::spec::ChatMessage;
|
use crate::protocols::spec::ChatMessage;
|
||||||
use crate::protocols::spec::{ChatCompletionRequest, StringOrArray};
|
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams,
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||||
ResponsesRequest, Tool, ToolChoice,
|
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolChoice,
|
||||||
};
|
};
|
||||||
use crate::reasoning_parser::ParserFactory;
|
use crate::reasoning_parser::ParserFactory;
|
||||||
use crate::routers::RouterTrait;
|
use crate::routers::RouterTrait;
|
||||||
use crate::server::AppContext;
|
use crate::server::AppContext;
|
||||||
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
||||||
|
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoderBuilder};
|
||||||
use crate::tokenizer::traits::Tokenizer;
|
use crate::tokenizer::traits::Tokenizer;
|
||||||
use crate::tokenizer::HuggingFaceTokenizer;
|
use crate::tokenizer::HuggingFaceTokenizer;
|
||||||
use crate::tool_parser::ParserRegistry;
|
use crate::tool_parser::ParserRegistry;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use tokio_stream::StreamExt;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
// Data structures for processing
|
// Data structures for processing
|
||||||
@@ -182,7 +183,7 @@ impl GrpcRouter {
|
|||||||
request_id,
|
request_id,
|
||||||
body,
|
body,
|
||||||
processed_messages.text.clone(),
|
processed_messages.text.clone(),
|
||||||
token_ids.into_iter().map(|id| id as i32).collect(),
|
token_ids,
|
||||||
processed_messages.multimodal_inputs,
|
processed_messages.multimodal_inputs,
|
||||||
tool_call_constraint, // Pass the full tuple (type, value)
|
tool_call_constraint, // Pass the full tuple (type, value)
|
||||||
) {
|
) {
|
||||||
@@ -479,28 +480,225 @@ impl GrpcRouter {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Placeholder for streaming handler (to be implemented in Phase 2)
|
/// Create a StopSequenceDecoder from the chat completion request
|
||||||
async fn handle_streaming_chat(
|
fn create_stop_decoder(
|
||||||
&self,
|
&self,
|
||||||
_client: SglangSchedulerClient,
|
original_request: &ChatCompletionRequest,
|
||||||
_request: proto::GenerateRequest,
|
) -> crate::tokenizer::stop::StopSequenceDecoder {
|
||||||
_original_request: &ChatCompletionRequest,
|
// Extract stop sequences from request
|
||||||
) -> Response {
|
let stop_sequences: Vec<String> = match &original_request.stop {
|
||||||
(StatusCode::NOT_IMPLEMENTED, "Streaming not yet implemented").into_response()
|
Some(StringOrArray::String(s)) => vec![s.clone()],
|
||||||
|
Some(StringOrArray::Array(arr)) => arr.clone(),
|
||||||
|
None => vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
// Build stop sequence decoder
|
||||||
|
let mut builder = StopSequenceDecoderBuilder::new(self.tokenizer.clone())
|
||||||
|
.skip_special_tokens(original_request.skip_special_tokens);
|
||||||
|
|
||||||
|
// Add stop sequences (visible if no_stop_trim is true, hidden otherwise)
|
||||||
|
for seq in stop_sequences {
|
||||||
|
builder = if original_request.no_stop_trim {
|
||||||
|
builder.visible_stop_sequence(seq)
|
||||||
|
} else {
|
||||||
|
builder.stop_sequence(seq)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add stop token IDs (visible if no_stop_trim is true, hidden otherwise)
|
||||||
|
if let Some(stop_token_ids) = &original_request.stop_token_ids {
|
||||||
|
for &token_id in stop_token_ids {
|
||||||
|
builder = if original_request.no_stop_trim {
|
||||||
|
builder.visible_stop_token(token_id)
|
||||||
|
} else {
|
||||||
|
builder.stop_token(token_id)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Placeholder for non-streaming handler (to be implemented in Phase 3)
|
/// Process a chunk of tokens through the stop decoder
|
||||||
|
fn process_chunk_tokens(
|
||||||
|
stop_decoder: &mut crate::tokenizer::stop::StopSequenceDecoder,
|
||||||
|
token_ids: &[u32],
|
||||||
|
) -> (String, bool) {
|
||||||
|
let mut chunk_text = String::new();
|
||||||
|
|
||||||
|
for &token_id in token_ids {
|
||||||
|
match stop_decoder.process_token(token_id).unwrap_or_else(|e| {
|
||||||
|
debug!(
|
||||||
|
"Error processing token {}: {}. Treating as Held.",
|
||||||
|
token_id, e
|
||||||
|
);
|
||||||
|
SequenceDecoderOutput::Held
|
||||||
|
}) {
|
||||||
|
SequenceDecoderOutput::Text(text) => {
|
||||||
|
chunk_text.push_str(&text);
|
||||||
|
}
|
||||||
|
SequenceDecoderOutput::StoppedWithText(text) => {
|
||||||
|
chunk_text.push_str(&text);
|
||||||
|
return (chunk_text, true); // Return text and signal to stop
|
||||||
|
}
|
||||||
|
SequenceDecoderOutput::Stopped => {
|
||||||
|
return (chunk_text, true); // Return text and signal to stop
|
||||||
|
}
|
||||||
|
SequenceDecoderOutput::Held => {
|
||||||
|
// Text held for potential stop sequence match
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(chunk_text, false) // Return text and continue processing
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Submit request and handle streaming response for chat completions route
|
||||||
|
async fn handle_streaming_chat(
|
||||||
|
&self,
|
||||||
|
mut client: SglangSchedulerClient,
|
||||||
|
request: proto::GenerateRequest,
|
||||||
|
original_request: &ChatCompletionRequest,
|
||||||
|
) -> Response {
|
||||||
|
let mut stop_decoder = self.create_stop_decoder(original_request);
|
||||||
|
|
||||||
|
// Process streaming tokens
|
||||||
|
let mut grpc_stream = match client.generate(request).await {
|
||||||
|
Ok(stream) => stream,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to start generation: {}", e);
|
||||||
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Generation failed: {}", e),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut decoded_text = String::new();
|
||||||
|
|
||||||
|
while let Some(response) = grpc_stream.next().await {
|
||||||
|
let gen_response = match response {
|
||||||
|
Ok(resp) => resp,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Stream error: {}", e);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match gen_response.response {
|
||||||
|
Some(proto::generate_response::Response::Chunk(chunk)) => {
|
||||||
|
// Process tokens and check if we should stop
|
||||||
|
let (chunk_text, should_stop) =
|
||||||
|
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
|
||||||
|
decoded_text.push_str(&chunk_text);
|
||||||
|
if should_stop {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Some(proto::generate_response::Response::Complete(_complete)) => {
|
||||||
|
// Flush any remaining text
|
||||||
|
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
|
||||||
|
if !text.is_empty() {
|
||||||
|
decoded_text.push_str(&text);
|
||||||
|
debug!("Flushed text: {}", text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Some(proto::generate_response::Response::Error(error)) => {
|
||||||
|
error!("Generation error: {}", error.message);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
None => continue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Replace with proper SSE streaming response
|
||||||
|
// For now, return the complete decoded text
|
||||||
|
(StatusCode::OK, format!("Decoded text: {}", decoded_text)).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Submit request and handle non-streaming response for chat completions route
|
||||||
async fn handle_non_streaming_chat(
|
async fn handle_non_streaming_chat(
|
||||||
&self,
|
&self,
|
||||||
_client: SglangSchedulerClient,
|
mut client: SglangSchedulerClient,
|
||||||
_request: proto::GenerateRequest,
|
request: proto::GenerateRequest,
|
||||||
_original_request: &ChatCompletionRequest,
|
original_request: &ChatCompletionRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
(
|
let mut stop_decoder = self.create_stop_decoder(original_request);
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
|
||||||
"Non-streaming not yet implemented",
|
// Small helpers to log + return a uniform 500
|
||||||
)
|
let fail_str = |msg: &'static str| -> Response {
|
||||||
.into_response()
|
error!("{}", msg);
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, msg).into_response()
|
||||||
|
};
|
||||||
|
let fail_fmt = |prefix: &str, e: &dyn std::fmt::Display| -> Response {
|
||||||
|
error!("{}{}", prefix, e);
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("{}{}", prefix, e),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start generation
|
||||||
|
let mut stream = match client.generate(request).await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => return fail_fmt("Failed to start generation: ", &e),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get the single Complete response
|
||||||
|
let gen_response = match stream.next().await {
|
||||||
|
Some(Ok(r)) => r,
|
||||||
|
Some(Err(e)) => return fail_fmt("Failed to get GenerateResponse: ", &e),
|
||||||
|
None => return fail_str("No response from server"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract the expected variant early
|
||||||
|
let complete = match gen_response.response {
|
||||||
|
Some(proto::generate_response::Response::Complete(c)) => c,
|
||||||
|
Some(proto::generate_response::Response::Error(err)) => {
|
||||||
|
error!("Generation failed: {}", err.message);
|
||||||
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Generation failed: {}", err.message),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
Some(proto::generate_response::Response::Chunk(_)) => {
|
||||||
|
return fail_str("Unexpected chunk response for non-streaming request")
|
||||||
|
}
|
||||||
|
None => return fail_str("Empty response from server"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Decode tokens
|
||||||
|
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
|
||||||
|
Ok(o) => o,
|
||||||
|
Err(e) => return fail_fmt("Failed to process tokens: ", &e),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Accumulate text with early breaks
|
||||||
|
let mut final_text = String::new();
|
||||||
|
for output in outputs {
|
||||||
|
match output {
|
||||||
|
SequenceDecoderOutput::Text(t) => final_text.push_str(&t),
|
||||||
|
SequenceDecoderOutput::StoppedWithText(t) => {
|
||||||
|
final_text.push_str(&t);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
SequenceDecoderOutput::Stopped => break,
|
||||||
|
SequenceDecoderOutput::Held => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush remaining text
|
||||||
|
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
|
||||||
|
final_text.push_str(&t);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Create proper OpenAI-compatible response
|
||||||
|
(StatusCode::OK, format!("Final text: {}", final_text)).into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user