router: Fix constraint proto and build_constraint in grpc router (#10881)

This commit is contained in:
Chang Su
2025-09-25 08:12:06 -07:00
committed by GitHub
parent d511b2d905
commit 916784746b
7 changed files with 172 additions and 141 deletions

View File

@@ -47,24 +47,24 @@ message SamplingParams {
string regex = 13;
string json_schema = 14;
string ebnf_grammar = 15;
string structural_tag = 16;
}
// LoRA adapter
string lora_path = 16;
string lora_path = 17;
// Speculative decoding
int32 n = 17; // Number of samples
int32 n = 18; // Number of samples
// Token healing
bool token_healing = 18;
bool token_healing = 19;
// Additional parameters
int32 min_new_tokens = 19;
bool ignore_eos = 20;
bool no_stop_trim = 21;
int32 stream_interval = 22;
map<string, float> logit_bias = 23;
string structural_tag = 24;
int32 min_new_tokens = 20;
bool ignore_eos = 21;
bool no_stop_trim = 22;
int32 stream_interval = 23;
map<string, float> logit_bias = 24;
// Custom parameters for extensibility
google.protobuf.Struct custom_params = 25;

View File

@@ -241,14 +241,14 @@ impl GrpcRouter {
debug!("Tokenized {} tokens from input", token_ids.len());
// Step 5: Build tool constraints if needed
let structural_tag = if let Some(tools) = &body.tools {
let tool_call_constraint = if let Some(tools) = &body.tools {
self.generate_tool_constraints(tools, &body.tool_choice, &body.model)
} else {
None
};
// Step 6: Build SamplingParams for gRPC
let sampling_params = match self.build_grpc_sampling_params(body, structural_tag) {
let sampling_params = match self.build_grpc_sampling_params(body, tool_call_constraint) {
Ok(params) => params,
Err(e) => {
error!("Failed to build sampling parameters: {}", e);
@@ -286,6 +286,41 @@ impl GrpcRouter {
}
// ============ Helper Methods ============
/// Select a worker for the request
fn select_worker_for_request(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn crate::core::Worker>> {
// Get workers for the specified model, filtered by connection mode
let workers = self.worker_registry.get_workers_filtered(
model_id,
Some(WorkerType::Regular),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // get all workers, we'll filter by is_available() next
);
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn crate::core::Worker>> = workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available.is_empty() {
return None;
}
// Get the appropriate policy for this model
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
// Select worker using the policy
let idx = policy.select_worker(&available, text)?;
Some(available[idx].clone())
}
/// Process chat messages and apply template
fn process_chat_messages(
@@ -516,7 +551,7 @@ impl GrpcRouter {
fn build_grpc_sampling_params(
&self,
request: &ChatCompletionRequest,
structural_tag: Option<String>,
tool_call_constraint: Option<(String, String)>,
) -> Result<proto::SamplingParams, String> {
let stop_sequences = self.extract_stop_strings(request);
@@ -555,8 +590,7 @@ impl GrpcRouter {
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
skip_special_tokens,
n: request.n.unwrap_or(1) as i32,
structural_tag: structural_tag.unwrap_or_default(),
constraint: self.build_constraint(request)?,
constraint: self.build_constraint(request, tool_call_constraint)?,
..Default::default()
})
}
@@ -574,28 +608,48 @@ impl GrpcRouter {
fn build_constraint(
&self,
request: &ChatCompletionRequest,
tool_call_constraint: Option<(String, String)>,
) -> Result<Option<proto::sampling_params::Constraint>, String> {
let mut constraints = Vec::new();
if let Some(ResponseFormat::JsonSchema { json_schema }) = &request.response_format {
let schema_str = serde_json::to_string(&json_schema.schema)
.map_err(|e| format!("Failed to serialize JSON schema: {}", e))?;
return Ok(Some(proto::sampling_params::Constraint::JsonSchema(
schema_str,
)));
constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str));
}
if let Some(ebnf) = &request.ebnf {
return Ok(Some(proto::sampling_params::Constraint::EbnfGrammar(
constraints.push(proto::sampling_params::Constraint::EbnfGrammar(
ebnf.clone(),
)));
));
}
if let Some(regex) = &request.regex {
return Ok(Some(proto::sampling_params::Constraint::Regex(
regex.clone(),
)));
constraints.push(proto::sampling_params::Constraint::Regex(regex.clone()));
}
Ok(None)
// Handle tool call constraint
if let Some((constraint_type, constraint_value)) = tool_call_constraint {
if !constraints.is_empty() {
return Err("Constrained decoding is not compatible with tool calls.".to_string());
}
let tool_constraint = match constraint_type.as_str() {
"structural_tag" => {
proto::sampling_params::Constraint::StructuralTag(constraint_value)
}
"json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value),
"ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value),
"regex" => proto::sampling_params::Constraint::Regex(constraint_value),
_ => return Err(format!("Unknown constraint type: {}", constraint_type)),
};
constraints.push(tool_constraint);
}
match constraints.len() {
0 => Ok(None),
1 => Ok(constraints.pop()),
_ => Err("Multiple constraints are not allowed.".to_string()),
}
}
/// Generate tool constraints for structured generation
@@ -604,52 +658,19 @@ impl GrpcRouter {
_tools: &[crate::protocols::spec::Tool],
_tool_choice: &Option<crate::protocols::spec::ToolChoice>,
model: &str,
) -> Option<String> {
) -> Option<(String, String)> {
let _parser = self.tool_parser_registry.get_parser(model)?;
// TODO: Implement actual constraint generation logic
// For now, return None as this is placeholder implementation
None
}
/// Select a worker for the request
fn select_worker_for_request(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn crate::core::Worker>> {
// Get workers for the specified model, filtered by connection mode
let workers = self.worker_registry.get_workers_filtered(
model_id,
Some(WorkerType::Regular),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // get all workers, we'll filter by is_available() next
);
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn crate::core::Worker>> = workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available.is_empty() {
return None;
}
// Get the appropriate policy for this model
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
// Select worker using the policy
let idx = policy.select_worker(&available, text)?;
Some(available[idx].clone())
}
/// Get or create a gRPC client for the worker
async fn get_or_create_grpc_client(
&self,
worker_url: &str,
) -> Result<SglangSchedulerClient, String> {
// TODO: move to worker
debug!("Creating new gRPC client for worker: {}", worker_url);
SglangSchedulerClient::connect(worker_url)
.await