[router] grpc router generate endpoint support (#11070)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
use std::convert::TryFrom;
|
||||
use std::time::Duration;
|
||||
use tonic::{transport::Channel, Request};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::protocols::spec::{ChatCompletionRequest, ResponseFormat};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, GenerateRequest, ResponseFormat,
|
||||
SamplingParams as GenerateSamplingParams, StringOrArray,
|
||||
};
|
||||
|
||||
// Include the generated protobuf code
|
||||
pub mod proto {
|
||||
@@ -112,6 +116,37 @@ impl SglangSchedulerClient {
|
||||
Ok(grpc_request)
|
||||
}
|
||||
|
||||
/// Build a basic GenerateRequest from the SGLang spec GenerateRequest
|
||||
pub fn build_plain_generate_request(
|
||||
&self,
|
||||
request_id: String,
|
||||
body: &GenerateRequest,
|
||||
original_text: Option<String>,
|
||||
token_ids: Vec<u32>,
|
||||
) -> Result<proto::GenerateRequest, String> {
|
||||
let sampling_params =
|
||||
Self::build_sampling_params_from_plain(body.sampling_params.as_ref())?;
|
||||
|
||||
let grpc_request = proto::GenerateRequest {
|
||||
request_id,
|
||||
tokenized: Some(proto::TokenizedInput {
|
||||
original_text: original_text.unwrap_or_default(),
|
||||
input_ids: token_ids,
|
||||
}),
|
||||
sampling_params: Some(sampling_params),
|
||||
return_logprob: body.return_logprob,
|
||||
logprob_start_len: -1,
|
||||
top_logprobs_num: 0,
|
||||
token_ids_logprob: vec![],
|
||||
return_hidden_states: body.return_hidden_states,
|
||||
stream: body.stream,
|
||||
log_metrics: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Ok(grpc_request)
|
||||
}
|
||||
|
||||
/// Build gRPC SamplingParams from OpenAI request
|
||||
fn build_grpc_sampling_params(
|
||||
&self,
|
||||
@@ -165,8 +200,8 @@ impl SglangSchedulerClient {
|
||||
/// Extract stop strings from request
|
||||
fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec<String> {
|
||||
match &request.stop {
|
||||
Some(crate::protocols::spec::StringOrArray::String(s)) => vec![s.clone()],
|
||||
Some(crate::protocols::spec::StringOrArray::Array(arr)) => arr.clone(),
|
||||
Some(StringOrArray::String(s)) => vec![s.clone()],
|
||||
Some(StringOrArray::Array(arr)) => arr.clone(),
|
||||
None => vec![],
|
||||
}
|
||||
}
|
||||
@@ -218,6 +253,100 @@ impl SglangSchedulerClient {
|
||||
_ => Err("Multiple constraints are not allowed.".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_single_constraint_from_plain(
|
||||
params: &GenerateSamplingParams,
|
||||
) -> Result<Option<proto::sampling_params::Constraint>, String> {
|
||||
let mut constraints = Vec::new();
|
||||
if let Some(json_schema) = ¶ms.json_schema {
|
||||
constraints.push(proto::sampling_params::Constraint::JsonSchema(
|
||||
json_schema.clone(),
|
||||
));
|
||||
}
|
||||
if let Some(regex) = ¶ms.regex {
|
||||
constraints.push(proto::sampling_params::Constraint::Regex(regex.clone()));
|
||||
}
|
||||
if let Some(ebnf) = ¶ms.ebnf {
|
||||
constraints.push(proto::sampling_params::Constraint::EbnfGrammar(
|
||||
ebnf.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
match constraints.len() {
|
||||
0 => Ok(None),
|
||||
1 => Ok(constraints.pop()),
|
||||
_ => Err("Multiple structured constraints are not allowed".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_sampling_params_from_plain(
|
||||
params: Option<&GenerateSamplingParams>,
|
||||
) -> Result<proto::SamplingParams, String> {
|
||||
let mut sampling = proto::SamplingParams {
|
||||
temperature: 1.0,
|
||||
top_p: 1.0,
|
||||
top_k: -1,
|
||||
repetition_penalty: 1.0,
|
||||
n: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let Some(p) = params else {
|
||||
return Ok(sampling);
|
||||
};
|
||||
|
||||
// Simple field mappings using a macro
|
||||
macro_rules! map_field {
|
||||
($field:ident) => {
|
||||
if let Some(val) = p.$field {
|
||||
sampling.$field = val;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
map_field!(temperature);
|
||||
map_field!(top_p);
|
||||
map_field!(top_k);
|
||||
map_field!(frequency_penalty);
|
||||
map_field!(presence_penalty);
|
||||
map_field!(repetition_penalty);
|
||||
map_field!(min_p);
|
||||
map_field!(ignore_eos);
|
||||
map_field!(skip_special_tokens);
|
||||
map_field!(no_stop_trim);
|
||||
|
||||
// Handle stop sequences
|
||||
if let Some(stop) = &p.stop {
|
||||
match stop {
|
||||
StringOrArray::String(s) => sampling.stop.push(s.clone()),
|
||||
StringOrArray::Array(arr) => sampling.stop.extend(arr.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
// Handle stop token IDs
|
||||
if let Some(stop_token_ids) = &p.stop_token_ids {
|
||||
sampling.stop_token_ids = stop_token_ids.clone();
|
||||
}
|
||||
|
||||
// Handle max_new_tokens with conversion
|
||||
if let Some(max_new_tokens) = p.max_new_tokens {
|
||||
sampling.max_new_tokens =
|
||||
Some(i32::try_from(max_new_tokens).map_err(|_| {
|
||||
"max_new_tokens must fit into a 32-bit signed integer".to_string()
|
||||
})?);
|
||||
}
|
||||
|
||||
// Handle min_tokens with conversion
|
||||
if let Some(min_tokens) = p.min_tokens {
|
||||
sampling.min_new_tokens = i32::try_from(min_tokens)
|
||||
.map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?;
|
||||
}
|
||||
|
||||
// Handle constraints (exactly one allowed)
|
||||
sampling.constraint = Self::build_single_constraint_from_plain(p)?;
|
||||
|
||||
Ok(sampling)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user