[router][grpc] Support E2E non-stream chat completions (#10980)
This commit is contained in:
@@ -13,7 +13,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
import grpc
|
import grpc
|
||||||
import zmq
|
import zmq
|
||||||
@@ -156,7 +156,7 @@ class GrpcRequestManager:
|
|||||||
obj: TokenizedGenerateReqInput,
|
obj: TokenizedGenerateReqInput,
|
||||||
request_id: Optional[str] = None,
|
request_id: Optional[str] = None,
|
||||||
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
||||||
):
|
) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
|
||||||
"""
|
"""
|
||||||
Submit a generation request to the scheduler with n>1 parallel sampling support.
|
Submit a generation request to the scheduler with n>1 parallel sampling support.
|
||||||
|
|
||||||
|
|||||||
@@ -321,14 +321,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
logger.info(f"Sending health check request to request manager...")
|
logger.info(f"Sending health check request to request manager...")
|
||||||
|
|
||||||
# Submit and wait for response
|
# Submit and wait for response
|
||||||
output_queue = await self.request_manager.generate_request(
|
output_generator = self.request_manager.generate_request(
|
||||||
health_request, request_id=rid
|
health_request, request_id=rid
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Wait for response with configurable timeout
|
# Get first response with timeout
|
||||||
response = await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
|
output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
@@ -492,13 +492,32 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
) -> sglang_scheduler_pb2.GenerateResponse:
|
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||||
"""Create a completion response."""
|
"""Create a completion response."""
|
||||||
|
|
||||||
# Determine finish reason
|
# Extract meta info and finish reason details
|
||||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
|
|
||||||
meta_info = output.get("meta_info", {})
|
meta_info = output.get("meta_info", {})
|
||||||
if meta_info.get("finish_reason") == "length":
|
finish_reason_data = meta_info.get("finish_reason")
|
||||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
|
|
||||||
elif meta_info.get("finish_reason") == "eos_token":
|
# Determine finish reason, default is stop
|
||||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
|
finish_reason = "stop"
|
||||||
|
if finish_reason_data:
|
||||||
|
if isinstance(finish_reason_data, dict):
|
||||||
|
finish_reason_type = finish_reason_data.get("type")
|
||||||
|
else:
|
||||||
|
# Handle legacy string format
|
||||||
|
finish_reason_type = finish_reason_data
|
||||||
|
|
||||||
|
if finish_reason_type == "length":
|
||||||
|
finish_reason = "length"
|
||||||
|
elif finish_reason_type == "abort":
|
||||||
|
finish_reason = "abort"
|
||||||
|
|
||||||
|
# Extract matched_stop information
|
||||||
|
matched_stop_kwargs = {}
|
||||||
|
if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
|
||||||
|
matched = finish_reason_data["matched"]
|
||||||
|
if isinstance(matched, int):
|
||||||
|
matched_stop_kwargs["matched_token_id"] = matched
|
||||||
|
elif isinstance(matched, str):
|
||||||
|
matched_stop_kwargs["matched_stop_str"] = matched
|
||||||
|
|
||||||
return sglang_scheduler_pb2.GenerateResponse(
|
return sglang_scheduler_pb2.GenerateResponse(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
@@ -510,6 +529,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
"completion_tokens", len(output.get("token_ids", []))
|
"completion_tokens", len(output.get("token_ids", []))
|
||||||
),
|
),
|
||||||
cached_tokens=meta_info.get("cached_tokens", 0),
|
cached_tokens=meta_info.get("cached_tokens", 0),
|
||||||
|
**matched_stop_kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -185,20 +185,8 @@ message GenerateComplete {
|
|||||||
// Final output
|
// Final output
|
||||||
repeated uint32 output_ids = 1;
|
repeated uint32 output_ids = 1;
|
||||||
|
|
||||||
// Finish reason
|
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
|
||||||
enum FinishReason {
|
string finish_reason = 2;
|
||||||
// The model generated a stop sequence.
|
|
||||||
STOP = 0;
|
|
||||||
// The model reached the maximum generation length.
|
|
||||||
LENGTH = 1;
|
|
||||||
// The model generated an end-of-sequence (EOS) token.
|
|
||||||
EOS_TOKEN = 2;
|
|
||||||
// The model generated a user-provided stop string.
|
|
||||||
STOP_STR = 3;
|
|
||||||
// The request was aborted by the user or system.
|
|
||||||
ABORT = 4;
|
|
||||||
}
|
|
||||||
FinishReason finish_reason = 2;
|
|
||||||
|
|
||||||
// Token usage counts
|
// Token usage counts
|
||||||
int32 prompt_tokens = 3;
|
int32 prompt_tokens = 3;
|
||||||
@@ -210,6 +198,12 @@ message GenerateComplete {
|
|||||||
|
|
||||||
// All hidden states if requested
|
// All hidden states if requested
|
||||||
repeated HiddenStates all_hidden_states = 7;
|
repeated HiddenStates all_hidden_states = 7;
|
||||||
|
|
||||||
|
// Matched stop information (for stop sequences)
|
||||||
|
oneof matched_stop {
|
||||||
|
uint32 matched_token_id = 8;
|
||||||
|
string matched_stop_str = 9;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateError {
|
message GenerateError {
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -3,7 +3,6 @@ import datetime
|
|||||||
from google.protobuf import timestamp_pb2 as _timestamp_pb2
|
from google.protobuf import timestamp_pb2 as _timestamp_pb2
|
||||||
from google.protobuf import struct_pb2 as _struct_pb2
|
from google.protobuf import struct_pb2 as _struct_pb2
|
||||||
from google.protobuf.internal import containers as _containers
|
from google.protobuf.internal import containers as _containers
|
||||||
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
|
||||||
from google.protobuf import descriptor as _descriptor
|
from google.protobuf import descriptor as _descriptor
|
||||||
from google.protobuf import message as _message
|
from google.protobuf import message as _message
|
||||||
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
|
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
|
||||||
@@ -179,19 +178,7 @@ class GenerateStreamChunk(_message.Message):
|
|||||||
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
|
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateComplete(_message.Message):
|
class GenerateComplete(_message.Message):
|
||||||
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states")
|
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str")
|
||||||
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
|
||||||
__slots__ = ()
|
|
||||||
STOP: _ClassVar[GenerateComplete.FinishReason]
|
|
||||||
LENGTH: _ClassVar[GenerateComplete.FinishReason]
|
|
||||||
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
|
|
||||||
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
|
|
||||||
ABORT: _ClassVar[GenerateComplete.FinishReason]
|
|
||||||
STOP: GenerateComplete.FinishReason
|
|
||||||
LENGTH: GenerateComplete.FinishReason
|
|
||||||
EOS_TOKEN: GenerateComplete.FinishReason
|
|
||||||
STOP_STR: GenerateComplete.FinishReason
|
|
||||||
ABORT: GenerateComplete.FinishReason
|
|
||||||
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
||||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
@@ -199,14 +186,18 @@ class GenerateComplete(_message.Message):
|
|||||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
|
||||||
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
finish_reason: GenerateComplete.FinishReason
|
finish_reason: str
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
cached_tokens: int
|
cached_tokens: int
|
||||||
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
|
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
|
||||||
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
|
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
|
||||||
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
|
matched_token_id: int
|
||||||
|
matched_stop_str: str
|
||||||
|
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateError(_message.Message):
|
class GenerateError(_message.Message):
|
||||||
__slots__ = ("message", "http_status_code", "details")
|
__slots__ = ("message", "http_status_code", "details")
|
||||||
|
|||||||
@@ -185,20 +185,8 @@ message GenerateComplete {
|
|||||||
// Final output
|
// Final output
|
||||||
repeated uint32 output_ids = 1;
|
repeated uint32 output_ids = 1;
|
||||||
|
|
||||||
// Finish reason
|
// Finish reason as OpenAI-compatible string ("stop", "length", "abort")
|
||||||
enum FinishReason {
|
string finish_reason = 2;
|
||||||
// The model generated a stop sequence.
|
|
||||||
STOP = 0;
|
|
||||||
// The model reached the maximum generation length.
|
|
||||||
LENGTH = 1;
|
|
||||||
// The model generated an end-of-sequence (EOS) token.
|
|
||||||
EOS_TOKEN = 2;
|
|
||||||
// The model generated a user-provided stop string.
|
|
||||||
STOP_STR = 3;
|
|
||||||
// The request was aborted by the user or system.
|
|
||||||
ABORT = 4;
|
|
||||||
}
|
|
||||||
FinishReason finish_reason = 2;
|
|
||||||
|
|
||||||
// Token usage counts
|
// Token usage counts
|
||||||
int32 prompt_tokens = 3;
|
int32 prompt_tokens = 3;
|
||||||
@@ -210,6 +198,12 @@ message GenerateComplete {
|
|||||||
|
|
||||||
// All hidden states if requested
|
// All hidden states if requested
|
||||||
repeated HiddenStates all_hidden_states = 7;
|
repeated HiddenStates all_hidden_states = 7;
|
||||||
|
|
||||||
|
// Matched stop information (for stop sequences)
|
||||||
|
oneof matched_stop {
|
||||||
|
uint32 matched_token_id = 8;
|
||||||
|
string matched_stop_str = 9;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateError {
|
message GenerateError {
|
||||||
|
|||||||
@@ -423,10 +423,25 @@ pub struct ChatCompletionResponse {
|
|||||||
pub system_fingerprint: Option<String>,
|
pub system_fingerprint: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Response message structure for ChatCompletionResponse (different from request ChatMessage)
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatCompletionMessage {
|
||||||
|
pub role: String, // Always "assistant" for responses
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<ToolCall>>,
|
||||||
|
/// Reasoning content for O1-style models (SGLang extension)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_content: Option<String>,
|
||||||
|
// Note: function_call is deprecated and not included
|
||||||
|
// Note: refusal, annotations, audio are not added yet
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ChatChoice {
|
pub struct ChatChoice {
|
||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub message: ChatMessage,
|
pub message: ChatCompletionMessage,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub logprobs: Option<ChatLogProbs>,
|
pub logprobs: Option<ChatLogProbs>,
|
||||||
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ use axum::{
|
|||||||
extract::Request,
|
extract::Request,
|
||||||
http::{HeaderMap, StatusCode},
|
http::{HeaderMap, StatusCode},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
|
Json,
|
||||||
};
|
};
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
@@ -18,8 +19,9 @@ use crate::metrics::RouterMetrics;
|
|||||||
use crate::policies::PolicyRegistry;
|
use crate::policies::PolicyRegistry;
|
||||||
use crate::protocols::spec::ChatMessage;
|
use crate::protocols::spec::ChatMessage;
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
|
||||||
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolChoice,
|
CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams,
|
||||||
|
ResponsesRequest, StringOrArray, Tool, ToolChoice, Usage,
|
||||||
};
|
};
|
||||||
use crate::reasoning_parser::ParserFactory;
|
use crate::reasoning_parser::ParserFactory;
|
||||||
use crate::routers::RouterTrait;
|
use crate::routers::RouterTrait;
|
||||||
@@ -30,6 +32,7 @@ use crate::tokenizer::traits::Tokenizer;
|
|||||||
use crate::tokenizer::HuggingFaceTokenizer;
|
use crate::tokenizer::HuggingFaceTokenizer;
|
||||||
use crate::tool_parser::ParserRegistry;
|
use crate::tool_parser::ParserRegistry;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -648,35 +651,98 @@ impl GrpcRouter {
|
|||||||
Err(e) => return fail_fmt("Failed to start generation: ", &e),
|
Err(e) => return fail_fmt("Failed to start generation: ", &e),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the single Complete response
|
// Collect all responses (for n>1 support)
|
||||||
let gen_response = match stream.next().await {
|
let mut all_responses = Vec::new();
|
||||||
Some(Ok(r)) => r,
|
while let Some(response) = stream.next().await {
|
||||||
Some(Err(e)) => return fail_fmt("Failed to get GenerateResponse: ", &e),
|
match response {
|
||||||
None => return fail_str("No response from server"),
|
Ok(gen_response) => match gen_response.response {
|
||||||
|
Some(proto::generate_response::Response::Complete(complete)) => {
|
||||||
|
all_responses.push(complete);
|
||||||
|
}
|
||||||
|
Some(proto::generate_response::Response::Error(err)) => {
|
||||||
|
error!("Generation failed for one choice: {}", err.message);
|
||||||
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Generation failed: {}", err.message),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
Some(proto::generate_response::Response::Chunk(_)) => {
|
||||||
|
return fail_str("Unexpected chunk response for non-streaming request")
|
||||||
|
}
|
||||||
|
None => return fail_str("Empty response from server"),
|
||||||
|
},
|
||||||
|
Err(e) => return fail_fmt("Failed to get GenerateResponse: ", &e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if all_responses.is_empty() {
|
||||||
|
return fail_str("No responses from server");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each response into a ChatChoice
|
||||||
|
let mut choices = Vec::new();
|
||||||
|
for (index, complete) in all_responses.iter().enumerate() {
|
||||||
|
match self
|
||||||
|
.process_single_choice(complete, index, original_request, &mut stop_decoder)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(choice) => choices.push(choice),
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to process choice {}: {}", index, e);
|
||||||
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to process choice {}: {}", index, e),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate usage information from all responses
|
||||||
|
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
|
||||||
|
let total_completion_tokens: u32 = all_responses
|
||||||
|
.iter()
|
||||||
|
.map(|r| r.completion_tokens as u32)
|
||||||
|
.sum();
|
||||||
|
let usage = Usage {
|
||||||
|
prompt_tokens: total_prompt_tokens,
|
||||||
|
completion_tokens: total_completion_tokens,
|
||||||
|
total_tokens: total_prompt_tokens + total_completion_tokens,
|
||||||
|
completion_tokens_details: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Extract the expected variant early
|
// Build final ChatCompletionResponse
|
||||||
let complete = match gen_response.response {
|
let response = ChatCompletionResponse {
|
||||||
Some(proto::generate_response::Response::Complete(c)) => c,
|
id: format!("chatcmpl-{}", Uuid::new_v4()),
|
||||||
Some(proto::generate_response::Response::Error(err)) => {
|
object: "chat.completion".to_string(),
|
||||||
error!("Generation failed: {}", err.message);
|
created: SystemTime::now()
|
||||||
return (
|
.duration_since(UNIX_EPOCH)
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
.unwrap_or_default()
|
||||||
format!("Generation failed: {}", err.message),
|
.as_secs(),
|
||||||
)
|
model: original_request.model.clone(),
|
||||||
.into_response();
|
choices,
|
||||||
}
|
usage: Some(usage),
|
||||||
Some(proto::generate_response::Response::Chunk(_)) => {
|
system_fingerprint: None,
|
||||||
return fail_str("Unexpected chunk response for non-streaming request")
|
|
||||||
}
|
|
||||||
None => return fail_str("Empty response from server"),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Serialize and return JSON response
|
||||||
|
Json(response).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a single GenerateComplete response into a ChatChoice
|
||||||
|
async fn process_single_choice(
|
||||||
|
&self,
|
||||||
|
complete: &proto::GenerateComplete,
|
||||||
|
index: usize,
|
||||||
|
original_request: &ChatCompletionRequest,
|
||||||
|
stop_decoder: &mut crate::tokenizer::stop::StopSequenceDecoder,
|
||||||
|
) -> Result<ChatChoice, String> {
|
||||||
|
stop_decoder.reset();
|
||||||
// Decode tokens
|
// Decode tokens
|
||||||
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
|
let outputs = stop_decoder
|
||||||
Ok(o) => o,
|
.process_tokens(&complete.output_ids)
|
||||||
Err(e) => return fail_fmt("Failed to process tokens: ", &e),
|
.map_err(|e| format!("Failed to process tokens: {}", e))?;
|
||||||
};
|
|
||||||
|
|
||||||
// Accumulate text with early breaks
|
// Accumulate text with early breaks
|
||||||
let mut final_text = String::new();
|
let mut final_text = String::new();
|
||||||
@@ -697,8 +763,119 @@ impl GrpcRouter {
|
|||||||
final_text.push_str(&t);
|
final_text.push_str(&t);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Create proper OpenAI-compatible response
|
// Step 1: Handle reasoning content parsing
|
||||||
(StatusCode::OK, format!("Final text: {}", final_text)).into_response()
|
let mut reasoning_text: Option<String> = None;
|
||||||
|
let mut processed_text = final_text;
|
||||||
|
|
||||||
|
// Check if reasoning parsing is enabled and separate_reasoning is requested
|
||||||
|
if original_request.separate_reasoning {
|
||||||
|
if let Ok(mut parser) = self
|
||||||
|
.reasoning_parser_factory
|
||||||
|
.create(&original_request.model)
|
||||||
|
{
|
||||||
|
match parser.detect_and_parse_reasoning(&processed_text) {
|
||||||
|
Ok(result) => {
|
||||||
|
if !result.reasoning_text.is_empty() {
|
||||||
|
reasoning_text = Some(result.reasoning_text);
|
||||||
|
}
|
||||||
|
processed_text = result.normal_text;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(format!("Reasoning parsing error: {}", e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Handle tool call parsing
|
||||||
|
let mut tool_calls: Option<Vec<crate::protocols::spec::ToolCall>> = None;
|
||||||
|
|
||||||
|
// Check if tool calls should be processed
|
||||||
|
let tool_choice_enabled = !matches!(
|
||||||
|
&original_request.tool_choice,
|
||||||
|
Some(ToolChoice::Value(
|
||||||
|
crate::protocols::spec::ToolChoiceValue::None
|
||||||
|
))
|
||||||
|
);
|
||||||
|
|
||||||
|
if tool_choice_enabled && original_request.tools.is_some() {
|
||||||
|
if let Some(parser) = self
|
||||||
|
.tool_parser_registry
|
||||||
|
.get_parser(&original_request.model)
|
||||||
|
{
|
||||||
|
match parser.parse_complete(&processed_text).await {
|
||||||
|
Ok(parsed_tool_calls) => {
|
||||||
|
if !parsed_tool_calls.is_empty() {
|
||||||
|
let spec_tool_calls = parsed_tool_calls
|
||||||
|
.into_iter()
|
||||||
|
.map(|tc| crate::protocols::spec::ToolCall {
|
||||||
|
id: tc.id,
|
||||||
|
tool_type: "function".to_string(),
|
||||||
|
function: crate::protocols::spec::FunctionCallResponse {
|
||||||
|
name: tc.function.name,
|
||||||
|
arguments: Some(
|
||||||
|
serde_json::to_string(&tc.function.arguments)
|
||||||
|
.unwrap_or_else(|_| "{}".to_string()),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
tool_calls = Some(spec_tool_calls);
|
||||||
|
processed_text = String::new();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Tool call parsing error: {}", e);
|
||||||
|
// Continue without tool calls rather than failing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Use finish reason directly from proto (already OpenAI-compatible string)
|
||||||
|
let finish_reason_str = &complete.finish_reason;
|
||||||
|
|
||||||
|
// Override finish reason if we have tool calls
|
||||||
|
let final_finish_reason_str = if tool_calls.is_some() {
|
||||||
|
"tool_calls"
|
||||||
|
} else {
|
||||||
|
finish_reason_str
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract matched_stop information from proto
|
||||||
|
let matched_stop = match &complete.matched_stop {
|
||||||
|
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => Some(
|
||||||
|
serde_json::Value::Number(serde_json::Number::from(*token_id)),
|
||||||
|
),
|
||||||
|
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
|
||||||
|
Some(serde_json::Value::String(stop_str.clone()))
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Step 4: Build ChatCompletionMessage (proper response message type)
|
||||||
|
let chat_message = ChatCompletionMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: if processed_text.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(processed_text)
|
||||||
|
},
|
||||||
|
tool_calls,
|
||||||
|
reasoning_content: reasoning_text,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Step 5: Build ChatChoice
|
||||||
|
let choice = ChatChoice {
|
||||||
|
index: index as u32,
|
||||||
|
message: chat_message,
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: Some(final_finish_reason_str.to_string()),
|
||||||
|
matched_stop,
|
||||||
|
hidden_states: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(choice)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user