[router] update generate spec to align with sgl io struct (#11591)
This commit is contained in:
@@ -356,7 +356,7 @@ pub struct ChatCompletionRequest {
|
||||
|
||||
/// Path to LoRA adapter(s) for model customization
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub lora_path: Option<LoRAPath>,
|
||||
pub lora_path: Option<String>,
|
||||
|
||||
/// Session parameters for continual prompting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -905,7 +905,7 @@ pub struct CompletionRequest {
|
||||
|
||||
/// Path to LoRA adapter(s) for model customization
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub lora_path: Option<LoRAPath>,
|
||||
pub lora_path: Option<String>,
|
||||
|
||||
/// Session parameters for continual prompting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -2309,10 +2309,6 @@ fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::Va
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Validate)]
|
||||
#[validate(schema(function = "validate_generate_request"))]
|
||||
pub struct GenerateRequest {
|
||||
/// The prompt to generate from (OpenAI style)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<StringOrArray>,
|
||||
|
||||
/// Text input - SGLang native format
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
@@ -2321,31 +2317,144 @@ pub struct GenerateRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub input_ids: Option<InputIds>,
|
||||
|
||||
/// Input embeddings for direct embedding input
|
||||
/// Can be a 2D array (single request) or 3D array (batch of requests)
|
||||
/// Placeholder for future use
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub input_embeds: Option<Value>,
|
||||
|
||||
/// Image input data
|
||||
/// Can be an image instance, file name, URL, or base64 encoded string
|
||||
/// Supports single images, lists of images, or nested lists for batch processing
|
||||
/// Placeholder for future use
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub image_data: Option<Value>,
|
||||
|
||||
/// Video input data
|
||||
/// Can be a file name, URL, or base64 encoded string
|
||||
/// Supports single videos, lists of videos, or nested lists for batch processing
|
||||
/// Placeholder for future use
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub video_data: Option<Value>,
|
||||
|
||||
/// Audio input data
|
||||
/// Can be a file name, URL, or base64 encoded string
|
||||
/// Supports single audio files, lists of audio, or nested lists for batch processing
|
||||
/// Placeholder for future use
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub audio_data: Option<Value>,
|
||||
|
||||
/// Sampling parameters (sglang style)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sampling_params: Option<SamplingParams>,
|
||||
|
||||
/// Whether to return logprobs
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub return_logprob: Option<bool>,
|
||||
|
||||
/// If return logprobs, the start location in the prompt for returning logprobs.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprob_start_len: Option<i32>,
|
||||
|
||||
/// If return logprobs, the number of top logprobs to return at each position.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_logprobs_num: Option<i32>,
|
||||
|
||||
/// If return logprobs, the token ids to return logprob for.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub token_ids_logprob: Option<Vec<u32>>,
|
||||
|
||||
/// Whether to detokenize tokens in text in the returned logprobs.
|
||||
#[serde(default)]
|
||||
pub return_text_in_logprobs: bool,
|
||||
|
||||
/// Whether to stream the response
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
|
||||
/// Whether to return logprobs
|
||||
#[serde(default)]
|
||||
pub return_logprob: bool,
|
||||
|
||||
/// Path to LoRA adapter(s) for model customization
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub lora_path: Option<LoRAPath>,
|
||||
|
||||
/// Session parameters for continual prompting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_params: Option<HashMap<String, Value>>,
|
||||
/// Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||
#[serde(default = "default_true")]
|
||||
pub log_metrics: bool,
|
||||
|
||||
/// Return model hidden states
|
||||
#[serde(default)]
|
||||
pub return_hidden_states: bool,
|
||||
|
||||
/// Request ID for tracking
|
||||
/// The modalities of the image data [image, multi-images, video]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub modalities: Option<Vec<String>>,
|
||||
|
||||
/// Session parameters for continual prompting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_params: Option<HashMap<String, Value>>,
|
||||
|
||||
/// Path to LoRA adapter(s) for model customization
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub lora_path: Option<String>,
|
||||
|
||||
/// LoRA adapter ID (if pre-loaded)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub lora_id: Option<String>,
|
||||
|
||||
/// Custom logit processor for advanced sampling control. Must be a serialized instance
|
||||
/// of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
||||
/// Use the processor's `to_str()` method to generate the serialized string.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub custom_logit_processor: Option<String>,
|
||||
|
||||
/// For disaggregated inference
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bootstrap_host: Option<String>,
|
||||
|
||||
/// For disaggregated inference
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bootstrap_port: Option<i32>,
|
||||
|
||||
/// For disaggregated inference
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bootstrap_room: Option<i32>,
|
||||
|
||||
/// For disaggregated inference
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bootstrap_pair_key: Option<String>,
|
||||
|
||||
/// Data parallel rank routing
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data_parallel_rank: Option<i32>,
|
||||
|
||||
/// Background response
|
||||
#[serde(default)]
|
||||
pub background: bool,
|
||||
|
||||
/// Conversation ID for tracking
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub conversation_id: Option<String>,
|
||||
|
||||
/// Priority for the request
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<i32>,
|
||||
|
||||
/// Extra key for classifying the request (e.g. cache_salt)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub extra_key: Option<String>,
|
||||
|
||||
/// Whether to disallow logging for this request (e.g. due to ZDR)
|
||||
#[serde(default)]
|
||||
pub no_logs: bool,
|
||||
|
||||
/// Custom metric labels
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub custom_labels: Option<HashMap<String, String>>,
|
||||
|
||||
/// Whether to return bytes for image generation
|
||||
#[serde(default)]
|
||||
pub return_bytes: bool,
|
||||
|
||||
/// Whether to return entropy
|
||||
#[serde(default)]
|
||||
pub return_entropy: bool,
|
||||
|
||||
/// Request ID for tracking (inherited from BaseReq in Python)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rid: Option<String>,
|
||||
}
|
||||
@@ -2358,7 +2467,7 @@ impl Normalizable for GenerateRequest {
|
||||
fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> {
|
||||
// Exactly one of text or input_ids must be provided
|
||||
// Note: input_embeds not yet supported in Rust implementation
|
||||
let has_text = req.text.is_some() || req.prompt.is_some();
|
||||
let has_text = req.text.is_some();
|
||||
let has_input_ids = req.input_ids.is_some();
|
||||
|
||||
let count = [has_text, has_input_ids].iter().filter(|&&x| x).count();
|
||||
@@ -2389,18 +2498,11 @@ impl GenerationRequest for GenerateRequest {
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
// Check fields in priority order: text, prompt, inputs
|
||||
// Check fields in priority order: text, input_ids
|
||||
if let Some(ref text) = self.text {
|
||||
return text.clone();
|
||||
}
|
||||
|
||||
if let Some(ref prompt) = self.prompt {
|
||||
return match prompt {
|
||||
StringOrArray::String(s) => s.clone(),
|
||||
StringOrArray::Array(v) => v.join(" "),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(ref input_ids) = self.input_ids {
|
||||
return match input_ids {
|
||||
InputIds::Single(ids) => ids
|
||||
|
||||
Reference in New Issue
Block a user