[router] add grpc router pd mode for chat and generate (#11140)
This commit is contained in:
@@ -67,8 +67,8 @@ dependencies = [
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"xgrammar==0.1.24",
|
||||
"grpcio==1.74.0", # keep it align with compile_proto.py
|
||||
"grpcio-tools==1.74.0" # keep it align with compile_proto.py
|
||||
"grpcio==1.75.1", # keep it align with compile_proto.py
|
||||
"grpcio-tools==1.75.1" # keep it align with compile_proto.py
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -19,7 +19,6 @@ import grpc
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOutput,
|
||||
@@ -111,6 +110,7 @@ class GrpcRequestManager:
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
bootstrap_server=None,
|
||||
):
|
||||
"""Initialize the gRPC request manager."""
|
||||
self.server_args = server_args
|
||||
@@ -147,8 +147,8 @@ class GrpcRequestManager:
|
||||
self.crash_dump_request_list = []
|
||||
self.crash_dump_performed = False
|
||||
|
||||
# Bootstrap server for disaggregation mode
|
||||
self.bootstrap_server = start_disagg_service(server_args)
|
||||
# Bootstrap server (passed from serve_grpc, not started here)
|
||||
self.bootstrap_server = bootstrap_server
|
||||
|
||||
logger.info(
|
||||
f"GrpcRequestManager initialized with ZMQ IPC: "
|
||||
@@ -157,7 +157,7 @@ class GrpcRequestManager:
|
||||
)
|
||||
if self.bootstrap_server:
|
||||
logger.info(
|
||||
f"Bootstrap server started for disaggregation mode: "
|
||||
f"Bootstrap server initialized for disaggregation mode: "
|
||||
f"{server_args.disaggregation_mode}"
|
||||
)
|
||||
|
||||
|
||||
@@ -16,11 +16,13 @@ from typing import AsyncIterator, Dict, Optional, Tuple
|
||||
import grpc
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
|
||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
||||
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
|
||||
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
|
||||
from sglang.srt.managers.data_parallel_controller import (
|
||||
run_data_parallel_controller_process,
|
||||
)
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -331,6 +333,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
token_ids_logprob=None,
|
||||
)
|
||||
|
||||
if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
|
||||
health_request.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
||||
health_request.bootstrap_room = 0
|
||||
|
||||
logger.info(f"Sending health check request to request manager...")
|
||||
|
||||
# Submit and wait for response
|
||||
@@ -406,6 +412,15 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
# Convert sampling params
|
||||
sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
|
||||
|
||||
# Extract disaggregated params if present
|
||||
bootstrap_host = None
|
||||
bootstrap_port = None
|
||||
bootstrap_room = None
|
||||
if grpc_req.HasField("disaggregated_params"):
|
||||
bootstrap_host = grpc_req.disaggregated_params.bootstrap_host or None
|
||||
bootstrap_port = grpc_req.disaggregated_params.bootstrap_port or None
|
||||
bootstrap_room = grpc_req.disaggregated_params.bootstrap_room or None
|
||||
|
||||
# Create request
|
||||
return TokenizedGenerateReqInput(
|
||||
rid=grpc_req.request_id,
|
||||
@@ -425,6 +440,9 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
|
||||
token_ids_logprob=(
|
||||
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
|
||||
),
|
||||
bootstrap_host=bootstrap_host,
|
||||
bootstrap_port=bootstrap_port,
|
||||
bootstrap_room=bootstrap_room,
|
||||
)
|
||||
|
||||
def _convert_embed_request(
|
||||
@@ -659,6 +677,16 @@ async def serve_grpc(
|
||||
):
|
||||
"""Start the standalone gRPC server with integrated scheduler."""
|
||||
|
||||
# Start bootstrap server BEFORE launching scheduler processes (only in PREFILL mode)
|
||||
# This ensures the bootstrap server is ready when prefill schedulers try to register
|
||||
bootstrap_server = None
|
||||
if server_args.disaggregation_mode == "prefill":
|
||||
bootstrap_server = start_disagg_service(server_args)
|
||||
if bootstrap_server:
|
||||
logger.info(
|
||||
f"Bootstrap server started for disaggregation mode on {server_args.host}:{server_args.disaggregation_bootstrap_port}"
|
||||
)
|
||||
|
||||
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
|
||||
logger.info("Launching scheduler process(es)...")
|
||||
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
|
||||
@@ -682,9 +710,11 @@ async def serve_grpc(
|
||||
}
|
||||
|
||||
# Create request manager with the correct port args
|
||||
# Note: We pass None for bootstrap_server since it's already started above
|
||||
request_manager = GrpcRequestManager(
|
||||
server_args=server_args,
|
||||
port_args=port_args,
|
||||
bootstrap_server=bootstrap_server,
|
||||
)
|
||||
|
||||
# Create gRPC server
|
||||
@@ -764,79 +794,9 @@ def main():
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
|
||||
|
||||
# Server arguments
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
||||
parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--model-path", type=str, required=True, help="Model path")
|
||||
parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
|
||||
parser.add_argument("--context-length", type=int, help="Context length")
|
||||
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
|
||||
|
||||
# Runtime arguments
|
||||
parser.add_argument(
|
||||
"--max-running-requests", type=int, default=2048, help="Max concurrent requests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-total-tokens", type=int, default=1000000, help="Max total tokens"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-backend", type=str, default="flashinfer", help="Attention backend"
|
||||
)
|
||||
parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
|
||||
|
||||
# Logging
|
||||
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
|
||||
|
||||
# Disaggregation mode arguments
|
||||
parser.add_argument(
|
||||
"--disaggregation-mode",
|
||||
type=str,
|
||||
default="null",
|
||||
choices=["null", "prefill", "decode"],
|
||||
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-transfer-backend",
|
||||
type=str,
|
||||
default="mooncake",
|
||||
choices=["mooncake", "nixl", "ascend", "fake"],
|
||||
help="The backend for disaggregation transfer. Default is mooncake.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-bootstrap-port",
|
||||
type=int,
|
||||
default=8998,
|
||||
help="Bootstrap server port on the prefill server. Default is 8998.",
|
||||
)
|
||||
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert to ServerArgs with gRPC host and port
|
||||
server_args = ServerArgs(
|
||||
model_path=args.model_path,
|
||||
tokenizer_path=args.tokenizer_path or args.model_path,
|
||||
context_length=args.context_length,
|
||||
tp_size=args.tp_size,
|
||||
dp_size=args.dp_size,
|
||||
max_running_requests=args.max_running_requests,
|
||||
max_total_tokens=args.max_total_tokens,
|
||||
max_prefill_tokens=args.max_prefill_tokens,
|
||||
attention_backend=args.attention_backend,
|
||||
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
|
||||
log_level=args.log_level,
|
||||
disaggregation_mode=args.disaggregation_mode,
|
||||
disaggregation_transfer_backend=args.disaggregation_transfer_backend,
|
||||
disaggregation_bootstrap_port=args.disaggregation_bootstrap_port,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
# Run server
|
||||
asyncio.run(
|
||||
|
||||
@@ -31,6 +31,18 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
/// Get the worker's connection mode (HTTP or gRPC)
|
||||
fn connection_mode(&self) -> ConnectionMode;
|
||||
|
||||
/// Get the bootstrap hostname for PD mode
|
||||
/// Returns cached hostname parsed from URL at construction time
|
||||
fn bootstrap_host(&self) -> &str {
|
||||
&self.metadata().bootstrap_host
|
||||
}
|
||||
|
||||
/// Get the bootstrap port for PD mode
|
||||
/// Returns cached port from WorkerType::Prefill
|
||||
fn bootstrap_port(&self) -> Option<u16> {
|
||||
self.metadata().bootstrap_port
|
||||
}
|
||||
|
||||
/// Check if the worker is currently healthy
|
||||
fn is_healthy(&self) -> bool;
|
||||
|
||||
@@ -147,21 +159,6 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
true
|
||||
}
|
||||
|
||||
// TODO: - Enhanced Worker Discovery
|
||||
// The Worker trait should handle async discovery of metadata from the worker itself
|
||||
// rather than having service discovery or other components query /get_server_info.
|
||||
// This keeps service discovery decoupled from worker-specific APIs.
|
||||
//
|
||||
// Proposed additions:
|
||||
// - async fn discover_metadata(&mut self) -> Result<(), Error>
|
||||
// Query /get_server_info and populate metadata labels with model_id, priority, cost, etc.
|
||||
// - async fn validate_configuration(&self) -> Result<(), Error>
|
||||
// Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC)
|
||||
// - Make worker creation async to allow metadata discovery during initialization
|
||||
//
|
||||
// This way service discovery just calls router.add_worker() and the worker
|
||||
// handles its own metadata discovery internally.
|
||||
|
||||
/// Get the model ID this worker serves
|
||||
fn model_id(&self) -> &str {
|
||||
self.metadata()
|
||||
@@ -325,6 +322,10 @@ pub struct WorkerMetadata {
|
||||
pub health_config: HealthConfig,
|
||||
/// API key
|
||||
pub api_key: Option<String>,
|
||||
/// Cached bootstrap hostname (parsed from URL at construction time)
|
||||
pub bootstrap_host: String,
|
||||
/// Cached bootstrap port (from WorkerType::Prefill)
|
||||
pub bootstrap_port: Option<u16>,
|
||||
}
|
||||
|
||||
/// Basic worker implementation
|
||||
|
||||
@@ -96,12 +96,29 @@ impl BasicWorkerBuilder {
|
||||
|
||||
/// Build the BasicWorker instance
|
||||
pub fn build(self) -> BasicWorker {
|
||||
use std::borrow::Cow;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, AtomicUsize},
|
||||
Arc,
|
||||
};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
let url_to_parse = if self.url.contains("://") {
|
||||
Cow::from(&self.url)
|
||||
} else {
|
||||
Cow::from(format!("http://{}", self.url))
|
||||
};
|
||||
|
||||
let bootstrap_host = match url::Url::parse(&url_to_parse) {
|
||||
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
|
||||
Err(_) => "localhost".to_string(),
|
||||
};
|
||||
|
||||
let bootstrap_port = match self.worker_type {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let metadata = WorkerMetadata {
|
||||
url: self.url.clone(),
|
||||
api_key: self.api_key,
|
||||
@@ -109,6 +126,8 @@ impl BasicWorkerBuilder {
|
||||
connection_mode: self.connection_mode,
|
||||
labels: self.labels,
|
||||
health_config: self.health_config,
|
||||
bootstrap_host,
|
||||
bootstrap_port,
|
||||
};
|
||||
|
||||
let grpc_client = Arc::new(RwLock::new(
|
||||
|
||||
@@ -342,6 +342,12 @@ impl SglangSchedulerClient {
|
||||
.map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?;
|
||||
}
|
||||
|
||||
// Handle n with conversion
|
||||
if let Some(n) = p.n {
|
||||
sampling.n = i32::try_from(n)
|
||||
.map_err(|_| "n must fit into a 32-bit signed integer".to_string())?;
|
||||
}
|
||||
|
||||
// Handle constraints (exactly one allowed)
|
||||
sampling.constraint = Self::build_single_constraint_from_plain(p)?;
|
||||
|
||||
|
||||
@@ -2,6 +2,11 @@ use serde::{Deserialize, Serialize};
|
||||
use serde_json::{to_value, Map, Number, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// Default model value when not specified
|
||||
fn default_model() -> String {
|
||||
"unknown".to_string()
|
||||
}
|
||||
|
||||
// # Protocol Specifications
|
||||
//
|
||||
// This module contains all protocol definitions for OpenAI and SGLang APIs.
|
||||
@@ -169,6 +174,7 @@ pub struct ChatCompletionRequest {
|
||||
pub messages: Vec<ChatMessage>,
|
||||
|
||||
/// ID of the model to use
|
||||
#[serde(default = "default_model")]
|
||||
pub model: String,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
||||
|
||||
@@ -1,4 +1,16 @@
|
||||
//! gRPC router implementations
|
||||
|
||||
use crate::grpc_client::proto;
|
||||
use crate::protocols::spec::StringOrArray;
|
||||
|
||||
pub mod pd_router;
|
||||
pub mod router;
|
||||
pub mod utils;
|
||||
|
||||
/// Processed chat messages ready for gRPC generation
|
||||
#[derive(Debug)]
|
||||
pub struct ProcessedMessages {
|
||||
pub text: String,
|
||||
pub multimodal_inputs: Option<proto::MultimodalInputs>,
|
||||
pub stop_sequences: Option<StringOrArray>,
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
843
sgl-router/src/routers/grpc/utils.rs
Normal file
843
sgl-router/src/routers/grpc/utils.rs
Normal file
@@ -0,0 +1,843 @@
|
||||
//! Shared utilities for gRPC routers
|
||||
|
||||
use super::ProcessedMessages;
|
||||
use crate::core::Worker;
|
||||
use crate::grpc_client::{proto, SglangSchedulerClient};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, ChatMessage, FunctionCallResponse, StringOrArray, Tool, ToolCall,
|
||||
ToolChoice, ToolChoiceValue,
|
||||
};
|
||||
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tokenizer::HuggingFaceTokenizer;
|
||||
pub use crate::tokenizer::StopSequenceDecoder;
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use futures::StreamExt;
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tonic::codec::Streaming;
|
||||
use tracing::{debug, error};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Get gRPC client from worker, returning appropriate error response on failure
|
||||
pub async fn get_grpc_client_from_worker(
|
||||
worker: &Arc<dyn Worker>,
|
||||
) -> Result<SglangSchedulerClient, Response> {
|
||||
let client_arc = worker
|
||||
.get_grpc_client()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to get gRPC client from worker: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to get gRPC client: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
error!("Selected worker is not a gRPC worker");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Selected worker is not configured for gRPC",
|
||||
)
|
||||
.into_response()
|
||||
})?;
|
||||
|
||||
let client = client_arc.lock().await.clone();
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
/// Process tool call arguments in messages
|
||||
/// Per Transformers docs, tool call arguments in assistant messages should be dicts
|
||||
pub fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), String> {
|
||||
for msg in messages {
|
||||
// Early return if not assistant message
|
||||
let role = msg.get("role").and_then(|v| v.as_str());
|
||||
if role != Some("assistant") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Early return if no tool_calls
|
||||
let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut()) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Process each tool call's arguments
|
||||
for call in tool_calls {
|
||||
let Some(function) = call.get_mut("function") else {
|
||||
continue;
|
||||
};
|
||||
let Some(args) = function.get_mut("arguments") else {
|
||||
continue;
|
||||
};
|
||||
let Some(args_str) = args.as_str() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
// Parse JSON string to object (like Python json.loads)
|
||||
match serde_json::from_str::<Value>(args_str) {
|
||||
Ok(parsed) => *args = parsed,
|
||||
Err(e) => {
|
||||
return Err(format!(
|
||||
"Failed to parse tool call arguments as JSON: '{}'. Error: {}",
|
||||
args_str, e
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process messages based on content format for ANY message type
|
||||
pub fn process_content_format(
|
||||
messages: &[ChatMessage],
|
||||
content_format: ChatTemplateContentFormat,
|
||||
) -> Result<Vec<Value>, String> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
let mut message_json = serde_json::to_value(message)
|
||||
.map_err(|e| format!("Failed to serialize message: {}", e))?;
|
||||
|
||||
if let Some(obj) = message_json.as_object_mut() {
|
||||
if let Some(content_value) = obj.get_mut("content") {
|
||||
transform_content_field(content_value, content_format);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(message_json)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Transform a single content field based on content format
|
||||
pub fn transform_content_field(
|
||||
content_value: &mut Value,
|
||||
content_format: ChatTemplateContentFormat,
|
||||
) {
|
||||
let Some(content_array) = content_value.as_array() else {
|
||||
return; // Not multimodal, keep as-is
|
||||
};
|
||||
|
||||
match content_format {
|
||||
ChatTemplateContentFormat::String => {
|
||||
// Extract and join text parts only
|
||||
let text_parts: Vec<String> = content_array
|
||||
.iter()
|
||||
.filter_map(|part| {
|
||||
part.as_object()?
|
||||
.get("type")?
|
||||
.as_str()
|
||||
.filter(|&t| t == "text")
|
||||
.and_then(|_| part.as_object()?.get("text")?.as_str())
|
||||
.map(String::from)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if !text_parts.is_empty() {
|
||||
*content_value = Value::String(text_parts.join(" "));
|
||||
}
|
||||
}
|
||||
ChatTemplateContentFormat::OpenAI => {
|
||||
// Replace media URLs with simple type placeholders
|
||||
let processed_parts: Vec<Value> = content_array
|
||||
.iter()
|
||||
.map(|part| {
|
||||
part.as_object()
|
||||
.and_then(|obj| obj.get("type")?.as_str())
|
||||
.and_then(|type_str| match type_str {
|
||||
"image_url" => Some(json!({"type": "image"})),
|
||||
"video_url" => Some(json!({"type": "video"})),
|
||||
"audio_url" => Some(json!({"type": "audio"})),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or_else(|| part.clone())
|
||||
})
|
||||
.collect();
|
||||
|
||||
*content_value = Value::Array(processed_parts);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate tool constraints for structured generation
|
||||
/// Note: tools should already be filtered if needed (by allowed_tools or specific function)
|
||||
pub fn generate_tool_constraints(
|
||||
tools: &[Tool],
|
||||
tool_choice: &Option<ToolChoice>,
|
||||
_model: &str,
|
||||
) -> Option<(String, String)> {
|
||||
let choice = tool_choice.as_ref()?;
|
||||
|
||||
match choice {
|
||||
// Specific function: Return parameters schema directly
|
||||
// tools should already be filtered to contain only the specific function
|
||||
ToolChoice::Function { .. } => {
|
||||
if tools.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let tool = &tools[0];
|
||||
|
||||
// Return the tool's parameters schema directly (not wrapped in array)
|
||||
let params_schema = serde_json::to_string(&tool.function.parameters).ok()?;
|
||||
Some(("json_schema".to_string(), params_schema))
|
||||
}
|
||||
|
||||
// Required: Array of tool calls with minItems: 1
|
||||
ToolChoice::Value(ToolChoiceValue::Required) => {
|
||||
let schema = build_required_array_schema(tools)?;
|
||||
Some(("json_schema".to_string(), schema))
|
||||
}
|
||||
|
||||
// AllowedTools with required mode: tools are already filtered
|
||||
ToolChoice::AllowedTools { mode, .. } => {
|
||||
if mode == "required" {
|
||||
if tools.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let schema = build_required_array_schema(tools)?;
|
||||
Some(("json_schema".to_string(), schema))
|
||||
} else {
|
||||
// "auto" mode - no constraint needed
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// "auto" or "none" - no constraint
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build JSON schema for required tool calls (array with minItems: 1)
|
||||
/// Includes $defs consolidation from all tools (matching Python's behavior)
|
||||
pub fn build_required_array_schema(tools: &[Tool]) -> Option<String> {
|
||||
// Build anyOf schemas for each tool
|
||||
let mut any_of_schemas = Vec::new();
|
||||
for tool in tools {
|
||||
let tool_schema = json!({
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"enum": [tool.function.name]
|
||||
},
|
||||
"parameters": tool.function.parameters
|
||||
},
|
||||
"required": ["name", "parameters"]
|
||||
});
|
||||
any_of_schemas.push(tool_schema);
|
||||
}
|
||||
|
||||
// Consolidate $defs from all tools (matching Python's _get_tool_schema_defs)
|
||||
let mut all_defs: HashMap<String, Value> = HashMap::new();
|
||||
for tool in tools {
|
||||
if let Value::Object(params) = &tool.function.parameters {
|
||||
if let Some(Value::Object(defs)) = params.get("$defs") {
|
||||
for (def_name, def_schema) in defs {
|
||||
if let Some(existing) = all_defs.get(def_name) {
|
||||
// Check for conflicts
|
||||
if existing != def_schema {
|
||||
error!(
|
||||
"Tool definition '{}' has multiple schemas, which is not supported",
|
||||
def_name
|
||||
);
|
||||
return None;
|
||||
}
|
||||
} else {
|
||||
all_defs.insert(def_name.clone(), def_schema.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the full array schema
|
||||
let mut array_schema = json!({
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": any_of_schemas
|
||||
}
|
||||
});
|
||||
|
||||
// Add $defs if any were found (matching Python's behavior)
|
||||
if !all_defs.is_empty() {
|
||||
if let Value::Object(ref mut schema_obj) = array_schema {
|
||||
let defs_value = Value::Object(all_defs.into_iter().collect::<Map<String, Value>>());
|
||||
schema_obj.insert("$defs".to_string(), defs_value);
|
||||
}
|
||||
}
|
||||
|
||||
serde_json::to_string(&array_schema).ok()
|
||||
}
|
||||
|
||||
/// Filter tools based on tool_choice (shared by both routers)
|
||||
/// Returns a reference to the original body if no filtering needed,
|
||||
/// otherwise returns a cloned and filtered body
|
||||
pub fn filter_tools_for_request(
|
||||
body: &ChatCompletionRequest,
|
||||
) -> std::borrow::Cow<'_, ChatCompletionRequest> {
|
||||
match &body.tool_choice {
|
||||
Some(ToolChoice::AllowedTools { tools: allowed, .. }) if body.tools.is_some() => {
|
||||
let mut filtered_body = body.clone();
|
||||
let all_tools = filtered_body.tools.as_ref().unwrap();
|
||||
let allowed_names: std::collections::HashSet<&str> =
|
||||
allowed.iter().map(|t| t.name.as_str()).collect();
|
||||
let filtered_tools: Vec<Tool> = all_tools
|
||||
.iter()
|
||||
.filter(|t| allowed_names.contains(t.function.name.as_str()))
|
||||
.cloned()
|
||||
.collect();
|
||||
filtered_body.tools = Some(filtered_tools);
|
||||
std::borrow::Cow::Owned(filtered_body)
|
||||
}
|
||||
Some(ToolChoice::Function { function, .. }) if body.tools.is_some() => {
|
||||
let mut filtered_body = body.clone();
|
||||
let all_tools = filtered_body.tools.as_ref().unwrap();
|
||||
let filtered_tools: Vec<Tool> = all_tools
|
||||
.iter()
|
||||
.filter(|t| t.function.name == function.name)
|
||||
.cloned()
|
||||
.collect();
|
||||
filtered_body.tools = Some(filtered_tools);
|
||||
std::borrow::Cow::Owned(filtered_body)
|
||||
}
|
||||
_ => std::borrow::Cow::Borrowed(body), // No filtering needed, use original
|
||||
}
|
||||
}
|
||||
|
||||
/// Process chat messages and apply template (shared by both routers)
|
||||
/// Requires HuggingFace tokenizer with chat template support
|
||||
pub fn process_chat_messages(
|
||||
request: &ChatCompletionRequest,
|
||||
tokenizer: &dyn Tokenizer,
|
||||
) -> Result<ProcessedMessages, String> {
|
||||
// Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC
|
||||
let formatted_text = if let Some(hf_tokenizer) =
|
||||
tokenizer.as_any().downcast_ref::<HuggingFaceTokenizer>()
|
||||
{
|
||||
// Get content format and transform messages accordingly
|
||||
let content_format = hf_tokenizer.chat_template_content_format();
|
||||
let mut transformed_messages = process_content_format(&request.messages, content_format)?;
|
||||
|
||||
// Process tool call arguments in assistant messages
|
||||
process_tool_call_arguments(&mut transformed_messages)?;
|
||||
|
||||
// Convert tools to JSON values for template processing
|
||||
let tools_json: Option<Vec<Value>> = request
|
||||
.tools
|
||||
.as_ref()
|
||||
.map(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.map(serde_json::to_value)
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
.transpose()
|
||||
.map_err(|e| format!("Failed to serialize tools: {}", e))?;
|
||||
|
||||
// Build template kwargs, merging reasoning_effort if present
|
||||
let mut combined_template_kwargs = HashMap::new();
|
||||
|
||||
// Add reasoning_effort if present (like Python does)
|
||||
if let Some(reasoning_effort) = &request.reasoning_effort {
|
||||
combined_template_kwargs.insert(
|
||||
"reasoning_effort".to_string(),
|
||||
Value::String(reasoning_effort.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
// Add any additional template kwargs from request
|
||||
if let Some(template_kwargs) = &request.chat_template_kwargs {
|
||||
for (key, value) in template_kwargs {
|
||||
combined_template_kwargs.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let final_template_kwargs = if combined_template_kwargs.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(&combined_template_kwargs)
|
||||
};
|
||||
|
||||
let params = ChatTemplateParams {
|
||||
add_generation_prompt: true,
|
||||
continue_final_message: request.continue_final_message,
|
||||
tools: tools_json.as_deref(),
|
||||
template_kwargs: final_template_kwargs,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Handle assistant prefix for continue_final_message
|
||||
let assistant_prefix = if request.continue_final_message
|
||||
&& !transformed_messages.is_empty()
|
||||
&& transformed_messages
|
||||
.last()
|
||||
.and_then(|msg| msg.get("role"))
|
||||
.and_then(|v| v.as_str())
|
||||
== Some("assistant")
|
||||
{
|
||||
// Pop the last message to handle it separately
|
||||
let last_msg = transformed_messages.pop().unwrap();
|
||||
last_msg
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Apply chat template with the (now possibly shorter) list of messages
|
||||
let rendered = hf_tokenizer
|
||||
.apply_chat_template(&transformed_messages, params)
|
||||
.map_err(|e| format!("Failed to apply chat template: {}", e))?;
|
||||
|
||||
// Append assistant prefix if we have one
|
||||
if let Some(prefix) = assistant_prefix {
|
||||
format!("{}{}", rendered, prefix)
|
||||
} else {
|
||||
rendered
|
||||
}
|
||||
} else {
|
||||
return Err(
|
||||
"gRPC router requires HuggingFace tokenizer with chat template support".to_string(),
|
||||
);
|
||||
};
|
||||
|
||||
// Placeholder for multimodal inputs
|
||||
let multimodal_inputs = None;
|
||||
|
||||
Ok(ProcessedMessages {
|
||||
text: formatted_text,
|
||||
multimodal_inputs,
|
||||
stop_sequences: request.stop.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Error response helpers (shared between regular and PD routers)
|
||||
pub fn internal_error_static(msg: &'static str) -> Response {
|
||||
error!("{}", msg);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, msg).into_response()
|
||||
}
|
||||
|
||||
pub fn internal_error_message(message: String) -> Response {
|
||||
error!("{}", message);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, message).into_response()
|
||||
}
|
||||
|
||||
/// Create a StopSequenceDecoder from stop parameters
|
||||
pub fn create_stop_decoder(
|
||||
tokenizer: &Arc<dyn Tokenizer>,
|
||||
stop: Option<&StringOrArray>,
|
||||
stop_token_ids: Option<&Vec<u32>>,
|
||||
skip_special_tokens: bool,
|
||||
no_stop_trim: bool,
|
||||
) -> StopSequenceDecoder {
|
||||
use crate::tokenizer::stop::StopSequenceDecoderBuilder;
|
||||
|
||||
// Extract stop sequences
|
||||
let stop_sequences: Vec<String> = match stop {
|
||||
Some(StringOrArray::String(s)) => vec![s.clone()],
|
||||
Some(StringOrArray::Array(arr)) => arr.clone(),
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
// Build stop sequence decoder
|
||||
let mut builder =
|
||||
StopSequenceDecoderBuilder::new(tokenizer.clone()).skip_special_tokens(skip_special_tokens);
|
||||
|
||||
// Add stop sequences (visible if no_stop_trim is true, hidden otherwise)
|
||||
for seq in stop_sequences {
|
||||
builder = if no_stop_trim {
|
||||
builder.visible_stop_sequence(seq)
|
||||
} else {
|
||||
builder.stop_sequence(seq)
|
||||
};
|
||||
}
|
||||
|
||||
// Add stop token IDs (visible if no_stop_trim is true, hidden otherwise)
|
||||
if let Some(token_ids) = stop_token_ids {
|
||||
for &token_id in token_ids {
|
||||
builder = if no_stop_trim {
|
||||
builder.visible_stop_token(token_id)
|
||||
} else {
|
||||
builder.stop_token(token_id)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
builder.build()
|
||||
}
|
||||
|
||||
/// Parse tool calls from JSON schema constrained response
|
||||
pub fn parse_json_schema_response(
|
||||
processed_text: &str,
|
||||
tool_choice: &Option<ToolChoice>,
|
||||
) -> (Option<Vec<ToolCall>>, String) {
|
||||
match tool_choice {
|
||||
Some(ToolChoice::Function { function, .. }) => {
|
||||
// Specific function: Parse parameters directly
|
||||
match serde_json::from_str::<Value>(processed_text) {
|
||||
Ok(params) => {
|
||||
let tool_call = ToolCall {
|
||||
id: format!("call_{}", uuid::Uuid::new_v4()),
|
||||
tool_type: "function".to_string(),
|
||||
function: FunctionCallResponse {
|
||||
name: function.name.clone(),
|
||||
arguments: Some(
|
||||
serde_json::to_string(¶ms).unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
},
|
||||
};
|
||||
(Some(vec![tool_call]), String::new())
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse specific function parameters: {}", e);
|
||||
(None, processed_text.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(ToolChoice::Value(ToolChoiceValue::Required))
|
||||
| Some(ToolChoice::AllowedTools { .. }) => {
|
||||
// Required mode: Parse array of tool calls
|
||||
match serde_json::from_str::<Vec<Value>>(processed_text) {
|
||||
Ok(parsed_array) => {
|
||||
let spec_tool_calls: Vec<ToolCall> = parsed_array
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, item)| {
|
||||
let obj = item.as_object()?;
|
||||
let name = obj.get("name")?.as_str()?.to_string();
|
||||
let parameters = obj.get("parameters")?;
|
||||
|
||||
Some(ToolCall {
|
||||
id: format!("call_{}_{}", i, uuid::Uuid::new_v4()),
|
||||
tool_type: "function".to_string(),
|
||||
function: FunctionCallResponse {
|
||||
name,
|
||||
arguments: Some(
|
||||
serde_json::to_string(parameters)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
),
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
(Some(spec_tool_calls), String::new())
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse required tool call array: {}", e);
|
||||
(None, processed_text.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => (None, processed_text.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect responses from a gRPC stream
|
||||
///
|
||||
/// This helper processes a gRPC GenerateResponse stream and collects all Complete responses.
|
||||
/// Used by both regular and PD routers for non-streaming requests.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `stream` - The gRPC response stream to consume
|
||||
/// * `worker_name` - Name for logging (e.g., "Prefill", "Decode", "Worker")
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Ok(Vec<GenerateComplete>)` - All complete responses collected from the stream
|
||||
/// * `Err(Response)` - Error response if the stream fails or returns an error
|
||||
pub async fn collect_stream_responses(
|
||||
mut stream: Streaming<proto::GenerateResponse>,
|
||||
worker_name: &str,
|
||||
) -> Result<Vec<proto::GenerateComplete>, Response> {
|
||||
use proto::generate_response::Response::*;
|
||||
|
||||
let mut all_responses = Vec::new();
|
||||
|
||||
while let Some(response) = stream.next().await {
|
||||
match response {
|
||||
Ok(gen_response) => {
|
||||
match gen_response.response {
|
||||
Some(Complete(complete)) => {
|
||||
debug!(
|
||||
"{} completed: prompt_tokens={}, completion_tokens={}, finish_reason={}",
|
||||
worker_name, complete.prompt_tokens, complete.completion_tokens, complete.finish_reason
|
||||
);
|
||||
all_responses.push(complete);
|
||||
}
|
||||
Some(Error(err)) => {
|
||||
error!("{} error: {}", worker_name, err.message);
|
||||
return Err(internal_error_message(format!(
|
||||
"{} generation failed: {}",
|
||||
worker_name, err.message
|
||||
)));
|
||||
}
|
||||
Some(Chunk(chunk)) => {
|
||||
debug!("{} chunk: {} tokens", worker_name, chunk.token_ids.len());
|
||||
}
|
||||
None => {
|
||||
debug!("{}: empty response", worker_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("{} stream error: {:?}", worker_name, e);
|
||||
return Err(internal_error_message(format!(
|
||||
"{} stream failed: {}",
|
||||
worker_name, e
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!("{} stream closed", worker_name);
|
||||
Ok(all_responses)
|
||||
}
|
||||
|
||||
/// Count the number of tool calls in the request message history
|
||||
/// This is used for KimiK2 format which needs globally unique indices
|
||||
pub fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize {
|
||||
request
|
||||
.messages
|
||||
.iter()
|
||||
.filter_map(|msg| {
|
||||
if let ChatMessage::Assistant { tool_calls, .. } = msg {
|
||||
tool_calls.as_ref().map(|calls| calls.len())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Generate a tool call ID based on model format
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model` - Model name to determine ID format
|
||||
/// * `tool_name` - Name of the tool being called
|
||||
/// * `tool_index` - Index of this tool call within the current message
|
||||
/// * `history_count` - Number of tool calls in previous messages
|
||||
///
|
||||
/// # Returns
|
||||
/// A unique ID string. KimiK2 uses `functions.{name}:{global_index}`, others use `call_{uuid}`
|
||||
pub fn generate_tool_call_id(
|
||||
model: &str,
|
||||
tool_name: &str,
|
||||
tool_index: usize,
|
||||
history_count: usize,
|
||||
) -> String {
|
||||
if model.to_lowercase().contains("kimi") {
|
||||
// KimiK2 format: functions.{name}:{global_index}
|
||||
format!("functions.{}:{}", tool_name, history_count + tool_index)
|
||||
} else {
|
||||
// Standard OpenAI format: call_{24-char-uuid}
|
||||
format!("call_{}", &Uuid::new_v4().simple().to_string()[..24])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::protocols::spec::{ChatMessage, ContentPart, ImageUrl, UserMessageContent};
|
||||
use crate::tokenizer::chat_template::ChatTemplateContentFormat;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_string_format() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "Hello".to_string(),
|
||||
},
|
||||
ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
},
|
||||
ContentPart::Text {
|
||||
text: "World".to_string(),
|
||||
},
|
||||
]),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
|
||||
// Should flatten multimodal content to text only
|
||||
assert_eq!(
|
||||
transformed_message["content"].as_str().unwrap(),
|
||||
"Hello World"
|
||||
);
|
||||
assert_eq!(transformed_message["role"].as_str().unwrap(), "user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_openai_format() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "Describe this image:".to_string(),
|
||||
},
|
||||
ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: Some("high".to_string()),
|
||||
},
|
||||
},
|
||||
]),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = process_content_format(&messages, ChatTemplateContentFormat::OpenAI).unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
|
||||
// Should replace media URLs with simple type placeholders
|
||||
let content_array = transformed_message["content"].as_array().unwrap();
|
||||
assert_eq!(content_array.len(), 2);
|
||||
|
||||
// Text part should remain unchanged
|
||||
assert_eq!(content_array[0]["type"], "text");
|
||||
assert_eq!(content_array[0]["text"], "Describe this image:");
|
||||
|
||||
// Image part should be replaced with simple type placeholder
|
||||
assert_eq!(content_array[1], json!({"type": "image"}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_simple_string_content() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text("Simple text message".to_string()),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
|
||||
// Simple string content should remain unchanged
|
||||
assert_eq!(
|
||||
transformed_message["content"].as_str().unwrap(),
|
||||
"Simple text message"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_multiple_messages() {
|
||||
let messages = vec![
|
||||
ChatMessage::System {
|
||||
role: "system".to_string(),
|
||||
content: "System prompt".to_string(),
|
||||
name: None,
|
||||
},
|
||||
ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "User message".to_string(),
|
||||
},
|
||||
ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
},
|
||||
]),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap();
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
|
||||
// System message should remain unchanged
|
||||
assert_eq!(result[0]["role"].as_str().unwrap(), "system");
|
||||
assert_eq!(result[0]["content"].as_str().unwrap(), "System prompt");
|
||||
|
||||
// User message should be flattened to text only
|
||||
assert_eq!(result[1]["role"].as_str().unwrap(), "user");
|
||||
assert_eq!(result[1]["content"].as_str().unwrap(), "User message");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_empty_text_parts() {
|
||||
let messages = vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
}]),
|
||||
name: None,
|
||||
}];
|
||||
|
||||
let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
let transformed_message = &result[0];
|
||||
|
||||
// Should keep original multimodal content when no text parts exist
|
||||
assert!(transformed_message["content"].is_array());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transform_messages_mixed_content_types() {
|
||||
let messages = vec![
|
||||
ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text("Plain text".to_string()),
|
||||
name: None,
|
||||
},
|
||||
ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Parts(vec![
|
||||
ContentPart::Text {
|
||||
text: "With image".to_string(),
|
||||
},
|
||||
ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: "https://example.com/image.jpg".to_string(),
|
||||
detail: Some("low".to_string()),
|
||||
},
|
||||
},
|
||||
]),
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
|
||||
let result_string =
|
||||
process_content_format(&messages, ChatTemplateContentFormat::String).unwrap();
|
||||
|
||||
assert_eq!(result_string.len(), 2);
|
||||
assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text");
|
||||
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
|
||||
|
||||
let result_openai =
|
||||
process_content_format(&messages, ChatTemplateContentFormat::OpenAI).unwrap();
|
||||
|
||||
assert_eq!(result_openai.len(), 2);
|
||||
assert_eq!(result_openai[0]["content"].as_str().unwrap(), "Plain text");
|
||||
|
||||
let content_array = result_openai[1]["content"].as_array().unwrap();
|
||||
assert_eq!(content_array.len(), 2);
|
||||
assert_eq!(content_array[0]["type"], "text");
|
||||
assert_eq!(content_array[1], json!({"type": "image"}));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user