refactor: Move grpc/client.rs to grpc_client/sglang_scheduler.rs (#10924)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use super::{CircuitBreaker, WorkerError, WorkerResult};
|
||||
use crate::core::CircuitState;
|
||||
use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
||||
use crate::grpc::SglangSchedulerClient;
|
||||
use crate::grpc_client::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use async_trait::async_trait;
|
||||
use futures;
|
||||
|
||||
@@ -2,7 +2,7 @@ use super::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
|
||||
use super::worker::{
|
||||
BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType,
|
||||
};
|
||||
use crate::grpc::client::SglangSchedulerClient;
|
||||
use crate::grpc_client::SglangSchedulerClient;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Builder for creating BasicWorker instances with fluent API
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
//! gRPC client module for communicating with SGLang scheduler
|
||||
//!
|
||||
//! This module provides a gRPC client implementation for the SGLang router.
|
||||
|
||||
pub mod client;
|
||||
|
||||
// Re-export the client
|
||||
pub use client::{proto, SglangSchedulerClient};
|
||||
3
sgl-router/src/grpc_client/mod.rs
Normal file
3
sgl-router/src/grpc_client/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod sglang_scheduler;
|
||||
|
||||
pub use sglang_scheduler::{proto, SglangSchedulerClient};
|
||||
@@ -2,6 +2,8 @@ use std::time::Duration;
|
||||
use tonic::{transport::Channel, Request};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::protocols::spec::{ChatCompletionRequest, ResponseFormat};
|
||||
|
||||
// Include the generated protobuf code
|
||||
pub mod proto {
|
||||
tonic::include_proto!("sglang.grpc.scheduler");
|
||||
@@ -75,6 +77,142 @@ impl SglangSchedulerClient {
|
||||
self.client.abort(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build a single SGLang GenerateRequest from OpenAI ChatCompletionRequest
|
||||
pub fn build_generate_request(
|
||||
&self,
|
||||
request_id: String,
|
||||
body: &ChatCompletionRequest,
|
||||
processed_text: String,
|
||||
token_ids: Vec<i32>,
|
||||
multimodal_inputs: Option<proto::MultimodalInputs>,
|
||||
tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value)
|
||||
) -> Result<proto::GenerateRequest, String> {
|
||||
// Build sampling params
|
||||
let sampling_params = self.build_grpc_sampling_params(body, tool_call_constraint)?;
|
||||
|
||||
let grpc_request = proto::GenerateRequest {
|
||||
request_id,
|
||||
tokenized: Some(proto::TokenizedInput {
|
||||
original_text: processed_text,
|
||||
input_ids: token_ids,
|
||||
}),
|
||||
mm_inputs: multimodal_inputs,
|
||||
sampling_params: Some(sampling_params),
|
||||
return_logprob: body.logprobs,
|
||||
logprob_start_len: -1,
|
||||
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
|
||||
return_hidden_states: body.return_hidden_states,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Ok(grpc_request)
|
||||
}
|
||||
|
||||
/// Build gRPC SamplingParams from OpenAI request
|
||||
fn build_grpc_sampling_params(
|
||||
&self,
|
||||
request: &ChatCompletionRequest,
|
||||
tool_call_constraint: Option<(String, String)>,
|
||||
) -> Result<proto::SamplingParams, String> {
|
||||
let stop_sequences = self.extract_stop_strings(request);
|
||||
|
||||
// Handle max tokens: prefer max_completion_tokens (new) over max_tokens (deprecated)
|
||||
// If neither is specified, use None to let the backend decide the default
|
||||
#[allow(deprecated)]
|
||||
let max_new_tokens = request
|
||||
.max_completion_tokens
|
||||
.or(request.max_tokens)
|
||||
.map(|v| v as i32);
|
||||
|
||||
// Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none"
|
||||
let skip_special_tokens = if request.tools.is_some() {
|
||||
match &request.tool_choice {
|
||||
Some(crate::protocols::spec::ToolChoice::Value(
|
||||
crate::protocols::spec::ToolChoiceValue::None,
|
||||
)) => request.skip_special_tokens,
|
||||
Some(_) => false, // tool_choice is not "none"
|
||||
None => false, // TODO: this assumes tool_choice defaults to "auto" when tools present
|
||||
}
|
||||
} else {
|
||||
request.skip_special_tokens
|
||||
};
|
||||
|
||||
#[allow(deprecated)]
|
||||
Ok(proto::SamplingParams {
|
||||
temperature: request.temperature.unwrap_or(1.0),
|
||||
top_p: request.top_p.unwrap_or(1.0),
|
||||
top_k: request.top_k.unwrap_or(-1),
|
||||
min_p: request.min_p.unwrap_or(0.0),
|
||||
frequency_penalty: request.frequency_penalty.unwrap_or(0.0),
|
||||
presence_penalty: request.presence_penalty.unwrap_or(0.0),
|
||||
repetition_penalty: request.repetition_penalty.unwrap_or(1.0),
|
||||
max_new_tokens,
|
||||
stop: stop_sequences,
|
||||
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
|
||||
skip_special_tokens,
|
||||
n: request.n.unwrap_or(1) as i32,
|
||||
constraint: self.build_constraint(request, tool_call_constraint)?,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract stop strings from request
|
||||
fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec<String> {
|
||||
match &request.stop {
|
||||
Some(crate::protocols::spec::StringOrArray::String(s)) => vec![s.clone()],
|
||||
Some(crate::protocols::spec::StringOrArray::Array(arr)) => arr.clone(),
|
||||
None => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Build constraint for structured generation
|
||||
fn build_constraint(
|
||||
&self,
|
||||
request: &ChatCompletionRequest,
|
||||
tool_call_constraint: Option<(String, String)>,
|
||||
) -> Result<Option<proto::sampling_params::Constraint>, String> {
|
||||
let mut constraints = Vec::new();
|
||||
|
||||
if let Some(ResponseFormat::JsonSchema { json_schema }) = &request.response_format {
|
||||
let schema_str = serde_json::to_string(&json_schema.schema)
|
||||
.map_err(|e| format!("Failed to serialize JSON schema: {}", e))?;
|
||||
constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str));
|
||||
}
|
||||
|
||||
if let Some(ebnf) = &request.ebnf {
|
||||
constraints.push(proto::sampling_params::Constraint::EbnfGrammar(
|
||||
ebnf.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(regex) = &request.regex {
|
||||
constraints.push(proto::sampling_params::Constraint::Regex(regex.clone()));
|
||||
}
|
||||
|
||||
// Handle tool call constraint
|
||||
if let Some((constraint_type, constraint_value)) = tool_call_constraint {
|
||||
if !constraints.is_empty() {
|
||||
return Err("Constrained decoding is not compatible with tool calls.".to_string());
|
||||
}
|
||||
let tool_constraint = match constraint_type.as_str() {
|
||||
"structural_tag" => {
|
||||
proto::sampling_params::Constraint::StructuralTag(constraint_value)
|
||||
}
|
||||
"json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value),
|
||||
"ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value),
|
||||
"regex" => proto::sampling_params::Constraint::Regex(constraint_value),
|
||||
_ => return Err(format!("Unknown constraint type: {}", constraint_type)),
|
||||
};
|
||||
constraints.push(tool_constraint);
|
||||
}
|
||||
|
||||
match constraints.len() {
|
||||
0 => Ok(None),
|
||||
1 => Ok(constraints.pop()),
|
||||
_ => Err("Multiple constraints are not allowed.".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -230,22 +368,16 @@ mod tests {
|
||||
fn test_generate_stream_chunk() {
|
||||
let chunk = proto::GenerateStreamChunk {
|
||||
token_id: 1234,
|
||||
text: " world".to_string(),
|
||||
prompt_tokens: 5,
|
||||
completion_tokens: 2,
|
||||
cached_tokens: 3,
|
||||
generation_time: 0.025,
|
||||
queue_time: 10,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(chunk.token_id, 1234);
|
||||
assert_eq!(chunk.text, " world");
|
||||
assert_eq!(chunk.prompt_tokens, 5);
|
||||
assert_eq!(chunk.completion_tokens, 2);
|
||||
assert_eq!(chunk.cached_tokens, 3);
|
||||
assert_eq!(chunk.generation_time, 0.025);
|
||||
assert_eq!(chunk.queue_time, 10);
|
||||
}
|
||||
|
||||
// TODO: ModelInfo not in current proto - skip test
|
||||
@@ -6,7 +6,7 @@ use std::collections::HashMap;
|
||||
pub mod core;
|
||||
pub mod data_connector;
|
||||
#[cfg(feature = "grpc-client")]
|
||||
pub mod grpc;
|
||||
pub mod grpc_client;
|
||||
pub mod mcp;
|
||||
pub mod metrics;
|
||||
pub mod middleware;
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::grpc::SglangSchedulerClient;
|
||||
use crate::grpc_client::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
|
||||
@@ -17,10 +17,10 @@ use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::grpc::{proto, SglangSchedulerClient};
|
||||
use crate::grpc_client::{proto, SglangSchedulerClient};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||
use crate::protocols::spec::{ChatCompletionRequest, ResponseFormat, StringOrArray};
|
||||
use crate::protocols::spec::{ChatCompletionRequest, StringOrArray};
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::RouterTrait;
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
@@ -247,40 +247,31 @@ impl GrpcRouter {
|
||||
None
|
||||
};
|
||||
|
||||
// Step 6: Build SamplingParams for gRPC
|
||||
let sampling_params = match self.build_grpc_sampling_params(body, tool_call_constraint) {
|
||||
Ok(params) => params,
|
||||
// Step 6: Build the base gRPC request
|
||||
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
let base_request = match client.build_generate_request(
|
||||
request_id,
|
||||
body,
|
||||
processed_messages.text.clone(),
|
||||
token_ids.into_iter().map(|id| id as i32).collect(),
|
||||
processed_messages.multimodal_inputs,
|
||||
tool_call_constraint, // Pass the full tuple (type, value)
|
||||
) {
|
||||
Ok(request) => request,
|
||||
Err(e) => {
|
||||
error!("Failed to build sampling parameters: {}", e);
|
||||
error!("Failed to build gRPC request: {}", e);
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Invalid sampling parameters: {}", e),
|
||||
format!("Invalid request parameters: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Step 7: Create GenerateRequest
|
||||
let grpc_request = proto::GenerateRequest {
|
||||
request_id: format!("chatcmpl-{}", Uuid::new_v4()),
|
||||
tokenized: Some(proto::TokenizedInput {
|
||||
original_text: processed_messages.text.clone(),
|
||||
input_ids: token_ids.into_iter().map(|id| id as i32).collect(),
|
||||
}),
|
||||
mm_inputs: processed_messages.multimodal_inputs,
|
||||
sampling_params: Some(sampling_params),
|
||||
return_logprob: body.logprobs,
|
||||
logprob_start_len: -1,
|
||||
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
|
||||
return_hidden_states: body.return_hidden_states,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Step 8: Handle streaming vs non-streaming
|
||||
if body.stream {
|
||||
self.handle_streaming_chat(client, grpc_request, body).await
|
||||
self.handle_streaming_chat(client, base_request, body).await
|
||||
} else {
|
||||
self.handle_non_streaming_chat(client, grpc_request, body)
|
||||
self.handle_non_streaming_chat(client, base_request, body)
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -547,111 +538,6 @@ impl GrpcRouter {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Build gRPC SamplingParams from OpenAI request
|
||||
fn build_grpc_sampling_params(
|
||||
&self,
|
||||
request: &ChatCompletionRequest,
|
||||
tool_call_constraint: Option<(String, String)>,
|
||||
) -> Result<proto::SamplingParams, String> {
|
||||
let stop_sequences = self.extract_stop_strings(request);
|
||||
|
||||
// Handle max tokens: prefer max_completion_tokens (new) over max_tokens (deprecated)
|
||||
// If neither is specified, use None to let the backend decide the default
|
||||
#[allow(deprecated)]
|
||||
let max_new_tokens = request
|
||||
.max_completion_tokens
|
||||
.or(request.max_tokens)
|
||||
.map(|v| v as i32);
|
||||
|
||||
// Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none"
|
||||
let skip_special_tokens = if request.tools.is_some() {
|
||||
match &request.tool_choice {
|
||||
Some(crate::protocols::spec::ToolChoice::Value(
|
||||
crate::protocols::spec::ToolChoiceValue::None,
|
||||
)) => request.skip_special_tokens,
|
||||
Some(_) => false, // tool_choice is not "none"
|
||||
None => false, // TODO: this assumes tool_choice defaults to "auto" when tools present
|
||||
}
|
||||
} else {
|
||||
request.skip_special_tokens
|
||||
};
|
||||
|
||||
#[allow(deprecated)]
|
||||
Ok(proto::SamplingParams {
|
||||
temperature: request.temperature.unwrap_or(1.0),
|
||||
top_p: request.top_p.unwrap_or(1.0),
|
||||
top_k: request.top_k.unwrap_or(-1),
|
||||
min_p: request.min_p.unwrap_or(0.0),
|
||||
frequency_penalty: request.frequency_penalty.unwrap_or(0.0),
|
||||
presence_penalty: request.presence_penalty.unwrap_or(0.0),
|
||||
repetition_penalty: request.repetition_penalty.unwrap_or(1.0),
|
||||
max_new_tokens,
|
||||
stop: stop_sequences,
|
||||
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
|
||||
skip_special_tokens,
|
||||
n: request.n.unwrap_or(1) as i32,
|
||||
constraint: self.build_constraint(request, tool_call_constraint)?,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract stop strings from request
|
||||
fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec<String> {
|
||||
match &request.stop {
|
||||
Some(StringOrArray::String(s)) => vec![s.clone()],
|
||||
Some(StringOrArray::Array(arr)) => arr.clone(),
|
||||
None => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Build constraint for structured generation
|
||||
fn build_constraint(
|
||||
&self,
|
||||
request: &ChatCompletionRequest,
|
||||
tool_call_constraint: Option<(String, String)>,
|
||||
) -> Result<Option<proto::sampling_params::Constraint>, String> {
|
||||
let mut constraints = Vec::new();
|
||||
|
||||
if let Some(ResponseFormat::JsonSchema { json_schema }) = &request.response_format {
|
||||
let schema_str = serde_json::to_string(&json_schema.schema)
|
||||
.map_err(|e| format!("Failed to serialize JSON schema: {}", e))?;
|
||||
constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str));
|
||||
}
|
||||
|
||||
if let Some(ebnf) = &request.ebnf {
|
||||
constraints.push(proto::sampling_params::Constraint::EbnfGrammar(
|
||||
ebnf.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(regex) = &request.regex {
|
||||
constraints.push(proto::sampling_params::Constraint::Regex(regex.clone()));
|
||||
}
|
||||
|
||||
// Handle tool call constraint
|
||||
if let Some((constraint_type, constraint_value)) = tool_call_constraint {
|
||||
if !constraints.is_empty() {
|
||||
return Err("Constrained decoding is not compatible with tool calls.".to_string());
|
||||
}
|
||||
let tool_constraint = match constraint_type.as_str() {
|
||||
"structural_tag" => {
|
||||
proto::sampling_params::Constraint::StructuralTag(constraint_value)
|
||||
}
|
||||
"json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value),
|
||||
"ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value),
|
||||
"regex" => proto::sampling_params::Constraint::Regex(constraint_value),
|
||||
_ => return Err(format!("Unknown constraint type: {}", constraint_type)),
|
||||
};
|
||||
constraints.push(tool_constraint);
|
||||
}
|
||||
|
||||
match constraints.len() {
|
||||
0 => Ok(None),
|
||||
1 => Ok(constraints.pop()),
|
||||
_ => Err("Multiple constraints are not allowed.".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate tool constraints for structured generation
|
||||
fn generate_tool_constraints(
|
||||
&self,
|
||||
|
||||
Reference in New Issue
Block a user