[router] update generate spec to align with sgl io struct (#11591)
This commit is contained in:
@@ -28,15 +28,38 @@ fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
|
|||||||
fn default_generate_request() -> GenerateRequest {
|
fn default_generate_request() -> GenerateRequest {
|
||||||
GenerateRequest {
|
GenerateRequest {
|
||||||
text: None,
|
text: None,
|
||||||
prompt: None,
|
|
||||||
input_ids: None,
|
input_ids: None,
|
||||||
stream: false,
|
input_embeds: None,
|
||||||
|
image_data: None,
|
||||||
|
video_data: None,
|
||||||
|
audio_data: None,
|
||||||
sampling_params: None,
|
sampling_params: None,
|
||||||
return_logprob: false,
|
return_logprob: None,
|
||||||
// SGLang Extensions
|
logprob_start_len: None,
|
||||||
lora_path: None,
|
top_logprobs_num: None,
|
||||||
session_params: None,
|
token_ids_logprob: None,
|
||||||
|
return_text_in_logprobs: false,
|
||||||
|
stream: false,
|
||||||
|
log_metrics: true,
|
||||||
return_hidden_states: false,
|
return_hidden_states: false,
|
||||||
|
modalities: None,
|
||||||
|
session_params: None,
|
||||||
|
lora_path: None,
|
||||||
|
lora_id: None,
|
||||||
|
custom_logit_processor: None,
|
||||||
|
bootstrap_host: None,
|
||||||
|
bootstrap_port: None,
|
||||||
|
bootstrap_room: None,
|
||||||
|
bootstrap_pair_key: None,
|
||||||
|
data_parallel_rank: None,
|
||||||
|
background: false,
|
||||||
|
conversation_id: None,
|
||||||
|
priority: None,
|
||||||
|
extra_key: None,
|
||||||
|
no_logs: false,
|
||||||
|
custom_labels: None,
|
||||||
|
return_bytes: false,
|
||||||
|
return_entropy: false,
|
||||||
rid: None,
|
rid: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,6 +124,7 @@ fn create_sample_generate_request() -> GenerateRequest {
|
|||||||
GenerateRequest {
|
GenerateRequest {
|
||||||
text: Some("Write a story about artificial intelligence".to_string()),
|
text: Some("Write a story about artificial intelligence".to_string()),
|
||||||
sampling_params: Some(SamplingParams {
|
sampling_params: Some(SamplingParams {
|
||||||
|
max_new_tokens: Some(100),
|
||||||
temperature: Some(0.8),
|
temperature: Some(0.8),
|
||||||
top_p: Some(0.9),
|
top_p: Some(0.9),
|
||||||
top_k: Some(50),
|
top_k: Some(50),
|
||||||
|
|||||||
@@ -280,13 +280,13 @@ impl SglangSchedulerClient {
|
|||||||
input_ids: token_ids,
|
input_ids: token_ids,
|
||||||
}),
|
}),
|
||||||
sampling_params: Some(sampling_params),
|
sampling_params: Some(sampling_params),
|
||||||
return_logprob: body.return_logprob,
|
return_logprob: body.return_logprob.unwrap_or(false),
|
||||||
logprob_start_len: -1,
|
logprob_start_len: body.logprob_start_len.unwrap_or(-1),
|
||||||
top_logprobs_num: 0,
|
top_logprobs_num: body.top_logprobs_num.unwrap_or(0),
|
||||||
token_ids_logprob: vec![],
|
token_ids_logprob: body.token_ids_logprob.clone().unwrap_or_default(),
|
||||||
return_hidden_states: body.return_hidden_states,
|
return_hidden_states: body.return_hidden_states,
|
||||||
stream: body.stream,
|
stream: body.stream,
|
||||||
log_metrics: true,
|
log_metrics: body.log_metrics,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -356,7 +356,7 @@ pub struct ChatCompletionRequest {
|
|||||||
|
|
||||||
/// Path to LoRA adapter(s) for model customization
|
/// Path to LoRA adapter(s) for model customization
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub lora_path: Option<LoRAPath>,
|
pub lora_path: Option<String>,
|
||||||
|
|
||||||
/// Session parameters for continual prompting
|
/// Session parameters for continual prompting
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -905,7 +905,7 @@ pub struct CompletionRequest {
|
|||||||
|
|
||||||
/// Path to LoRA adapter(s) for model customization
|
/// Path to LoRA adapter(s) for model customization
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub lora_path: Option<LoRAPath>,
|
pub lora_path: Option<String>,
|
||||||
|
|
||||||
/// Session parameters for continual prompting
|
/// Session parameters for continual prompting
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -2309,10 +2309,6 @@ fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::Va
|
|||||||
#[derive(Clone, Debug, Serialize, Deserialize, Validate)]
|
#[derive(Clone, Debug, Serialize, Deserialize, Validate)]
|
||||||
#[validate(schema(function = "validate_generate_request"))]
|
#[validate(schema(function = "validate_generate_request"))]
|
||||||
pub struct GenerateRequest {
|
pub struct GenerateRequest {
|
||||||
/// The prompt to generate from (OpenAI style)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub prompt: Option<StringOrArray>,
|
|
||||||
|
|
||||||
/// Text input - SGLang native format
|
/// Text input - SGLang native format
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub text: Option<String>,
|
pub text: Option<String>,
|
||||||
@@ -2321,31 +2317,144 @@ pub struct GenerateRequest {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub input_ids: Option<InputIds>,
|
pub input_ids: Option<InputIds>,
|
||||||
|
|
||||||
|
/// Input embeddings for direct embedding input
|
||||||
|
/// Can be a 2D array (single request) or 3D array (batch of requests)
|
||||||
|
/// Placeholder for future use
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub input_embeds: Option<Value>,
|
||||||
|
|
||||||
|
/// Image input data
|
||||||
|
/// Can be an image instance, file name, URL, or base64 encoded string
|
||||||
|
/// Supports single images, lists of images, or nested lists for batch processing
|
||||||
|
/// Placeholder for future use
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub image_data: Option<Value>,
|
||||||
|
|
||||||
|
/// Video input data
|
||||||
|
/// Can be a file name, URL, or base64 encoded string
|
||||||
|
/// Supports single videos, lists of videos, or nested lists for batch processing
|
||||||
|
/// Placeholder for future use
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub video_data: Option<Value>,
|
||||||
|
|
||||||
|
/// Audio input data
|
||||||
|
/// Can be a file name, URL, or base64 encoded string
|
||||||
|
/// Supports single audio files, lists of audio, or nested lists for batch processing
|
||||||
|
/// Placeholder for future use
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub audio_data: Option<Value>,
|
||||||
|
|
||||||
/// Sampling parameters (sglang style)
|
/// Sampling parameters (sglang style)
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub sampling_params: Option<SamplingParams>,
|
pub sampling_params: Option<SamplingParams>,
|
||||||
|
|
||||||
|
/// Whether to return logprobs
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub return_logprob: Option<bool>,
|
||||||
|
|
||||||
|
/// If return logprobs, the start location in the prompt for returning logprobs.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprob_start_len: Option<i32>,
|
||||||
|
|
||||||
|
/// If return logprobs, the number of top logprobs to return at each position.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_logprobs_num: Option<i32>,
|
||||||
|
|
||||||
|
/// If return logprobs, the token ids to return logprob for.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub token_ids_logprob: Option<Vec<u32>>,
|
||||||
|
|
||||||
|
/// Whether to detokenize tokens in text in the returned logprobs.
|
||||||
|
#[serde(default)]
|
||||||
|
pub return_text_in_logprobs: bool,
|
||||||
|
|
||||||
/// Whether to stream the response
|
/// Whether to stream the response
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
|
|
||||||
/// Whether to return logprobs
|
/// Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||||
#[serde(default)]
|
#[serde(default = "default_true")]
|
||||||
pub return_logprob: bool,
|
pub log_metrics: bool,
|
||||||
|
|
||||||
/// Path to LoRA adapter(s) for model customization
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub lora_path: Option<LoRAPath>,
|
|
||||||
|
|
||||||
/// Session parameters for continual prompting
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub session_params: Option<HashMap<String, Value>>,
|
|
||||||
|
|
||||||
/// Return model hidden states
|
/// Return model hidden states
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub return_hidden_states: bool,
|
pub return_hidden_states: bool,
|
||||||
|
|
||||||
/// Request ID for tracking
|
/// The modalities of the image data [image, multi-images, video]
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub modalities: Option<Vec<String>>,
|
||||||
|
|
||||||
|
/// Session parameters for continual prompting
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub session_params: Option<HashMap<String, Value>>,
|
||||||
|
|
||||||
|
/// Path to LoRA adapter(s) for model customization
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub lora_path: Option<String>,
|
||||||
|
|
||||||
|
/// LoRA adapter ID (if pre-loaded)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub lora_id: Option<String>,
|
||||||
|
|
||||||
|
/// Custom logit processor for advanced sampling control. Must be a serialized instance
|
||||||
|
/// of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
||||||
|
/// Use the processor's `to_str()` method to generate the serialized string.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub custom_logit_processor: Option<String>,
|
||||||
|
|
||||||
|
/// For disaggregated inference
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub bootstrap_host: Option<String>,
|
||||||
|
|
||||||
|
/// For disaggregated inference
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub bootstrap_port: Option<i32>,
|
||||||
|
|
||||||
|
/// For disaggregated inference
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub bootstrap_room: Option<i32>,
|
||||||
|
|
||||||
|
/// For disaggregated inference
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub bootstrap_pair_key: Option<String>,
|
||||||
|
|
||||||
|
/// Data parallel rank routing
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub data_parallel_rank: Option<i32>,
|
||||||
|
|
||||||
|
/// Background response
|
||||||
|
#[serde(default)]
|
||||||
|
pub background: bool,
|
||||||
|
|
||||||
|
/// Conversation ID for tracking
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub conversation_id: Option<String>,
|
||||||
|
|
||||||
|
/// Priority for the request
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub priority: Option<i32>,
|
||||||
|
|
||||||
|
/// Extra key for classifying the request (e.g. cache_salt)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub extra_key: Option<String>,
|
||||||
|
|
||||||
|
/// Whether to disallow logging for this request (e.g. due to ZDR)
|
||||||
|
#[serde(default)]
|
||||||
|
pub no_logs: bool,
|
||||||
|
|
||||||
|
/// Custom metric labels
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub custom_labels: Option<HashMap<String, String>>,
|
||||||
|
|
||||||
|
/// Whether to return bytes for image generation
|
||||||
|
#[serde(default)]
|
||||||
|
pub return_bytes: bool,
|
||||||
|
|
||||||
|
/// Whether to return entropy
|
||||||
|
#[serde(default)]
|
||||||
|
pub return_entropy: bool,
|
||||||
|
|
||||||
|
/// Request ID for tracking (inherited from BaseReq in Python)
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub rid: Option<String>,
|
pub rid: Option<String>,
|
||||||
}
|
}
|
||||||
@@ -2358,7 +2467,7 @@ impl Normalizable for GenerateRequest {
|
|||||||
fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> {
|
fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> {
|
||||||
// Exactly one of text or input_ids must be provided
|
// Exactly one of text or input_ids must be provided
|
||||||
// Note: input_embeds not yet supported in Rust implementation
|
// Note: input_embeds not yet supported in Rust implementation
|
||||||
let has_text = req.text.is_some() || req.prompt.is_some();
|
let has_text = req.text.is_some();
|
||||||
let has_input_ids = req.input_ids.is_some();
|
let has_input_ids = req.input_ids.is_some();
|
||||||
|
|
||||||
let count = [has_text, has_input_ids].iter().filter(|&&x| x).count();
|
let count = [has_text, has_input_ids].iter().filter(|&&x| x).count();
|
||||||
@@ -2389,18 +2498,11 @@ impl GenerationRequest for GenerateRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn extract_text_for_routing(&self) -> String {
|
fn extract_text_for_routing(&self) -> String {
|
||||||
// Check fields in priority order: text, prompt, inputs
|
// Check fields in priority order: text, input_ids
|
||||||
if let Some(ref text) = self.text {
|
if let Some(ref text) = self.text {
|
||||||
return text.clone();
|
return text.clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref prompt) = self.prompt {
|
|
||||||
return match prompt {
|
|
||||||
StringOrArray::String(s) => s.clone(),
|
|
||||||
StringOrArray::Array(v) => v.join(" "),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref input_ids) = self.input_ids {
|
if let Some(ref input_ids) = self.input_ids {
|
||||||
return match input_ids {
|
return match input_ids {
|
||||||
InputIds::Single(ids) => ids
|
InputIds::Single(ids) => ids
|
||||||
|
|||||||
@@ -877,7 +877,7 @@ impl ResponseProcessingStage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Non-streaming: Delegate to ResponseProcessor
|
// Non-streaming: Delegate to ResponseProcessor
|
||||||
let request_logprobs = ctx.generate_request().return_logprob;
|
let request_logprobs = ctx.generate_request().return_logprob.unwrap_or(false);
|
||||||
let generate_request = ctx.generate_request_arc();
|
let generate_request = ctx.generate_request_arc();
|
||||||
|
|
||||||
let stop_decoder = ctx
|
let stop_decoder = ctx
|
||||||
|
|||||||
@@ -616,7 +616,7 @@ impl StreamingProcessor {
|
|||||||
generate_request: Arc<GenerateRequest>,
|
generate_request: Arc<GenerateRequest>,
|
||||||
dispatch: context::DispatchMetadata,
|
dispatch: context::DispatchMetadata,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let return_logprob = generate_request.return_logprob;
|
let return_logprob = generate_request.return_logprob.unwrap_or(false);
|
||||||
|
|
||||||
// Create SSE channel
|
// Create SSE channel
|
||||||
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
|
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
|
||||||
|
|||||||
@@ -150,11 +150,6 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> {
|
fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> {
|
||||||
if let Some(StringOrArray::Array(arr)) = &req.prompt {
|
|
||||||
if !arr.is_empty() {
|
|
||||||
return Some(arr.len());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some(text) = &req.text {
|
if let Some(text) = &req.text {
|
||||||
if text.contains("[") && text.contains("]") {
|
if text.contains("[") && text.contains("]") {
|
||||||
return None;
|
return None;
|
||||||
@@ -1061,18 +1056,10 @@ impl RouterTrait for PDRouter {
|
|||||||
model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let is_stream = body.stream;
|
let is_stream = body.stream;
|
||||||
let return_logprob = body.return_logprob;
|
let return_logprob = body.return_logprob.unwrap_or(false);
|
||||||
|
|
||||||
let request_text = if self.policies_need_request_text() {
|
let request_text = if self.policies_need_request_text() {
|
||||||
body.text
|
body.text.as_deref().map(|s| s.to_string())
|
||||||
.as_deref()
|
|
||||||
.or_else(|| {
|
|
||||||
body.prompt.as_ref().and_then(|p| match p {
|
|
||||||
StringOrArray::String(s) => Some(s.as_str()),
|
|
||||||
StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.map(|s| s.to_string())
|
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -598,15 +598,39 @@ async fn test_unsupported_endpoints() {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
prompt: None,
|
|
||||||
text: Some("Hello world".to_string()),
|
text: Some("Hello world".to_string()),
|
||||||
input_ids: None,
|
input_ids: None,
|
||||||
|
input_embeds: None,
|
||||||
|
image_data: None,
|
||||||
|
video_data: None,
|
||||||
|
audio_data: None,
|
||||||
sampling_params: None,
|
sampling_params: None,
|
||||||
stream: false,
|
stream: false,
|
||||||
return_logprob: false,
|
return_logprob: Some(false),
|
||||||
lora_path: None,
|
logprob_start_len: None,
|
||||||
session_params: None,
|
top_logprobs_num: None,
|
||||||
|
token_ids_logprob: None,
|
||||||
|
return_text_in_logprobs: false,
|
||||||
|
log_metrics: true,
|
||||||
return_hidden_states: false,
|
return_hidden_states: false,
|
||||||
|
modalities: None,
|
||||||
|
session_params: None,
|
||||||
|
lora_path: None,
|
||||||
|
lora_id: None,
|
||||||
|
custom_logit_processor: None,
|
||||||
|
bootstrap_host: None,
|
||||||
|
bootstrap_port: None,
|
||||||
|
bootstrap_room: None,
|
||||||
|
bootstrap_pair_key: None,
|
||||||
|
data_parallel_rank: None,
|
||||||
|
background: false,
|
||||||
|
conversation_id: None,
|
||||||
|
priority: None,
|
||||||
|
extra_key: None,
|
||||||
|
no_logs: false,
|
||||||
|
custom_labels: None,
|
||||||
|
return_bytes: false,
|
||||||
|
return_entropy: false,
|
||||||
rid: None,
|
rid: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user