[router][grpc] Support streaming for v1/chat/completions (#11179)
This commit is contained in:
@@ -578,7 +578,7 @@ class GrpcRequestManager:
|
|||||||
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
|
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
|
||||||
),
|
),
|
||||||
"finish_reason": (
|
"finish_reason": (
|
||||||
str(batch_out.finished_reasons[i])
|
batch_out.finished_reasons[i]
|
||||||
if batch_out.finished_reasons[i]
|
if batch_out.finished_reasons[i]
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -112,7 +112,6 @@ def _launch_scheduler_process_only(
|
|||||||
pp_rank,
|
pp_rank,
|
||||||
None,
|
None,
|
||||||
writer,
|
writer,
|
||||||
None,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -583,6 +582,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
cached_tokens=meta_info.get("cached_tokens", 0),
|
cached_tokens=meta_info.get("cached_tokens", 0),
|
||||||
output_logprobs=output_logprobs_proto,
|
output_logprobs=output_logprobs_proto,
|
||||||
input_logprobs=input_logprobs_proto,
|
input_logprobs=input_logprobs_proto,
|
||||||
|
index=output.get("index", 0),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -640,6 +640,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
|||||||
cached_tokens=meta_info.get("cached_tokens", 0),
|
cached_tokens=meta_info.get("cached_tokens", 0),
|
||||||
output_logprobs=output_logprobs_proto,
|
output_logprobs=output_logprobs_proto,
|
||||||
input_logprobs=input_logprobs_proto,
|
input_logprobs=input_logprobs_proto,
|
||||||
|
index=output.get("index", 0),
|
||||||
**matched_stop_kwargs,
|
**matched_stop_kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -179,6 +179,9 @@ message GenerateStreamChunk {
|
|||||||
|
|
||||||
// Input logprobs (if requested) - only in first chunk
|
// Input logprobs (if requested) - only in first chunk
|
||||||
InputLogProbs input_logprobs = 7;
|
InputLogProbs input_logprobs = 7;
|
||||||
|
|
||||||
|
// Index for ordering when n>1 (for parallel request multiplexing)
|
||||||
|
uint32 index = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateComplete {
|
message GenerateComplete {
|
||||||
@@ -207,6 +210,9 @@ message GenerateComplete {
|
|||||||
|
|
||||||
// Input logprobs if requested (for prompt tokens)
|
// Input logprobs if requested (for prompt tokens)
|
||||||
InputLogProbs input_logprobs = 10;
|
InputLogProbs input_logprobs = 10;
|
||||||
|
|
||||||
|
// Index for ordering when n>1 (for parallel request multiplexing)
|
||||||
|
uint32 index = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateError {
|
message GenerateError {
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -160,7 +160,7 @@ class GenerateResponse(_message.Message):
|
|||||||
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateStreamChunk(_message.Message):
|
class GenerateStreamChunk(_message.Message):
|
||||||
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
|
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index")
|
||||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
@@ -168,6 +168,7 @@ class GenerateStreamChunk(_message.Message):
|
|||||||
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||||
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
INDEX_FIELD_NUMBER: _ClassVar[int]
|
||||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
@@ -175,10 +176,11 @@ class GenerateStreamChunk(_message.Message):
|
|||||||
output_logprobs: OutputLogProbs
|
output_logprobs: OutputLogProbs
|
||||||
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
||||||
input_logprobs: InputLogProbs
|
input_logprobs: InputLogProbs
|
||||||
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ...
|
index: int
|
||||||
|
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateComplete(_message.Message):
|
class GenerateComplete(_message.Message):
|
||||||
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
|
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index")
|
||||||
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
||||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
@@ -189,6 +191,7 @@ class GenerateComplete(_message.Message):
|
|||||||
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
|
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
|
||||||
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
INDEX_FIELD_NUMBER: _ClassVar[int]
|
||||||
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
@@ -199,7 +202,8 @@ class GenerateComplete(_message.Message):
|
|||||||
matched_token_id: int
|
matched_token_id: int
|
||||||
matched_stop_str: str
|
matched_stop_str: str
|
||||||
input_logprobs: InputLogProbs
|
input_logprobs: InputLogProbs
|
||||||
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ...
|
index: int
|
||||||
|
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
|
||||||
|
|
||||||
class GenerateError(_message.Message):
|
class GenerateError(_message.Message):
|
||||||
__slots__ = ("message", "http_status_code", "details")
|
__slots__ = ("message", "http_status_code", "details")
|
||||||
|
|||||||
@@ -192,7 +192,6 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
|
|||||||
content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)),
|
content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)),
|
||||||
name: None,
|
name: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
function_call: None,
|
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -179,6 +179,9 @@ message GenerateStreamChunk {
|
|||||||
|
|
||||||
// Input logprobs (if requested) - only in first chunk
|
// Input logprobs (if requested) - only in first chunk
|
||||||
InputLogProbs input_logprobs = 7;
|
InputLogProbs input_logprobs = 7;
|
||||||
|
|
||||||
|
// Index for ordering when n>1 (for parallel request multiplexing)
|
||||||
|
uint32 index = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateComplete {
|
message GenerateComplete {
|
||||||
@@ -207,6 +210,9 @@ message GenerateComplete {
|
|||||||
|
|
||||||
// Input logprobs if requested (for prompt tokens)
|
// Input logprobs if requested (for prompt tokens)
|
||||||
InputLogProbs input_logprobs = 10;
|
InputLogProbs input_logprobs = 10;
|
||||||
|
|
||||||
|
// Index for ordering when n>1 (for parallel request multiplexing)
|
||||||
|
uint32 index = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GenerateError {
|
message GenerateError {
|
||||||
|
|||||||
@@ -72,8 +72,6 @@ pub enum ChatMessage {
|
|||||||
name: Option<String>,
|
name: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
tool_calls: Option<Vec<ToolCall>>,
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
function_call: Option<FunctionCallResponse>,
|
|
||||||
/// Reasoning content for O1-style models (SGLang extension)
|
/// Reasoning content for O1-style models (SGLang extension)
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
reasoning_content: Option<String>,
|
reasoning_content: Option<String>,
|
||||||
@@ -140,8 +138,6 @@ pub struct ChatMessageDelta {
|
|||||||
pub content: Option<String>,
|
pub content: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub function_call: Option<FunctionCallDelta>,
|
|
||||||
/// Reasoning content delta for O1-style models (SGLang extension)
|
/// Reasoning content delta for O1-style models (SGLang extension)
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub reasoning_content: Option<String>,
|
pub reasoning_content: Option<String>,
|
||||||
@@ -473,6 +469,8 @@ pub struct ChatStreamChoice {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub logprobs: Option<ChatLogProbs>,
|
pub logprobs: Option<ChatLogProbs>,
|
||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub matched_stop: Option<Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Completions API request types (v1/completions) - DEPRECATED but still supported
|
// Completions API request types (v1/completions) - DEPRECATED but still supported
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ graph TB
|
|||||||
end
|
end
|
||||||
|
|
||||||
subgraph Factory Layer
|
subgraph Factory Layer
|
||||||
MID --> PF[ParserFactory]
|
MID --> PF[ReasoningParserFactory]
|
||||||
PF --> REG[ParserRegistry]
|
PF --> REG[ParserRegistry]
|
||||||
REG --> PM[Pattern Matching]
|
REG --> PM[Pattern Matching]
|
||||||
PM --> PP[Parser Pool]
|
PM --> PP[Parser Pool]
|
||||||
@@ -93,7 +93,7 @@ graph TB
|
|||||||
```mermaid
|
```mermaid
|
||||||
sequenceDiagram
|
sequenceDiagram
|
||||||
participant C as Client
|
participant C as Client
|
||||||
participant F as ParserFactory
|
participant F as ReasoningParserFactory
|
||||||
participant R as Registry
|
participant R as Registry
|
||||||
participant P as Parser Pool
|
participant P as Parser Pool
|
||||||
participant BP as BaseParser
|
participant BP as BaseParser
|
||||||
@@ -206,7 +206,7 @@ classDiagram
|
|||||||
+new() Self
|
+new() Self
|
||||||
}
|
}
|
||||||
|
|
||||||
class ParserFactory {
|
class ReasoningParserFactory {
|
||||||
-registry: ParserRegistry
|
-registry: ParserRegistry
|
||||||
+new() Self
|
+new() Self
|
||||||
+get_pooled(model_id: &str) PooledParser
|
+get_pooled(model_id: &str) PooledParser
|
||||||
@@ -240,7 +240,7 @@ classDiagram
|
|||||||
Step3Parser o-- BaseReasoningParser
|
Step3Parser o-- BaseReasoningParser
|
||||||
|
|
||||||
BaseReasoningParser o-- ParserConfig
|
BaseReasoningParser o-- ParserConfig
|
||||||
ParserFactory o-- ParserRegistry
|
ReasoningParserFactory o-- ParserRegistry
|
||||||
ParserRegistry o-- ReasoningParser
|
ParserRegistry o-- ReasoningParser
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -302,7 +302,7 @@ classDiagram
|
|||||||
- Delegate to get_pooled_parser
|
- Delegate to get_pooled_parser
|
||||||
- Case-insensitive comparison
|
- Case-insensitive comparison
|
||||||
|
|
||||||
**ParserFactory Methods**:
|
**ReasoningParserFactory Methods**:
|
||||||
|
|
||||||
1. **`new()`**:
|
1. **`new()`**:
|
||||||
- Register all built-in parsers
|
- Register all built-in parsers
|
||||||
@@ -437,7 +437,7 @@ impl ReasoningParser for MyModelParser {
|
|||||||
**Step 2: Register in Factory**
|
**Step 2: Register in Factory**
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
// In factory.rs ParserFactory::new()
|
// In factory.rs ReasoningParserFactory::new()
|
||||||
registry.register_parser("mymodel", || {
|
registry.register_parser("mymodel", || {
|
||||||
Box::new(MyModelParser::new())
|
Box::new(MyModelParser::new())
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -128,11 +128,11 @@ impl Default for ParserRegistry {
|
|||||||
|
|
||||||
/// Factory for creating reasoning parsers based on model type.
|
/// Factory for creating reasoning parsers based on model type.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ParserFactory {
|
pub struct ReasoningParserFactory {
|
||||||
registry: ParserRegistry,
|
registry: ParserRegistry,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ParserFactory {
|
impl ReasoningParserFactory {
|
||||||
/// Create a new factory with default parsers registered.
|
/// Create a new factory with default parsers registered.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let registry = ParserRegistry::new();
|
let registry = ParserRegistry::new();
|
||||||
@@ -237,7 +237,7 @@ impl ParserFactory {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ParserFactory {
|
impl Default for ReasoningParserFactory {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self::new()
|
Self::new()
|
||||||
}
|
}
|
||||||
@@ -249,35 +249,35 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_factory_creates_deepseek_r1() {
|
fn test_factory_creates_deepseek_r1() {
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let parser = factory.create("deepseek-r1-distill").unwrap();
|
let parser = factory.create("deepseek-r1-distill").unwrap();
|
||||||
assert_eq!(parser.model_type(), "deepseek_r1");
|
assert_eq!(parser.model_type(), "deepseek_r1");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_factory_creates_qwen3() {
|
fn test_factory_creates_qwen3() {
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let parser = factory.create("qwen3-7b").unwrap();
|
let parser = factory.create("qwen3-7b").unwrap();
|
||||||
assert_eq!(parser.model_type(), "qwen3");
|
assert_eq!(parser.model_type(), "qwen3");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_factory_creates_kimi() {
|
fn test_factory_creates_kimi() {
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let parser = factory.create("kimi-chat").unwrap();
|
let parser = factory.create("kimi-chat").unwrap();
|
||||||
assert_eq!(parser.model_type(), "kimi");
|
assert_eq!(parser.model_type(), "kimi");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_factory_fallback_to_passthrough() {
|
fn test_factory_fallback_to_passthrough() {
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let parser = factory.create("unknown-model").unwrap();
|
let parser = factory.create("unknown-model").unwrap();
|
||||||
assert_eq!(parser.model_type(), "passthrough");
|
assert_eq!(parser.model_type(), "passthrough");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_case_insensitive_matching() {
|
fn test_case_insensitive_matching() {
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let parser1 = factory.create("DeepSeek-R1").unwrap();
|
let parser1 = factory.create("DeepSeek-R1").unwrap();
|
||||||
let parser2 = factory.create("QWEN3").unwrap();
|
let parser2 = factory.create("QWEN3").unwrap();
|
||||||
let parser3 = factory.create("Kimi").unwrap();
|
let parser3 = factory.create("Kimi").unwrap();
|
||||||
@@ -289,21 +289,21 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_step3_model() {
|
fn test_step3_model() {
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let step3 = factory.create("step3-model").unwrap();
|
let step3 = factory.create("step3-model").unwrap();
|
||||||
assert_eq!(step3.model_type(), "step3");
|
assert_eq!(step3.model_type(), "step3");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_glm45_model() {
|
fn test_glm45_model() {
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let glm45 = factory.create("glm45-v2").unwrap();
|
let glm45 = factory.create("glm45-v2").unwrap();
|
||||||
assert_eq!(glm45.model_type(), "glm45");
|
assert_eq!(glm45.model_type(), "glm45");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pooled_parser_reuse() {
|
fn test_pooled_parser_reuse() {
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
|
|
||||||
// Get the same parser twice - should be the same instance
|
// Get the same parser twice - should be the same instance
|
||||||
let parser1 = factory.get_pooled("deepseek-r1");
|
let parser1 = factory.get_pooled("deepseek-r1");
|
||||||
@@ -321,7 +321,7 @@ mod tests {
|
|||||||
fn test_pooled_parser_concurrent_access() {
|
fn test_pooled_parser_concurrent_access() {
|
||||||
use std::thread;
|
use std::thread;
|
||||||
|
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let parser = factory.get_pooled("deepseek-r1");
|
let parser = factory.get_pooled("deepseek-r1");
|
||||||
|
|
||||||
// Spawn multiple threads that use the same parser
|
// Spawn multiple threads that use the same parser
|
||||||
@@ -347,7 +347,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pool_clearing() {
|
fn test_pool_clearing() {
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
|
|
||||||
// Get a pooled parser
|
// Get a pooled parser
|
||||||
let parser1 = factory.get_pooled("deepseek-r1");
|
let parser1 = factory.get_pooled("deepseek-r1");
|
||||||
@@ -364,7 +364,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_passthrough_parser_pooling() {
|
fn test_passthrough_parser_pooling() {
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
|
|
||||||
// Unknown models should get passthrough parser
|
// Unknown models should get passthrough parser
|
||||||
let parser1 = factory.get_pooled("unknown-model-1");
|
let parser1 = factory.get_pooled("unknown-model-1");
|
||||||
@@ -383,7 +383,7 @@ mod tests {
|
|||||||
use std::thread;
|
use std::thread;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let num_threads = 100;
|
let num_threads = 100;
|
||||||
let requests_per_thread = 50;
|
let requests_per_thread = 50;
|
||||||
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
|
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
|
||||||
@@ -527,7 +527,7 @@ mod tests {
|
|||||||
fn test_concurrent_pool_modifications() {
|
fn test_concurrent_pool_modifications() {
|
||||||
use std::thread;
|
use std::thread;
|
||||||
|
|
||||||
let factory = ParserFactory::new();
|
let factory = ReasoningParserFactory::new();
|
||||||
let mut handles = vec![];
|
let mut handles = vec![];
|
||||||
|
|
||||||
// Thread 1: Continuously get parsers
|
// Thread 1: Continuously get parsers
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ pub mod factory;
|
|||||||
pub mod parsers;
|
pub mod parsers;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
pub use factory::{ParserFactory, ParserRegistry, PooledParser};
|
pub use factory::{ParserRegistry, PooledParser, ReasoningParserFactory};
|
||||||
pub use parsers::{
|
pub use parsers::{
|
||||||
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
|
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
|
||||||
QwenThinkingParser, Step3Parser,
|
QwenThinkingParser, Step3Parser,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use crate::config::types::RetryConfig;
|
|||||||
use crate::core::{WorkerRegistry, WorkerType};
|
use crate::core::{WorkerRegistry, WorkerType};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::policies::PolicyRegistry;
|
use crate::policies::PolicyRegistry;
|
||||||
use crate::reasoning_parser::ParserFactory;
|
use crate::reasoning_parser::ReasoningParserFactory;
|
||||||
use crate::routers::RouterTrait;
|
use crate::routers::RouterTrait;
|
||||||
use crate::tokenizer::traits::Tokenizer;
|
use crate::tokenizer::traits::Tokenizer;
|
||||||
use crate::tool_parser::ToolParserFactory;
|
use crate::tool_parser::ToolParserFactory;
|
||||||
@@ -24,7 +24,7 @@ pub struct GrpcPDRouter {
|
|||||||
worker_registry: Arc<WorkerRegistry>,
|
worker_registry: Arc<WorkerRegistry>,
|
||||||
policy_registry: Arc<PolicyRegistry>,
|
policy_registry: Arc<PolicyRegistry>,
|
||||||
tokenizer: Arc<dyn Tokenizer>,
|
tokenizer: Arc<dyn Tokenizer>,
|
||||||
reasoning_parser_factory: ParserFactory,
|
reasoning_parser_factory: ReasoningParserFactory,
|
||||||
tool_parser_factory: ToolParserFactory,
|
tool_parser_factory: ToolParserFactory,
|
||||||
|
|
||||||
dp_aware: bool,
|
dp_aware: bool,
|
||||||
|
|||||||
@@ -7,10 +7,14 @@ use async_trait::async_trait;
|
|||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
extract::Request,
|
extract::Request,
|
||||||
http::{HeaderMap, StatusCode},
|
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
Json,
|
Json,
|
||||||
};
|
};
|
||||||
|
use bytes::Bytes;
|
||||||
|
use std::io;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::RetryConfig;
|
||||||
@@ -21,11 +25,12 @@ use crate::policies::PolicyRegistry;
|
|||||||
use crate::protocols::spec::ChatMessage;
|
use crate::protocols::spec::ChatMessage;
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
|
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
|
||||||
CompletionRequest, EmbeddingRequest, FunctionCallResponse, GenerateRequest, RerankRequest,
|
ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, CompletionRequest,
|
||||||
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolChoice,
|
EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, GenerateRequest, RerankRequest,
|
||||||
|
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice,
|
||||||
ToolChoiceValue, Usage,
|
ToolChoiceValue, Usage,
|
||||||
};
|
};
|
||||||
use crate::reasoning_parser::ParserFactory;
|
use crate::reasoning_parser::{ParserResult, ReasoningParserFactory};
|
||||||
use crate::routers::RouterTrait;
|
use crate::routers::RouterTrait;
|
||||||
use crate::server::AppContext;
|
use crate::server::AppContext;
|
||||||
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
||||||
@@ -34,7 +39,7 @@ use crate::tokenizer::stop::{
|
|||||||
};
|
};
|
||||||
use crate::tokenizer::traits::Tokenizer;
|
use crate::tokenizer::traits::Tokenizer;
|
||||||
use crate::tokenizer::HuggingFaceTokenizer;
|
use crate::tokenizer::HuggingFaceTokenizer;
|
||||||
use crate::tool_parser::ToolParserFactory;
|
use crate::tool_parser::{StreamingParseResult, ToolParserFactory};
|
||||||
use proto::generate_response::Response::{Chunk, Complete, Error};
|
use proto::generate_response::Response::{Chunk, Complete, Error};
|
||||||
use serde_json::{json, Map, Value};
|
use serde_json::{json, Map, Value};
|
||||||
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
use std::time::{Instant, SystemTime, UNIX_EPOCH};
|
||||||
@@ -50,12 +55,13 @@ pub struct ProcessedMessages {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// gRPC router implementation for SGLang
|
/// gRPC router implementation for SGLang
|
||||||
|
#[derive(Clone)]
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub struct GrpcRouter {
|
pub struct GrpcRouter {
|
||||||
worker_registry: Arc<WorkerRegistry>,
|
worker_registry: Arc<WorkerRegistry>,
|
||||||
policy_registry: Arc<PolicyRegistry>,
|
policy_registry: Arc<PolicyRegistry>,
|
||||||
tokenizer: Arc<dyn Tokenizer>,
|
tokenizer: Arc<dyn Tokenizer>,
|
||||||
reasoning_parser_factory: ParserFactory,
|
reasoning_parser_factory: ReasoningParserFactory,
|
||||||
tool_parser_factory: ToolParserFactory,
|
tool_parser_factory: ToolParserFactory,
|
||||||
dp_aware: bool,
|
dp_aware: bool,
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
@@ -776,10 +782,11 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Parse tool calls using model-specific parser
|
/// Parse tool calls using model-specific parser
|
||||||
async fn parse_with_model_parser(
|
async fn parse_tool_calls(
|
||||||
&self,
|
&self,
|
||||||
processed_text: &str,
|
processed_text: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
|
history_tool_calls_count: usize,
|
||||||
) -> (Option<Vec<ToolCall>>, String) {
|
) -> (Option<Vec<ToolCall>>, String) {
|
||||||
// Get pooled parser for this model
|
// Get pooled parser for this model
|
||||||
let pooled_parser = self.tool_parser_factory.get_pooled(model);
|
let pooled_parser = self.tool_parser_factory.get_pooled(model);
|
||||||
@@ -810,16 +817,26 @@ impl GrpcRouter {
|
|||||||
|
|
||||||
let spec_tool_calls = parsed_tool_calls
|
let spec_tool_calls = parsed_tool_calls
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|tc| ToolCall {
|
.enumerate()
|
||||||
id: tc.id,
|
.map(|(index, tc)| {
|
||||||
tool_type: "function".to_string(),
|
// Generate ID for this tool call
|
||||||
function: FunctionCallResponse {
|
let id = Self::generate_tool_call_id(
|
||||||
name: tc.function.name,
|
model,
|
||||||
arguments: Some(
|
&tc.function.name,
|
||||||
serde_json::to_string(&tc.function.arguments)
|
index,
|
||||||
.unwrap_or_else(|_| "{}".to_string()),
|
history_tool_calls_count,
|
||||||
),
|
);
|
||||||
},
|
ToolCall {
|
||||||
|
id,
|
||||||
|
tool_type: "function".to_string(),
|
||||||
|
function: FunctionCallResponse {
|
||||||
|
name: tc.function.name,
|
||||||
|
arguments: Some(
|
||||||
|
serde_json::to_string(&tc.function.arguments)
|
||||||
|
.unwrap_or_else(|_| "{}".to_string()),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
(Some(spec_tool_calls), normal_text)
|
(Some(spec_tool_calls), normal_text)
|
||||||
@@ -920,6 +937,47 @@ impl GrpcRouter {
|
|||||||
builder.build()
|
builder.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Count the number of tool calls in the request message history
|
||||||
|
/// This is used for KimiK2 format which needs globally unique indices
|
||||||
|
fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize {
|
||||||
|
request
|
||||||
|
.messages
|
||||||
|
.iter()
|
||||||
|
.filter_map(|msg| {
|
||||||
|
if let ChatMessage::Assistant { tool_calls, .. } = msg {
|
||||||
|
tool_calls.as_ref().map(|calls| calls.len())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a tool call ID based on model format
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `model` - Model name to determine ID format
|
||||||
|
/// * `tool_name` - Name of the tool being called
|
||||||
|
/// * `tool_index` - Index of this tool call within the current message
|
||||||
|
/// * `history_count` - Number of tool calls in previous messages
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// A unique ID string. KimiK2 uses `functions.{name}:{global_index}`, others use `call_{uuid}`
|
||||||
|
fn generate_tool_call_id(
|
||||||
|
model: &str,
|
||||||
|
tool_name: &str,
|
||||||
|
tool_index: usize,
|
||||||
|
history_count: usize,
|
||||||
|
) -> String {
|
||||||
|
if model.to_lowercase().contains("kimi") {
|
||||||
|
// KimiK2 format: functions.{name}:{global_index}
|
||||||
|
format!("functions.{}:{}", tool_name, history_count + tool_index)
|
||||||
|
} else {
|
||||||
|
// Standard OpenAI format: call_{24-char-uuid}
|
||||||
|
format!("call_{}", &Uuid::new_v4().simple().to_string()[..24])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Process a chunk of tokens through the stop decoder
|
/// Process a chunk of tokens through the stop decoder
|
||||||
fn process_chunk_tokens(
|
fn process_chunk_tokens(
|
||||||
stop_decoder: &mut StopSequenceDecoder,
|
stop_decoder: &mut StopSequenceDecoder,
|
||||||
@@ -953,6 +1011,230 @@ impl GrpcRouter {
|
|||||||
(chunk_text, false) // Return text and continue processing
|
(chunk_text, false) // Return text and continue processing
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Helper: Process reasoning content in streaming mode
|
||||||
|
/// Returns (modified_delta, optional_reasoning_chunk)
|
||||||
|
fn process_reasoning_stream(
|
||||||
|
&self,
|
||||||
|
delta: &str,
|
||||||
|
index: u32,
|
||||||
|
reasoning_parsers: &mut HashMap<
|
||||||
|
u32,
|
||||||
|
Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>,
|
||||||
|
>,
|
||||||
|
request_id: &str,
|
||||||
|
model: &str,
|
||||||
|
created: u64,
|
||||||
|
) -> (String, Option<ChatCompletionStreamResponse>) {
|
||||||
|
// Get or create parser for this index
|
||||||
|
reasoning_parsers
|
||||||
|
.entry(index)
|
||||||
|
.or_insert_with(|| self.reasoning_parser_factory.get_pooled(model));
|
||||||
|
|
||||||
|
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
|
||||||
|
let parse_result = {
|
||||||
|
let mut parser = pooled_parser.lock().unwrap();
|
||||||
|
parser.parse_reasoning_streaming_incremental(delta)
|
||||||
|
};
|
||||||
|
|
||||||
|
match parse_result {
|
||||||
|
Ok(ParserResult {
|
||||||
|
reasoning_text,
|
||||||
|
normal_text,
|
||||||
|
}) => {
|
||||||
|
let chunk = if !reasoning_text.is_empty() {
|
||||||
|
Some(ChatCompletionStreamResponse {
|
||||||
|
id: request_id.to_string(),
|
||||||
|
object: "chat.completion.chunk".to_string(),
|
||||||
|
created,
|
||||||
|
model: model.to_string(),
|
||||||
|
system_fingerprint: None,
|
||||||
|
choices: vec![ChatStreamChoice {
|
||||||
|
index,
|
||||||
|
delta: ChatMessageDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: None,
|
||||||
|
tool_calls: None,
|
||||||
|
reasoning_content: Some(reasoning_text),
|
||||||
|
},
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: None,
|
||||||
|
matched_stop: None,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
return (normal_text, chunk);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Reasoning parsing error: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(delta.to_string(), None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper: Process tool calls in streaming mode
|
||||||
|
/// Returns (should_skip_content, chunks_to_emit)
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
async fn process_tool_calls_stream(
|
||||||
|
&self,
|
||||||
|
delta: &str,
|
||||||
|
index: u32,
|
||||||
|
tool_parsers: &mut HashMap<
|
||||||
|
u32,
|
||||||
|
Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>,
|
||||||
|
>,
|
||||||
|
has_tool_calls: &mut HashMap<u32, bool>,
|
||||||
|
tools: &[crate::protocols::spec::Tool],
|
||||||
|
request_id: &str,
|
||||||
|
model: &str,
|
||||||
|
created: u64,
|
||||||
|
history_tool_calls_count: usize,
|
||||||
|
) -> (bool, Vec<ChatCompletionStreamResponse>) {
|
||||||
|
let mut chunks = Vec::new();
|
||||||
|
|
||||||
|
// Get or create parser for this index
|
||||||
|
tool_parsers
|
||||||
|
.entry(index)
|
||||||
|
.or_insert_with(|| self.tool_parser_factory.get_pooled(model));
|
||||||
|
|
||||||
|
if let Some(pooled_parser) = tool_parsers.get(&index) {
|
||||||
|
let mut parser = pooled_parser.lock().await;
|
||||||
|
match parser.parse_incremental(delta, tools).await {
|
||||||
|
Ok(StreamingParseResult { normal_text, calls }) => {
|
||||||
|
// Emit normal text if present
|
||||||
|
if !normal_text.is_empty() {
|
||||||
|
chunks.push(ChatCompletionStreamResponse {
|
||||||
|
id: request_id.to_string(),
|
||||||
|
object: "chat.completion.chunk".to_string(),
|
||||||
|
created,
|
||||||
|
model: model.to_string(),
|
||||||
|
system_fingerprint: None,
|
||||||
|
choices: vec![ChatStreamChoice {
|
||||||
|
index,
|
||||||
|
delta: ChatMessageDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: Some(normal_text),
|
||||||
|
tool_calls: None,
|
||||||
|
reasoning_content: None,
|
||||||
|
},
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: None,
|
||||||
|
matched_stop: None,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit tool call chunks
|
||||||
|
for tool_call_item in calls {
|
||||||
|
has_tool_calls.insert(index, true);
|
||||||
|
|
||||||
|
let tool_call_id = if let Some(ref name) = tool_call_item.name {
|
||||||
|
Some(Self::generate_tool_call_id(
|
||||||
|
model,
|
||||||
|
name,
|
||||||
|
tool_call_item.tool_index,
|
||||||
|
history_tool_calls_count,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let tool_call_delta = ToolCallDelta {
|
||||||
|
index: tool_call_item.tool_index as u32,
|
||||||
|
id: tool_call_id,
|
||||||
|
tool_type: if tool_call_item.name.is_some() {
|
||||||
|
Some("function".to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
function: Some(FunctionCallDelta {
|
||||||
|
name: tool_call_item.name,
|
||||||
|
arguments: if !tool_call_item.parameters.is_empty() {
|
||||||
|
Some(tool_call_item.parameters)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
chunks.push(ChatCompletionStreamResponse {
|
||||||
|
id: request_id.to_string(),
|
||||||
|
object: "chat.completion.chunk".to_string(),
|
||||||
|
created,
|
||||||
|
model: model.to_string(),
|
||||||
|
system_fingerprint: None,
|
||||||
|
choices: vec![ChatStreamChoice {
|
||||||
|
index,
|
||||||
|
delta: ChatMessageDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: None,
|
||||||
|
tool_calls: Some(vec![tool_call_delta]),
|
||||||
|
reasoning_content: None,
|
||||||
|
},
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: None,
|
||||||
|
matched_stop: None,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we emitted chunks, skip regular content
|
||||||
|
return (!chunks.is_empty(), chunks);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Tool call parsing error: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(false, chunks)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper: Create content chunk
|
||||||
|
fn create_content_chunk(
|
||||||
|
content: String,
|
||||||
|
index: u32,
|
||||||
|
request_id: &str,
|
||||||
|
model: &str,
|
||||||
|
created: u64,
|
||||||
|
logprobs: Option<crate::protocols::spec::ChatLogProbs>,
|
||||||
|
) -> ChatCompletionStreamResponse {
|
||||||
|
ChatCompletionStreamResponse {
|
||||||
|
id: request_id.to_string(),
|
||||||
|
object: "chat.completion.chunk".to_string(),
|
||||||
|
created,
|
||||||
|
model: model.to_string(),
|
||||||
|
system_fingerprint: None,
|
||||||
|
choices: vec![ChatStreamChoice {
|
||||||
|
index,
|
||||||
|
delta: ChatMessageDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: Some(content),
|
||||||
|
tool_calls: None,
|
||||||
|
reasoning_content: None,
|
||||||
|
},
|
||||||
|
logprobs,
|
||||||
|
finish_reason: None,
|
||||||
|
matched_stop: None,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper: Format response as SSE chunk
|
||||||
|
fn format_sse_chunk(response: &ChatCompletionStreamResponse) -> String {
|
||||||
|
format!(
|
||||||
|
"data: {}\n\n",
|
||||||
|
serde_json::to_string(response).unwrap_or_default()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// Submit request and handle streaming response for chat completions route
|
/// Submit request and handle streaming response for chat completions route
|
||||||
async fn handle_streaming_chat(
|
async fn handle_streaming_chat(
|
||||||
&self,
|
&self,
|
||||||
@@ -960,14 +1242,13 @@ impl GrpcRouter {
|
|||||||
request: proto::GenerateRequest,
|
request: proto::GenerateRequest,
|
||||||
original_request: &ChatCompletionRequest,
|
original_request: &ChatCompletionRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let mut stop_decoder = self.create_stop_decoder(
|
let request_id = request.request_id.clone();
|
||||||
original_request.stop.as_ref(),
|
let model = original_request.model.clone();
|
||||||
original_request.stop_token_ids.as_ref(),
|
|
||||||
original_request.skip_special_tokens,
|
|
||||||
original_request.no_stop_trim,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Process streaming tokens
|
// Create channel for SSE streaming
|
||||||
|
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
|
||||||
|
|
||||||
|
// Start the gRPC stream
|
||||||
let mut grpc_stream = match client.generate(request).await {
|
let mut grpc_stream = match client.generate(request).await {
|
||||||
Ok(stream) => stream,
|
Ok(stream) => stream,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -980,49 +1261,414 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut decoded_text = String::new();
|
let stop_params = (
|
||||||
|
original_request.stop.clone(),
|
||||||
|
original_request.stop_token_ids.clone(),
|
||||||
|
original_request.skip_special_tokens,
|
||||||
|
original_request.no_stop_trim,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Spawn processing task
|
||||||
|
let self_clone = self.clone();
|
||||||
|
let original_request_clone = original_request.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let result = Self::process_streaming_chunks(
|
||||||
|
&self_clone,
|
||||||
|
&mut grpc_stream,
|
||||||
|
request_id,
|
||||||
|
model,
|
||||||
|
stop_params,
|
||||||
|
original_request_clone,
|
||||||
|
&tx,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
let error_chunk = format!(
|
||||||
|
"data: {}\n\n",
|
||||||
|
json!({
|
||||||
|
"error": {
|
||||||
|
"message": e,
|
||||||
|
"type": "internal_error"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
);
|
||||||
|
let _ = tx.send(Ok(Bytes::from(error_chunk)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send DONE marker
|
||||||
|
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create response with SSE headers
|
||||||
|
let stream = UnboundedReceiverStream::new(rx);
|
||||||
|
let mut response = Response::new(Body::from_stream(stream));
|
||||||
|
*response.status_mut() = StatusCode::OK;
|
||||||
|
response
|
||||||
|
.headers_mut()
|
||||||
|
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||||
|
response
|
||||||
|
.headers_mut()
|
||||||
|
.insert("Cache-Control", HeaderValue::from_static("no-cache"));
|
||||||
|
response
|
||||||
|
.headers_mut()
|
||||||
|
.insert("Connection", HeaderValue::from_static("keep-alive"));
|
||||||
|
response
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process streaming chunks and send SSE events
|
||||||
|
async fn process_streaming_chunks(
|
||||||
|
router: &GrpcRouter,
|
||||||
|
grpc_stream: &mut (impl tokio_stream::Stream<Item = Result<proto::GenerateResponse, tonic::Status>>
|
||||||
|
+ Unpin),
|
||||||
|
request_id: String,
|
||||||
|
model: String,
|
||||||
|
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
|
||||||
|
original_request: ChatCompletionRequest,
|
||||||
|
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// Extract request parameters
|
||||||
|
let separate_reasoning = original_request.separate_reasoning;
|
||||||
|
let tool_choice = &original_request.tool_choice;
|
||||||
|
let tools = &original_request.tools;
|
||||||
|
let history_tool_calls_count = Self::get_history_tool_calls_count(&original_request);
|
||||||
|
let stream_options = &original_request.stream_options;
|
||||||
|
|
||||||
|
// Phase 1: Initialize state tracking (per-index for n>1 support)
|
||||||
|
let mut is_firsts: HashMap<u32, bool> = HashMap::new();
|
||||||
|
let mut stream_buffers: HashMap<u32, String> = HashMap::new();
|
||||||
|
let mut finish_reasons: HashMap<u32, String> = HashMap::new();
|
||||||
|
let mut matched_stops: HashMap<u32, Option<Value>> = HashMap::new();
|
||||||
|
let mut prompt_tokens: HashMap<u32, u32> = HashMap::new();
|
||||||
|
let mut completion_tokens: HashMap<u32, u32> = HashMap::new();
|
||||||
|
let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
|
||||||
|
|
||||||
|
// Parser state (lazy initialization per index)
|
||||||
|
type PooledReasoningParser =
|
||||||
|
Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>;
|
||||||
|
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
|
||||||
|
|
||||||
|
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>;
|
||||||
|
let mut tool_parsers: HashMap<u32, PooledToolParser> = HashMap::new();
|
||||||
|
let mut has_tool_calls: HashMap<u32, bool> = HashMap::new();
|
||||||
|
|
||||||
|
// Create stop decoder
|
||||||
|
let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params;
|
||||||
|
let mut stop_decoder = router.create_stop_decoder(
|
||||||
|
stop.as_ref(),
|
||||||
|
stop_token_ids.as_ref(),
|
||||||
|
skip_special_tokens,
|
||||||
|
no_stop_trim,
|
||||||
|
);
|
||||||
|
|
||||||
|
let created = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs();
|
||||||
|
|
||||||
|
// Phase 2: Main streaming loop
|
||||||
while let Some(response) = grpc_stream.next().await {
|
while let Some(response) = grpc_stream.next().await {
|
||||||
let gen_response = match response {
|
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
|
||||||
Ok(resp) => resp,
|
|
||||||
Err(e) => {
|
|
||||||
error!("Stream error: {}", e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match gen_response.response {
|
match gen_response.response {
|
||||||
Some(Chunk(chunk)) => {
|
Some(Chunk(chunk)) => {
|
||||||
// Process tokens and check if we should stop
|
let index = chunk.index;
|
||||||
let (chunk_text, should_stop) =
|
|
||||||
|
// Process tokens through stop decoder
|
||||||
|
let (chunk_text, _should_stop) =
|
||||||
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
|
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
|
||||||
decoded_text.push_str(&chunk_text);
|
|
||||||
if should_stop {
|
if chunk_text.is_empty() {
|
||||||
break;
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process logprobs if present
|
||||||
|
let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs {
|
||||||
|
match router.convert_proto_to_openai_logprobs(proto_logprobs) {
|
||||||
|
Ok(logprobs) => Some(logprobs),
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to process logprobs: {}", e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initialize stream buffer if first time
|
||||||
|
let stream_buffer = stream_buffers.entry(index).or_default();
|
||||||
|
|
||||||
|
// Send first chunk with role
|
||||||
|
if is_firsts.get(&index).copied().unwrap_or(true) {
|
||||||
|
let first_chunk = ChatCompletionStreamResponse {
|
||||||
|
id: request_id.clone(),
|
||||||
|
object: "chat.completion.chunk".to_string(),
|
||||||
|
created,
|
||||||
|
model: model.clone(),
|
||||||
|
system_fingerprint: None,
|
||||||
|
choices: vec![ChatStreamChoice {
|
||||||
|
index,
|
||||||
|
delta: ChatMessageDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: None,
|
||||||
|
tool_calls: None,
|
||||||
|
reasoning_content: None,
|
||||||
|
},
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: None,
|
||||||
|
matched_stop: None,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
};
|
||||||
|
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk))))
|
||||||
|
.map_err(|_| "Failed to send first chunk".to_string())?;
|
||||||
|
is_firsts.insert(index, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate delta
|
||||||
|
let mut delta = chunk_text;
|
||||||
|
stream_buffer.push_str(&delta);
|
||||||
|
|
||||||
|
// Reasoning content handling
|
||||||
|
if separate_reasoning {
|
||||||
|
let (normal_text, reasoning_chunk) = router.process_reasoning_stream(
|
||||||
|
&delta,
|
||||||
|
index,
|
||||||
|
&mut reasoning_parsers,
|
||||||
|
&request_id,
|
||||||
|
&model,
|
||||||
|
created,
|
||||||
|
);
|
||||||
|
if let Some(chunk) = reasoning_chunk {
|
||||||
|
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
|
||||||
|
.map_err(|_| "Failed to send reasoning chunk".to_string())?;
|
||||||
|
}
|
||||||
|
delta = normal_text;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool call handling
|
||||||
|
let tool_choice_enabled =
|
||||||
|
!matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None)));
|
||||||
|
|
||||||
|
if tool_choice_enabled && tools.is_some() {
|
||||||
|
let (should_skip, tool_chunks) = router
|
||||||
|
.process_tool_calls_stream(
|
||||||
|
&delta,
|
||||||
|
index,
|
||||||
|
&mut tool_parsers,
|
||||||
|
&mut has_tool_calls,
|
||||||
|
tools.as_ref().unwrap(),
|
||||||
|
&request_id,
|
||||||
|
&model,
|
||||||
|
created,
|
||||||
|
history_tool_calls_count,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
for chunk in tool_chunks {
|
||||||
|
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
|
||||||
|
.map_err(|_| "Failed to send tool call chunk".to_string())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if should_skip {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular content emission
|
||||||
|
if !delta.is_empty() {
|
||||||
|
let content_chunk = Self::create_content_chunk(
|
||||||
|
delta,
|
||||||
|
index,
|
||||||
|
&request_id,
|
||||||
|
&model,
|
||||||
|
created,
|
||||||
|
choice_logprobs,
|
||||||
|
);
|
||||||
|
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk))))
|
||||||
|
.map_err(|_| "Failed to send content chunk".to_string())?;
|
||||||
}
|
}
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
Some(Complete(_complete)) => {
|
Some(Complete(complete)) => {
|
||||||
// Flush any remaining text
|
// Flush any remaining text
|
||||||
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
|
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
decoded_text.push_str(&text);
|
let index = complete.index;
|
||||||
debug!("Flushed text: {}", text);
|
let stream_buffer = stream_buffers.entry(index).or_default();
|
||||||
|
stream_buffer.push_str(&text);
|
||||||
|
|
||||||
|
let content_chunk = ChatCompletionStreamResponse {
|
||||||
|
id: request_id.clone(),
|
||||||
|
object: "chat.completion.chunk".to_string(),
|
||||||
|
created,
|
||||||
|
model: model.clone(),
|
||||||
|
system_fingerprint: None,
|
||||||
|
choices: vec![ChatStreamChoice {
|
||||||
|
index,
|
||||||
|
delta: ChatMessageDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: Some(text),
|
||||||
|
tool_calls: None,
|
||||||
|
reasoning_content: None,
|
||||||
|
},
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: None,
|
||||||
|
matched_stop: None,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let sse_chunk = serde_json::to_string(&content_chunk)
|
||||||
|
.map_err(|e| format!("Failed to serialize content chunk: {}", e))?;
|
||||||
|
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
|
||||||
|
.map_err(|_| "Failed to send flushed content".to_string())?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store metadata
|
||||||
|
let index = complete.index;
|
||||||
|
prompt_tokens.insert(index, complete.prompt_tokens as u32);
|
||||||
|
completion_tokens.insert(index, complete.completion_tokens as u32);
|
||||||
|
cached_tokens.insert(index, complete.cached_tokens as u32);
|
||||||
|
finish_reasons.insert(index, complete.finish_reason.clone());
|
||||||
|
|
||||||
|
// Extract matched_stop
|
||||||
|
let matched_stop_value = match &complete.matched_stop {
|
||||||
|
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
|
||||||
|
Some(Value::Number(serde_json::Number::from(*token_id)))
|
||||||
|
}
|
||||||
|
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
|
||||||
|
Some(Value::String(stop_str.clone()))
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
matched_stops.insert(index, matched_stop_value);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Some(Error(error)) => {
|
Some(Error(error)) => {
|
||||||
error!("Generation error: {}", error.message);
|
return Err(error.message);
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
None => continue,
|
None => continue,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Replace with proper SSE streaming response
|
// Phase 3: Check unstreamed tool args
|
||||||
// For now, return the complete decoded text
|
// Check if parsers have any remaining arguments that haven't been streamed yet
|
||||||
(StatusCode::OK, format!("Decoded text: {}", decoded_text)).into_response()
|
for (index, parser) in &tool_parsers {
|
||||||
|
let parser_guard = parser.lock().await;
|
||||||
|
if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() {
|
||||||
|
for tool_call_item in unstreamed_items {
|
||||||
|
let tool_call_delta = ToolCallDelta {
|
||||||
|
index: tool_call_item.tool_index as u32,
|
||||||
|
id: None,
|
||||||
|
tool_type: None, // No type for argument deltas
|
||||||
|
function: Some(FunctionCallDelta {
|
||||||
|
name: None, // No name for argument deltas
|
||||||
|
arguments: if !tool_call_item.parameters.is_empty() {
|
||||||
|
Some(tool_call_item.parameters)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let tool_chunk = ChatCompletionStreamResponse {
|
||||||
|
id: request_id.clone(),
|
||||||
|
object: "chat.completion.chunk".to_string(),
|
||||||
|
created,
|
||||||
|
model: model.clone(),
|
||||||
|
system_fingerprint: None,
|
||||||
|
choices: vec![ChatStreamChoice {
|
||||||
|
index: *index,
|
||||||
|
delta: ChatMessageDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: None,
|
||||||
|
tool_calls: Some(vec![tool_call_delta]),
|
||||||
|
reasoning_content: None,
|
||||||
|
},
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: None,
|
||||||
|
matched_stop: None,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let sse_chunk = serde_json::to_string(&tool_chunk)
|
||||||
|
.map_err(|e| format!("Failed to serialize tool chunk: {}", e))?;
|
||||||
|
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
|
||||||
|
.map_err(|_| "Failed to send unstreamed tool args".to_string())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 4: Finish reason chunks
|
||||||
|
for (index, finish_reason) in finish_reasons.iter() {
|
||||||
|
let final_finish_reason =
|
||||||
|
if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" {
|
||||||
|
"tool_calls".to_string()
|
||||||
|
} else {
|
||||||
|
finish_reason.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone());
|
||||||
|
|
||||||
|
let finish_chunk = ChatCompletionStreamResponse {
|
||||||
|
id: request_id.clone(),
|
||||||
|
object: "chat.completion.chunk".to_string(),
|
||||||
|
created,
|
||||||
|
model: model.clone(),
|
||||||
|
system_fingerprint: None,
|
||||||
|
choices: vec![ChatStreamChoice {
|
||||||
|
index: *index,
|
||||||
|
delta: ChatMessageDelta {
|
||||||
|
role: Some("assistant".to_string()),
|
||||||
|
content: None,
|
||||||
|
tool_calls: None,
|
||||||
|
reasoning_content: None,
|
||||||
|
},
|
||||||
|
logprobs: None,
|
||||||
|
finish_reason: Some(final_finish_reason),
|
||||||
|
matched_stop: matched_stop_value,
|
||||||
|
}],
|
||||||
|
usage: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let sse_chunk = serde_json::to_string(&finish_chunk)
|
||||||
|
.map_err(|e| format!("Failed to serialize finish chunk: {}", e))?;
|
||||||
|
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
|
||||||
|
.map_err(|_| "Failed to send finish chunk".to_string())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 5: Usage chunk
|
||||||
|
if let Some(stream_opts) = stream_options {
|
||||||
|
if stream_opts.include_usage.unwrap_or(false) {
|
||||||
|
let total_prompt: u32 = prompt_tokens.values().sum();
|
||||||
|
let total_completion: u32 = completion_tokens.values().sum();
|
||||||
|
|
||||||
|
let usage_chunk = ChatCompletionStreamResponse {
|
||||||
|
id: request_id.clone(),
|
||||||
|
object: "chat.completion.chunk".to_string(),
|
||||||
|
created,
|
||||||
|
model: model.clone(),
|
||||||
|
system_fingerprint: None,
|
||||||
|
choices: vec![],
|
||||||
|
usage: Some(Usage {
|
||||||
|
prompt_tokens: total_prompt,
|
||||||
|
completion_tokens: total_completion,
|
||||||
|
total_tokens: total_prompt + total_completion,
|
||||||
|
completion_tokens_details: None,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let sse_chunk = serde_json::to_string(&usage_chunk)
|
||||||
|
.map_err(|e| format!("Failed to serialize usage chunk: {}", e))?;
|
||||||
|
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
|
||||||
|
.map_err(|_| "Failed to send usage chunk".to_string())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Submit request and handle non-streaming response for chat completions route
|
/// Submit request and handle non-streaming response for chat completions route
|
||||||
@@ -1082,10 +1728,17 @@ impl GrpcRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process each response into a ChatChoice
|
// Process each response into a ChatChoice
|
||||||
|
let history_tool_calls_count = Self::get_history_tool_calls_count(original_request);
|
||||||
let mut choices = Vec::new();
|
let mut choices = Vec::new();
|
||||||
for (index, complete) in all_responses.iter().enumerate() {
|
for (index, complete) in all_responses.iter().enumerate() {
|
||||||
match self
|
match self
|
||||||
.process_single_choice(complete, index, original_request, &mut stop_decoder)
|
.process_single_choice(
|
||||||
|
complete,
|
||||||
|
index,
|
||||||
|
original_request,
|
||||||
|
&mut stop_decoder,
|
||||||
|
history_tool_calls_count,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(choice) => choices.push(choice),
|
Ok(choice) => choices.push(choice),
|
||||||
@@ -1216,11 +1869,12 @@ impl GrpcRouter {
|
|||||||
decoded_text.push_str(&t);
|
decoded_text.push_str(&t);
|
||||||
}
|
}
|
||||||
|
|
||||||
let output_ids = complete.output_ids.clone();
|
let output_ids = std::mem::take(&mut complete.output_ids);
|
||||||
|
let finish_reason = std::mem::take(&mut complete.finish_reason);
|
||||||
|
|
||||||
// Build base meta_info using json! macro
|
// Build base meta_info using json! macro
|
||||||
let mut meta_info = json!({
|
let mut meta_info = json!({
|
||||||
"finish_reason": complete.finish_reason.clone(),
|
"finish_reason": finish_reason,
|
||||||
"prompt_tokens": complete.prompt_tokens,
|
"prompt_tokens": complete.prompt_tokens,
|
||||||
"completion_tokens": complete.completion_tokens,
|
"completion_tokens": complete.completion_tokens,
|
||||||
"cached_tokens": complete.cached_tokens,
|
"cached_tokens": complete.cached_tokens,
|
||||||
@@ -1269,9 +1923,13 @@ impl GrpcRouter {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Build ChatLogProbsContent for each token
|
// Build ChatLogProbsContent for each token (consume iterator to avoid clones)
|
||||||
for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() {
|
for (i, (&logprob, token_text)) in proto_logprobs
|
||||||
let token_text = token_texts.get(i).cloned().unwrap_or_default();
|
.token_logprobs
|
||||||
|
.iter()
|
||||||
|
.zip(token_texts.into_iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
let bytes = Some(token_text.as_bytes().to_vec());
|
let bytes = Some(token_text.as_bytes().to_vec());
|
||||||
|
|
||||||
// Build top_logprobs for this position
|
// Build top_logprobs for this position
|
||||||
@@ -1324,6 +1982,7 @@ impl GrpcRouter {
|
|||||||
index: usize,
|
index: usize,
|
||||||
original_request: &ChatCompletionRequest,
|
original_request: &ChatCompletionRequest,
|
||||||
stop_decoder: &mut StopSequenceDecoder,
|
stop_decoder: &mut StopSequenceDecoder,
|
||||||
|
history_tool_calls_count: usize,
|
||||||
) -> Result<ChatChoice, String> {
|
) -> Result<ChatChoice, String> {
|
||||||
stop_decoder.reset();
|
stop_decoder.reset();
|
||||||
// Decode tokens
|
// Decode tokens
|
||||||
@@ -1401,7 +2060,11 @@ impl GrpcRouter {
|
|||||||
self.parse_json_schema_response(&processed_text, &original_request.tool_choice);
|
self.parse_json_schema_response(&processed_text, &original_request.tool_choice);
|
||||||
} else {
|
} else {
|
||||||
(tool_calls, processed_text) = self
|
(tool_calls, processed_text) = self
|
||||||
.parse_with_model_parser(&processed_text, &original_request.model)
|
.parse_tool_calls(
|
||||||
|
&processed_text,
|
||||||
|
&original_request.model,
|
||||||
|
history_tool_calls_count,
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1686,7 +2349,6 @@ mod tests {
|
|||||||
content: Some("Assistant response".to_string()),
|
content: Some("Assistant response".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
function_call: None,
|
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ use crate::{
|
|||||||
},
|
},
|
||||||
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
|
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
|
||||||
},
|
},
|
||||||
reasoning_parser::ParserFactory,
|
reasoning_parser::ReasoningParserFactory,
|
||||||
routers::{router_manager::RouterManager, RouterTrait},
|
routers::{router_manager::RouterManager, RouterTrait},
|
||||||
service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
|
service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
|
||||||
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
|
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
|
||||||
@@ -45,7 +45,7 @@ pub struct AppContext {
|
|||||||
pub router_config: RouterConfig,
|
pub router_config: RouterConfig,
|
||||||
pub rate_limiter: Arc<TokenBucket>,
|
pub rate_limiter: Arc<TokenBucket>,
|
||||||
pub tokenizer: Option<Arc<dyn Tokenizer>>,
|
pub tokenizer: Option<Arc<dyn Tokenizer>>,
|
||||||
pub reasoning_parser_factory: Option<ParserFactory>,
|
pub reasoning_parser_factory: Option<ReasoningParserFactory>,
|
||||||
pub tool_parser_factory: Option<ToolParserFactory>,
|
pub tool_parser_factory: Option<ToolParserFactory>,
|
||||||
pub worker_registry: Arc<WorkerRegistry>,
|
pub worker_registry: Arc<WorkerRegistry>,
|
||||||
pub policy_registry: Arc<PolicyRegistry>,
|
pub policy_registry: Arc<PolicyRegistry>,
|
||||||
@@ -79,7 +79,7 @@ impl AppContext {
|
|||||||
tokenizer_factory::create_tokenizer(&tokenizer_path)
|
tokenizer_factory::create_tokenizer(&tokenizer_path)
|
||||||
.map_err(|e| format!("Failed to create tokenizer: {e}"))?,
|
.map_err(|e| format!("Failed to create tokenizer: {e}"))?,
|
||||||
);
|
);
|
||||||
let reasoning_parser_factory = Some(ParserFactory::new());
|
let reasoning_parser_factory = Some(ReasoningParserFactory::new());
|
||||||
let tool_parser_factory = Some(ToolParserFactory::new());
|
let tool_parser_factory = Some(ToolParserFactory::new());
|
||||||
|
|
||||||
(tokenizer, reasoning_parser_factory, tool_parser_factory)
|
(tokenizer, reasoning_parser_factory, tool_parser_factory)
|
||||||
|
|||||||
@@ -123,12 +123,7 @@ impl DeepSeekParser {
|
|||||||
let arguments = serde_json::to_string(&args)
|
let arguments = serde_json::to_string(&args)
|
||||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||||
|
|
||||||
// Generate ID
|
|
||||||
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
|
|
||||||
|
|
||||||
Ok(ToolCall {
|
Ok(ToolCall {
|
||||||
id,
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: func_name.to_string(),
|
name: func_name.to_string(),
|
||||||
arguments,
|
arguments,
|
||||||
@@ -320,4 +315,8 @@ impl ToolParser for DeepSeekParser {
|
|||||||
fn detect_format(&self, text: &str) -> bool {
|
fn detect_format(&self, text: &str) -> bool {
|
||||||
self.has_tool_markers(text)
|
self.has_tool_markers(text)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
|
||||||
|
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -129,12 +129,7 @@ impl Glm4MoeParser {
|
|||||||
let arguments_str = serde_json::to_string(&arguments)
|
let arguments_str = serde_json::to_string(&arguments)
|
||||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||||
|
|
||||||
// Generate ID
|
|
||||||
let id = format!("glm4_call_{}", uuid::Uuid::new_v4());
|
|
||||||
|
|
||||||
Ok(Some(ToolCall {
|
Ok(Some(ToolCall {
|
||||||
id,
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: func_name.to_string(),
|
name: func_name.to_string(),
|
||||||
arguments: arguments_str,
|
arguments: arguments_str,
|
||||||
@@ -321,4 +316,8 @@ impl ToolParser for Glm4MoeParser {
|
|||||||
fn detect_format(&self, text: &str) -> bool {
|
fn detect_format(&self, text: &str) -> bool {
|
||||||
self.has_tool_markers(text)
|
self.has_tool_markers(text)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
|
||||||
|
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -113,12 +113,7 @@ impl ToolParser for GptOssParser {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Generate unique ID
|
|
||||||
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
|
|
||||||
|
|
||||||
tools.push(ToolCall {
|
tools.push(ToolCall {
|
||||||
id,
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: function_name,
|
name: function_name,
|
||||||
arguments,
|
arguments,
|
||||||
|
|||||||
@@ -14,6 +14,48 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get unstreamed tool call arguments
|
||||||
|
/// Returns tool call items for arguments that have been parsed but not yet streamed
|
||||||
|
/// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk
|
||||||
|
pub fn get_unstreamed_args(
|
||||||
|
prev_tool_call_arr: &[Value],
|
||||||
|
streamed_args_for_tool: &[String],
|
||||||
|
) -> Option<Vec<ToolCallItem>> {
|
||||||
|
// Check if we have tool calls being tracked
|
||||||
|
if prev_tool_call_arr.is_empty() || streamed_args_for_tool.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the last tool call that was being processed
|
||||||
|
let tool_index = prev_tool_call_arr.len() - 1;
|
||||||
|
if tool_index >= streamed_args_for_tool.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get expected vs actual arguments
|
||||||
|
let expected_args = prev_tool_call_arr[tool_index].get("arguments")?;
|
||||||
|
let expected_str = serde_json::to_string(expected_args).ok()?;
|
||||||
|
let actual_str = &streamed_args_for_tool[tool_index];
|
||||||
|
|
||||||
|
// Check if there are remaining arguments to send
|
||||||
|
let remaining = if expected_str.starts_with(actual_str) {
|
||||||
|
&expected_str[actual_str.len()..]
|
||||||
|
} else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
if remaining.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the remaining arguments as a ToolCallItem
|
||||||
|
Some(vec![ToolCallItem {
|
||||||
|
tool_index,
|
||||||
|
name: None, // No name for argument deltas
|
||||||
|
parameters: remaining.to_string(),
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
|
||||||
/// Check if a buffer ends with a partial occurrence of a token
|
/// Check if a buffer ends with a partial occurrence of a token
|
||||||
/// Returns Some(length) if there's a partial match, None otherwise
|
/// Returns Some(length) if there's a partial match, None otherwise
|
||||||
pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
|
pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ use crate::tool_parser::{
|
|||||||
parsers::helpers,
|
parsers::helpers,
|
||||||
partial_json::PartialJson,
|
partial_json::PartialJson,
|
||||||
traits::ToolParser,
|
traits::ToolParser,
|
||||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// JSON format parser for tool calls
|
/// JSON format parser for tool calls
|
||||||
@@ -136,16 +136,7 @@ impl JsonParser {
|
|||||||
let arguments = serde_json::to_string(args)
|
let arguments = serde_json::to_string(args)
|
||||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||||
|
|
||||||
// Generate a unique ID if not provided
|
|
||||||
let id = obj
|
|
||||||
.get("id")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.map(String::from)
|
|
||||||
.unwrap_or_else(|| format!("call_{}", uuid::Uuid::new_v4()));
|
|
||||||
|
|
||||||
Ok(Some(ToolCall {
|
Ok(Some(ToolCall {
|
||||||
id,
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
arguments,
|
arguments,
|
||||||
@@ -274,4 +265,8 @@ impl ToolParser for JsonParser {
|
|||||||
let trimmed = text.trim();
|
let trimmed = text.trim();
|
||||||
(trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#)
|
(trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
|
||||||
|
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -131,12 +131,7 @@ impl ToolParser for KimiK2Parser {
|
|||||||
// Try to parse JSON arguments
|
// Try to parse JSON arguments
|
||||||
match serde_json::from_str::<serde_json::Value>(function_args) {
|
match serde_json::from_str::<serde_json::Value>(function_args) {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
// Generate unique ID
|
|
||||||
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
|
|
||||||
|
|
||||||
tools.push(ToolCall {
|
tools.push(ToolCall {
|
||||||
id,
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: func_name,
|
name: func_name,
|
||||||
arguments: function_args.to_string(),
|
arguments: function_args.to_string(),
|
||||||
@@ -339,4 +334,8 @@ impl ToolParser for KimiK2Parser {
|
|||||||
fn detect_format(&self, text: &str) -> bool {
|
fn detect_format(&self, text: &str) -> bool {
|
||||||
self.has_tool_markers(text) || text.contains("<|tool_call_begin|>")
|
self.has_tool_markers(text) || text.contains("<|tool_call_begin|>")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
|
||||||
|
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use uuid;
|
|
||||||
|
|
||||||
use crate::protocols::spec::Tool;
|
use crate::protocols::spec::Tool;
|
||||||
|
|
||||||
@@ -84,16 +83,7 @@ impl LlamaParser {
|
|||||||
let arguments = serde_json::to_string(parameters)
|
let arguments = serde_json::to_string(parameters)
|
||||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||||
|
|
||||||
// Generate a unique ID for Llama calls
|
|
||||||
let id = obj
|
|
||||||
.get("id")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.map(String::from)
|
|
||||||
.unwrap_or_else(|| format!("llama_call_{}", uuid::Uuid::new_v4()));
|
|
||||||
|
|
||||||
Ok(Some(ToolCall {
|
Ok(Some(ToolCall {
|
||||||
id,
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
arguments,
|
arguments,
|
||||||
@@ -243,4 +233,8 @@ impl ToolParser for LlamaParser {
|
|||||||
text.contains("<|python_tag|>")
|
text.contains("<|python_tag|>")
|
||||||
|| (text.trim_start().starts_with('{') && text.contains(r#""name""#))
|
|| (text.trim_start().starts_with('{') && text.contains(r#""name""#))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
|
||||||
|
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,16 +146,7 @@ impl MistralParser {
|
|||||||
let arguments = serde_json::to_string(args)
|
let arguments = serde_json::to_string(args)
|
||||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||||
|
|
||||||
// Generate unique ID
|
|
||||||
let id = obj
|
|
||||||
.get("id")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.map(String::from)
|
|
||||||
.unwrap_or_else(|| format!("mistral_call_{}", uuid::Uuid::new_v4()));
|
|
||||||
|
|
||||||
Ok(Some(ToolCall {
|
Ok(Some(ToolCall {
|
||||||
id,
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
arguments,
|
arguments,
|
||||||
@@ -266,4 +257,8 @@ impl ToolParser for MistralParser {
|
|||||||
fn detect_format(&self, text: &str) -> bool {
|
fn detect_format(&self, text: &str) -> bool {
|
||||||
self.has_tool_markers(text)
|
self.has_tool_markers(text)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
|
||||||
|
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -244,7 +244,7 @@ fn parse_python_expression(source: &str) -> ToolParserResult<Expr> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_tool_call(expr: Expr, index: usize) -> ToolParserResult<ToolCall> {
|
fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> {
|
||||||
match expr {
|
match expr {
|
||||||
Expr::Call(call_expr) => {
|
Expr::Call(call_expr) => {
|
||||||
if !call_expr.args.is_empty() {
|
if !call_expr.args.is_empty() {
|
||||||
@@ -277,8 +277,6 @@ fn build_tool_call(expr: Expr, index: usize) -> ToolParserResult<ToolCall> {
|
|||||||
let arguments_string = serde_json::to_string(&arguments_json)?;
|
let arguments_string = serde_json::to_string(&arguments_json)?;
|
||||||
|
|
||||||
Ok(ToolCall {
|
Ok(ToolCall {
|
||||||
id: format!("call-{}", index + 1),
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: function_name,
|
name: function_name,
|
||||||
arguments: arguments_string,
|
arguments: arguments_string,
|
||||||
|
|||||||
@@ -88,16 +88,7 @@ impl QwenParser {
|
|||||||
let arguments = serde_json::to_string(args)
|
let arguments = serde_json::to_string(args)
|
||||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||||
|
|
||||||
// Generate unique ID
|
|
||||||
let id = obj
|
|
||||||
.get("id")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.map(String::from)
|
|
||||||
.unwrap_or_else(|| format!("qwen_call_{}", uuid::Uuid::new_v4()));
|
|
||||||
|
|
||||||
Ok(Some(ToolCall {
|
Ok(Some(ToolCall {
|
||||||
id,
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
arguments,
|
arguments,
|
||||||
@@ -255,4 +246,8 @@ impl ToolParser for QwenParser {
|
|||||||
fn detect_format(&self, text: &str) -> bool {
|
fn detect_format(&self, text: &str) -> bool {
|
||||||
self.has_tool_markers(text)
|
self.has_tool_markers(text)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
|
||||||
|
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -400,12 +400,7 @@ impl Step3Parser {
|
|||||||
let arguments_str = serde_json::to_string(¶meters)
|
let arguments_str = serde_json::to_string(¶meters)
|
||||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||||
|
|
||||||
// Generate ID
|
|
||||||
let id = format!("step3_call_{}", uuid::Uuid::new_v4());
|
|
||||||
|
|
||||||
Ok(Some(ToolCall {
|
Ok(Some(ToolCall {
|
||||||
id,
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: func_name.to_string(),
|
name: func_name.to_string(),
|
||||||
arguments: arguments_str,
|
arguments: arguments_str,
|
||||||
@@ -561,4 +556,8 @@ impl ToolParser for Step3Parser {
|
|||||||
fn detect_format(&self, text: &str) -> bool {
|
fn detect_format(&self, text: &str) -> bool {
|
||||||
self.has_tool_markers(text)
|
self.has_tool_markers(text)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
|
||||||
|
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,8 +31,6 @@ async fn test_tool_parser_factory_model_mapping() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_tool_call_serialization() {
|
fn test_tool_call_serialization() {
|
||||||
let tool_call = ToolCall {
|
let tool_call = ToolCall {
|
||||||
id: "call-123".to_string(),
|
|
||||||
r#type: "function".to_string(),
|
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: "search".to_string(),
|
name: "search".to_string(),
|
||||||
arguments: r#"{"query": "rust programming"}"#.to_string(),
|
arguments: r#"{"query": "rust programming"}"#.to_string(),
|
||||||
@@ -40,13 +38,15 @@ fn test_tool_call_serialization() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let json = serde_json::to_string(&tool_call).unwrap();
|
let json = serde_json::to_string(&tool_call).unwrap();
|
||||||
assert!(json.contains("call-123"));
|
|
||||||
assert!(json.contains("search"));
|
assert!(json.contains("search"));
|
||||||
assert!(json.contains("rust programming"));
|
assert!(json.contains("rust programming"));
|
||||||
|
|
||||||
let parsed: ToolCall = serde_json::from_str(&json).unwrap();
|
let parsed: ToolCall = serde_json::from_str(&json).unwrap();
|
||||||
assert_eq!(parsed.id, "call-123");
|
|
||||||
assert_eq!(parsed.function.name, "search");
|
assert_eq!(parsed.function.name, "search");
|
||||||
|
assert_eq!(
|
||||||
|
parsed.function.arguments,
|
||||||
|
r#"{"query": "rust programming"}"#
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -32,6 +32,12 @@ pub trait ToolParser: Send + Sync {
|
|||||||
fn as_token_parser(&self) -> Option<&dyn TokenToolParser> {
|
fn as_token_parser(&self) -> Option<&dyn TokenToolParser> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get unstreamed tool call arguments
|
||||||
|
/// Returns tool call items for arguments that have been parsed but not yet streamed
|
||||||
|
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for partial JSON parsing
|
/// Trait for partial JSON parsing
|
||||||
|
|||||||
@@ -1,13 +1,8 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Parsed tool call from model output (OpenAI format)
|
/// Parsed tool call from model output
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct ToolCall {
|
pub struct ToolCall {
|
||||||
/// Unique identifier for the tool call
|
|
||||||
pub id: String,
|
|
||||||
/// Type of tool call (currently always "function")
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub r#type: String,
|
|
||||||
/// Function call details
|
/// Function call details
|
||||||
pub function: FunctionCall,
|
pub function: FunctionCall,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -181,7 +181,6 @@ fn test_chatml_template() {
|
|||||||
content: Some("Hi there!".to_string()),
|
content: Some("Hi there!".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
function_call: None,
|
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
},
|
},
|
||||||
spec::ChatMessage::User {
|
spec::ChatMessage::User {
|
||||||
|
|||||||
@@ -68,7 +68,6 @@ mod tests {
|
|||||||
content: Some("Hi there".to_string()),
|
content: Some("Hi there".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
function_call: None,
|
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
@@ -213,7 +212,6 @@ mod tests {
|
|||||||
content: Some("World".to_string()),
|
content: Some("World".to_string()),
|
||||||
name: None,
|
name: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
function_call: None,
|
|
||||||
reasoning_content: None,
|
reasoning_content: None,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|||||||
Reference in New Issue
Block a user