[router] Move all protocols to spec.rs file (#9519)
This commit is contained in:
@@ -4,6 +4,11 @@ use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Display;
|
||||
|
||||
// Import types from spec module
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, ChatMessage, ResponseFormat, StringOrArray, UserMessageContent,
|
||||
};
|
||||
|
||||
/// Validation constants for OpenAI API parameters
|
||||
pub mod constants {
|
||||
/// Temperature range: 0.0 to 2.0 (OpenAI spec)
|
||||
@@ -257,7 +262,7 @@ pub mod utils {
|
||||
) -> Result<(), ValidationError> {
|
||||
if let Some(stop) = request.get_stop_sequences() {
|
||||
match stop {
|
||||
crate::protocols::common::StringOrArray::String(s) => {
|
||||
StringOrArray::String(s) => {
|
||||
if s.is_empty() {
|
||||
return Err(ValidationError::InvalidValue {
|
||||
parameter: "stop".to_string(),
|
||||
@@ -266,7 +271,7 @@ pub mod utils {
|
||||
});
|
||||
}
|
||||
}
|
||||
crate::protocols::common::StringOrArray::Array(arr) => {
|
||||
StringOrArray::Array(arr) => {
|
||||
validate_max_items(arr, constants::MAX_STOP_SEQUENCES, "stop")?;
|
||||
for (i, s) in arr.iter().enumerate() {
|
||||
if s.is_empty() {
|
||||
@@ -469,7 +474,7 @@ pub trait SamplingOptionsProvider {
|
||||
/// Trait for validating stop conditions
|
||||
pub trait StopConditionsProvider {
|
||||
/// Get stop sequences
|
||||
fn get_stop_sequences(&self) -> Option<&crate::protocols::common::StringOrArray>;
|
||||
fn get_stop_sequences(&self) -> Option<&StringOrArray>;
|
||||
}
|
||||
|
||||
/// Trait for validating token limits
|
||||
@@ -532,25 +537,237 @@ pub trait ValidatableRequest:
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = OPENAI CHAT COMPLETION VALIDATION =
|
||||
// ==================================================================
|
||||
|
||||
impl SamplingOptionsProvider for ChatCompletionRequest {
|
||||
fn get_temperature(&self) -> Option<f32> {
|
||||
self.temperature
|
||||
}
|
||||
fn get_top_p(&self) -> Option<f32> {
|
||||
self.top_p
|
||||
}
|
||||
fn get_frequency_penalty(&self) -> Option<f32> {
|
||||
self.frequency_penalty
|
||||
}
|
||||
fn get_presence_penalty(&self) -> Option<f32> {
|
||||
self.presence_penalty
|
||||
}
|
||||
}
|
||||
|
||||
impl StopConditionsProvider for ChatCompletionRequest {
|
||||
fn get_stop_sequences(&self) -> Option<&StringOrArray> {
|
||||
self.stop.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenLimitsProvider for ChatCompletionRequest {
|
||||
fn get_max_tokens(&self) -> Option<u32> {
|
||||
// Prefer max_completion_tokens over max_tokens if both are set
|
||||
self.max_completion_tokens.or(self.max_tokens)
|
||||
}
|
||||
|
||||
fn get_min_tokens(&self) -> Option<u32> {
|
||||
self.min_tokens
|
||||
}
|
||||
}
|
||||
|
||||
impl LogProbsProvider for ChatCompletionRequest {
|
||||
fn get_logprobs(&self) -> Option<u32> {
|
||||
// For chat API, logprobs is a boolean, return 1 if true for validation purposes
|
||||
if self.logprobs {
|
||||
Some(1)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_top_logprobs(&self) -> Option<u32> {
|
||||
self.top_logprobs
|
||||
}
|
||||
}
|
||||
|
||||
impl SGLangExtensionsProvider for ChatCompletionRequest {
|
||||
fn get_top_k(&self) -> Option<i32> {
|
||||
self.top_k
|
||||
}
|
||||
|
||||
fn get_min_p(&self) -> Option<f32> {
|
||||
self.min_p
|
||||
}
|
||||
|
||||
fn get_repetition_penalty(&self) -> Option<f32> {
|
||||
self.repetition_penalty
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionCountProvider for ChatCompletionRequest {
|
||||
fn get_n(&self) -> Option<u32> {
|
||||
self.n
|
||||
}
|
||||
}
|
||||
|
||||
impl ChatCompletionRequest {
|
||||
/// Validate message-specific requirements
|
||||
pub fn validate_messages(&self) -> Result<(), ValidationError> {
|
||||
// Ensure messages array is not empty
|
||||
utils::validate_non_empty_array(&self.messages, "messages")?;
|
||||
|
||||
// Validate message content is not empty
|
||||
for (i, msg) in self.messages.iter().enumerate() {
|
||||
if let ChatMessage::User { content, .. } = msg {
|
||||
match content {
|
||||
UserMessageContent::Text(text) if text.is_empty() => {
|
||||
return Err(ValidationError::InvalidValue {
|
||||
parameter: format!("messages[{}].content", i),
|
||||
value: "empty".to_string(),
|
||||
reason: "message content cannot be empty".to_string(),
|
||||
});
|
||||
}
|
||||
UserMessageContent::Parts(parts) if parts.is_empty() => {
|
||||
return Err(ValidationError::InvalidValue {
|
||||
parameter: format!("messages[{}].content", i),
|
||||
value: "empty array".to_string(),
|
||||
reason: "message content parts cannot be empty".to_string(),
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate response format if specified
|
||||
pub fn validate_response_format(&self) -> Result<(), ValidationError> {
|
||||
if let Some(ResponseFormat::JsonSchema { json_schema }) = &self.response_format {
|
||||
if json_schema.name.is_empty() {
|
||||
return Err(ValidationError::InvalidValue {
|
||||
parameter: "response_format.json_schema.name".to_string(),
|
||||
value: "empty".to_string(),
|
||||
reason: "JSON schema name cannot be empty".to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate chat API specific logprobs requirements
|
||||
pub fn validate_chat_logprobs(&self) -> Result<(), ValidationError> {
|
||||
// In chat API, if logprobs=true, top_logprobs must be specified
|
||||
if self.logprobs && self.top_logprobs.is_none() {
|
||||
return Err(ValidationError::MissingRequired {
|
||||
parameter: "top_logprobs".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// If top_logprobs is specified, logprobs should be true
|
||||
if self.top_logprobs.is_some() && !self.logprobs {
|
||||
return Err(ValidationError::InvalidValue {
|
||||
parameter: "logprobs".to_string(),
|
||||
value: "false".to_string(),
|
||||
reason: "must be true when top_logprobs is specified".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate cross-parameter relationships specific to chat completions
|
||||
pub fn validate_chat_cross_parameters(&self) -> Result<(), ValidationError> {
|
||||
// Validate that both max_tokens and max_completion_tokens aren't set
|
||||
utils::validate_conflicting_parameters(
|
||||
"max_tokens",
|
||||
self.max_tokens.is_some(),
|
||||
"max_completion_tokens",
|
||||
self.max_completion_tokens.is_some(),
|
||||
"cannot specify both max_tokens and max_completion_tokens",
|
||||
)?;
|
||||
|
||||
// Validate that tools and functions aren't both specified (deprecated)
|
||||
utils::validate_conflicting_parameters(
|
||||
"tools",
|
||||
self.tools.is_some(),
|
||||
"functions",
|
||||
self.functions.is_some(),
|
||||
"functions is deprecated, use tools instead",
|
||||
)?;
|
||||
|
||||
// Validate structured output constraints don't conflict with JSON response format
|
||||
let has_json_format = matches!(
|
||||
self.response_format,
|
||||
Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
|
||||
);
|
||||
|
||||
utils::validate_conflicting_parameters(
|
||||
"response_format",
|
||||
has_json_format,
|
||||
"regex",
|
||||
self.regex.is_some(),
|
||||
"cannot use regex constraint with JSON response format",
|
||||
)?;
|
||||
|
||||
utils::validate_conflicting_parameters(
|
||||
"response_format",
|
||||
has_json_format,
|
||||
"ebnf",
|
||||
self.ebnf.is_some(),
|
||||
"cannot use EBNF constraint with JSON response format",
|
||||
)?;
|
||||
|
||||
// Only one structured output constraint should be active
|
||||
let structured_constraints = [
|
||||
("regex", self.regex.is_some()),
|
||||
("ebnf", self.ebnf.is_some()),
|
||||
(
|
||||
"json_schema",
|
||||
matches!(
|
||||
self.response_format,
|
||||
Some(ResponseFormat::JsonSchema { .. })
|
||||
),
|
||||
),
|
||||
];
|
||||
|
||||
utils::validate_mutually_exclusive_options(
|
||||
&structured_constraints,
|
||||
"Only one structured output constraint (regex, ebnf, or json_schema) can be active at a time",
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ValidatableRequest for ChatCompletionRequest {
|
||||
fn validate(&self) -> Result<(), ValidationError> {
|
||||
// Call the common validation function from the validation module
|
||||
utils::validate_common_request_params(self)?;
|
||||
|
||||
// Then validate chat-specific parameters
|
||||
self.validate_messages()?;
|
||||
self.validate_response_format()?;
|
||||
self.validate_chat_logprobs()?;
|
||||
self.validate_chat_cross_parameters()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::constants::*;
|
||||
use super::utils::*;
|
||||
use super::*;
|
||||
use crate::protocols::common::StringOrArray;
|
||||
use crate::protocols::spec::StringOrArray;
|
||||
|
||||
// Mock request type for testing validation traits
|
||||
#[derive(Debug, Default)]
|
||||
struct MockRequest {
|
||||
temperature: Option<f32>,
|
||||
top_p: Option<f32>,
|
||||
frequency_penalty: Option<f32>,
|
||||
presence_penalty: Option<f32>,
|
||||
stop: Option<StringOrArray>,
|
||||
max_tokens: Option<u32>,
|
||||
min_tokens: Option<u32>,
|
||||
logprobs: Option<u32>,
|
||||
top_logprobs: Option<u32>,
|
||||
}
|
||||
|
||||
impl SamplingOptionsProvider for MockRequest {
|
||||
@@ -558,13 +775,13 @@ mod tests {
|
||||
self.temperature
|
||||
}
|
||||
fn get_top_p(&self) -> Option<f32> {
|
||||
self.top_p
|
||||
None
|
||||
}
|
||||
fn get_frequency_penalty(&self) -> Option<f32> {
|
||||
self.frequency_penalty
|
||||
None
|
||||
}
|
||||
fn get_presence_penalty(&self) -> Option<f32> {
|
||||
self.presence_penalty
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -585,97 +802,36 @@ mod tests {
|
||||
|
||||
impl LogProbsProvider for MockRequest {
|
||||
fn get_logprobs(&self) -> Option<u32> {
|
||||
self.logprobs
|
||||
None
|
||||
}
|
||||
fn get_top_logprobs(&self) -> Option<u32> {
|
||||
self.top_logprobs
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl SGLangExtensionsProvider for MockRequest {
|
||||
// Default implementations return None, so no custom logic needed
|
||||
}
|
||||
|
||||
impl CompletionCountProvider for MockRequest {
|
||||
// Default implementation returns None, so no custom logic needed
|
||||
}
|
||||
|
||||
impl SGLangExtensionsProvider for MockRequest {}
|
||||
impl CompletionCountProvider for MockRequest {}
|
||||
impl ValidatableRequest for MockRequest {}
|
||||
|
||||
#[test]
|
||||
fn test_validate_range_valid() {
|
||||
let result = validate_range(1.5f32, &TEMPERATURE_RANGE, "temperature");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), 1.5f32);
|
||||
fn test_range_validation() {
|
||||
// Valid range
|
||||
assert!(validate_range(1.5f32, &TEMPERATURE_RANGE, "temperature").is_ok());
|
||||
// Invalid range
|
||||
assert!(validate_range(-0.1f32, &TEMPERATURE_RANGE, "temperature").is_err());
|
||||
assert!(validate_range(3.0f32, &TEMPERATURE_RANGE, "temperature").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_range_too_low() {
|
||||
let result = validate_range(-0.1f32, &TEMPERATURE_RANGE, "temperature");
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
ValidationError::OutOfRange { parameter, .. } => {
|
||||
assert_eq!(parameter, "temperature");
|
||||
}
|
||||
_ => panic!("Expected OutOfRange error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_positive_valid() {
|
||||
let result = validate_positive(5i32, "max_tokens");
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), 5i32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_max_items_valid() {
|
||||
let items = vec!["stop1", "stop2"];
|
||||
let result = validate_max_items(&items, MAX_STOP_SEQUENCES, "stop");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_top_k() {
|
||||
fn test_sglang_top_k_validation() {
|
||||
assert!(validate_top_k(-1).is_ok()); // Disabled
|
||||
assert!(validate_top_k(50).is_ok()); // Positive
|
||||
assert!(validate_top_k(50).is_ok()); // Valid positive
|
||||
assert!(validate_top_k(0).is_err()); // Invalid
|
||||
assert!(validate_top_k(-5).is_err()); // Invalid
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_request() {
|
||||
let request = MockRequest {
|
||||
temperature: Some(1.0),
|
||||
top_p: Some(0.9),
|
||||
frequency_penalty: Some(0.5),
|
||||
presence_penalty: Some(-0.5),
|
||||
stop: Some(StringOrArray::Array(vec![
|
||||
"stop1".to_string(),
|
||||
"stop2".to_string(),
|
||||
])),
|
||||
max_tokens: Some(100),
|
||||
min_tokens: Some(10),
|
||||
logprobs: Some(3),
|
||||
top_logprobs: Some(15),
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_temperature() {
|
||||
let request = MockRequest {
|
||||
temperature: Some(3.0), // Invalid: too high
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = request.validate();
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_too_many_stop_sequences() {
|
||||
fn test_stop_sequences_limits() {
|
||||
let request = MockRequest {
|
||||
stop: Some(StringOrArray::Array(vec![
|
||||
"stop1".to_string(),
|
||||
@@ -686,72 +842,322 @@ mod tests {
|
||||
])),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = request.validate();
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
ValidationError::TooManyItems {
|
||||
parameter,
|
||||
count,
|
||||
max,
|
||||
} => {
|
||||
assert_eq!(parameter, "stop");
|
||||
assert_eq!(count, 5);
|
||||
assert_eq!(max, MAX_STOP_SEQUENCES);
|
||||
}
|
||||
_ => panic!("Expected TooManyItems error"),
|
||||
}
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conflicting_token_limits() {
|
||||
fn test_token_limits_conflict() {
|
||||
let request = MockRequest {
|
||||
min_tokens: Some(100),
|
||||
max_tokens: Some(50), // Invalid: min > max
|
||||
max_tokens: Some(50), // min > max
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = request.validate();
|
||||
assert!(result.is_err());
|
||||
match result.unwrap_err() {
|
||||
ValidationError::ConflictingParameters {
|
||||
parameter1,
|
||||
parameter2,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(parameter1, "min_tokens");
|
||||
assert_eq!(parameter2, "max_tokens");
|
||||
}
|
||||
_ => panic!("Expected ConflictingParameters error"),
|
||||
}
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_boundary_values() {
|
||||
fn test_valid_request() {
|
||||
let request = MockRequest {
|
||||
temperature: Some(0.0), // Boundary: minimum
|
||||
top_p: Some(1.0), // Boundary: maximum
|
||||
frequency_penalty: Some(-2.0), // Boundary: minimum
|
||||
presence_penalty: Some(2.0), // Boundary: maximum
|
||||
logprobs: Some(0), // Boundary: minimum
|
||||
top_logprobs: Some(20), // Boundary: maximum
|
||||
..Default::default()
|
||||
temperature: Some(1.0),
|
||||
stop: Some(StringOrArray::Array(vec!["stop".to_string()])),
|
||||
max_tokens: Some(100),
|
||||
min_tokens: Some(10),
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validation_error_display() {
|
||||
let error = ValidationError::OutOfRange {
|
||||
parameter: "temperature".to_string(),
|
||||
value: "3.0".to_string(),
|
||||
min: "0.0".to_string(),
|
||||
max: "2.0".to_string(),
|
||||
};
|
||||
// Chat completion specific tests
|
||||
#[cfg(test)]
|
||||
mod chat_tests {
|
||||
use super::*;
|
||||
|
||||
let message = format!("{}", error);
|
||||
assert!(message.contains("temperature"));
|
||||
assert!(message.contains("3.0"));
|
||||
fn create_valid_chat_request() -> ChatCompletionRequest {
|
||||
ChatCompletionRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text("Hello".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
temperature: Some(1.0),
|
||||
top_p: Some(0.9),
|
||||
n: Some(1),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
stop: None,
|
||||
max_tokens: Some(100),
|
||||
max_completion_tokens: None,
|
||||
presence_penalty: Some(0.0),
|
||||
frequency_penalty: Some(0.0),
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
logprobs: false,
|
||||
top_logprobs: None,
|
||||
response_format: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
functions: None,
|
||||
function_call: None,
|
||||
// SGLang extensions
|
||||
top_k: None,
|
||||
min_p: None,
|
||||
min_tokens: None,
|
||||
repetition_penalty: None,
|
||||
regex: None,
|
||||
ebnf: None,
|
||||
stop_token_ids: None,
|
||||
no_stop_trim: false,
|
||||
ignore_eos: false,
|
||||
continue_final_message: false,
|
||||
skip_special_tokens: true,
|
||||
lora_path: None,
|
||||
session_params: None,
|
||||
separate_reasoning: true,
|
||||
stream_reasoning: true,
|
||||
return_hidden_states: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_validation_basics() {
|
||||
// Valid request
|
||||
assert!(create_valid_chat_request().validate().is_ok());
|
||||
|
||||
// Empty messages
|
||||
let mut request = create_valid_chat_request();
|
||||
request.messages = vec![];
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Invalid temperature
|
||||
let mut request = create_valid_chat_request();
|
||||
request.temperature = Some(3.0);
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_conflicts() {
|
||||
let mut request = create_valid_chat_request();
|
||||
|
||||
// Conflicting max_tokens
|
||||
request.max_tokens = Some(100);
|
||||
request.max_completion_tokens = Some(200);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Logprobs without top_logprobs
|
||||
request.max_tokens = None;
|
||||
request.logprobs = true;
|
||||
request.top_logprobs = None;
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sglang_extensions() {
|
||||
let mut request = create_valid_chat_request();
|
||||
|
||||
// Valid SGLang parameters
|
||||
request.top_k = Some(-1);
|
||||
request.min_p = Some(0.1);
|
||||
request.repetition_penalty = Some(1.2);
|
||||
assert!(request.validate().is_ok());
|
||||
|
||||
// Invalid parameters
|
||||
request.top_k = Some(0); // Invalid
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameter_ranges() {
|
||||
let mut request = create_valid_chat_request();
|
||||
|
||||
// Test temperature range (0.0 to 2.0)
|
||||
request.temperature = Some(1.5);
|
||||
assert!(request.validate().is_ok());
|
||||
request.temperature = Some(-0.1);
|
||||
assert!(request.validate().is_err());
|
||||
request.temperature = Some(3.0);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Test top_p range (0.0 to 1.0)
|
||||
request.temperature = Some(1.0); // Reset
|
||||
request.top_p = Some(0.9);
|
||||
assert!(request.validate().is_ok());
|
||||
request.top_p = Some(-0.1);
|
||||
assert!(request.validate().is_err());
|
||||
request.top_p = Some(1.5);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Test frequency_penalty range (-2.0 to 2.0)
|
||||
request.top_p = Some(0.9); // Reset
|
||||
request.frequency_penalty = Some(1.5);
|
||||
assert!(request.validate().is_ok());
|
||||
request.frequency_penalty = Some(-2.5);
|
||||
assert!(request.validate().is_err());
|
||||
request.frequency_penalty = Some(3.0);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Test presence_penalty range (-2.0 to 2.0)
|
||||
request.frequency_penalty = Some(0.0); // Reset
|
||||
request.presence_penalty = Some(-1.5);
|
||||
assert!(request.validate().is_ok());
|
||||
request.presence_penalty = Some(-3.0);
|
||||
assert!(request.validate().is_err());
|
||||
request.presence_penalty = Some(2.5);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Test repetition_penalty range (0.0 to 2.0)
|
||||
request.presence_penalty = Some(0.0); // Reset
|
||||
request.repetition_penalty = Some(1.2);
|
||||
assert!(request.validate().is_ok());
|
||||
request.repetition_penalty = Some(-0.1);
|
||||
assert!(request.validate().is_err());
|
||||
request.repetition_penalty = Some(2.1);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Test min_p range (0.0 to 1.0)
|
||||
request.repetition_penalty = Some(1.0); // Reset
|
||||
request.min_p = Some(0.5);
|
||||
assert!(request.validate().is_ok());
|
||||
request.min_p = Some(-0.1);
|
||||
assert!(request.validate().is_err());
|
||||
request.min_p = Some(1.5);
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_structured_output_conflicts() {
|
||||
let mut request = create_valid_chat_request();
|
||||
|
||||
// JSON response format with regex should conflict
|
||||
request.response_format = Some(ResponseFormat::JsonObject);
|
||||
request.regex = Some(".*".to_string());
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// JSON response format with EBNF should conflict
|
||||
request.regex = None;
|
||||
request.ebnf = Some("grammar".to_string());
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Multiple structured constraints should conflict
|
||||
request.response_format = None;
|
||||
request.regex = Some(".*".to_string());
|
||||
request.ebnf = Some("grammar".to_string());
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Only one constraint should work
|
||||
request.ebnf = None;
|
||||
request.regex = Some(".*".to_string());
|
||||
assert!(request.validate().is_ok());
|
||||
|
||||
request.regex = None;
|
||||
request.ebnf = Some("grammar".to_string());
|
||||
assert!(request.validate().is_ok());
|
||||
|
||||
request.ebnf = None;
|
||||
request.response_format = Some(ResponseFormat::JsonObject);
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stop_sequences_validation() {
|
||||
let mut request = create_valid_chat_request();
|
||||
|
||||
// Valid stop sequences
|
||||
request.stop = Some(StringOrArray::Array(vec![
|
||||
"stop1".to_string(),
|
||||
"stop2".to_string(),
|
||||
]));
|
||||
assert!(request.validate().is_ok());
|
||||
|
||||
// Too many stop sequences (max 4)
|
||||
request.stop = Some(StringOrArray::Array(vec![
|
||||
"stop1".to_string(),
|
||||
"stop2".to_string(),
|
||||
"stop3".to_string(),
|
||||
"stop4".to_string(),
|
||||
"stop5".to_string(),
|
||||
]));
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Empty stop sequence should fail
|
||||
request.stop = Some(StringOrArray::String("".to_string()));
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Empty string in array should fail
|
||||
request.stop = Some(StringOrArray::Array(vec![
|
||||
"stop1".to_string(),
|
||||
"".to_string(),
|
||||
]));
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_logprobs_validation() {
|
||||
let mut request = create_valid_chat_request();
|
||||
|
||||
// Valid logprobs configuration
|
||||
request.logprobs = true;
|
||||
request.top_logprobs = Some(10);
|
||||
assert!(request.validate().is_ok());
|
||||
|
||||
// logprobs=true without top_logprobs should fail
|
||||
request.top_logprobs = None;
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// top_logprobs without logprobs=true should fail
|
||||
request.logprobs = false;
|
||||
request.top_logprobs = Some(10);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// top_logprobs out of range (0-20)
|
||||
request.logprobs = true;
|
||||
request.top_logprobs = Some(25);
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_n_parameter_validation() {
|
||||
let mut request = create_valid_chat_request();
|
||||
|
||||
// Valid n values (1-10)
|
||||
request.n = Some(1);
|
||||
assert!(request.validate().is_ok());
|
||||
request.n = Some(5);
|
||||
assert!(request.validate().is_ok());
|
||||
request.n = Some(10);
|
||||
assert!(request.validate().is_ok());
|
||||
|
||||
// Invalid n values
|
||||
request.n = Some(0);
|
||||
assert!(request.validate().is_err());
|
||||
request.n = Some(15);
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_max_tokens_validation() {
|
||||
let mut request = create_valid_chat_request();
|
||||
|
||||
// Valid token limits
|
||||
request.min_tokens = Some(10);
|
||||
request.max_tokens = Some(100);
|
||||
assert!(request.validate().is_ok());
|
||||
|
||||
// min_tokens > max_tokens should fail
|
||||
request.min_tokens = Some(150);
|
||||
request.max_tokens = Some(100);
|
||||
assert!(request.validate().is_err());
|
||||
|
||||
// Should work with max_completion_tokens instead
|
||||
request.max_tokens = None;
|
||||
request.max_completion_tokens = Some(200);
|
||||
request.min_tokens = Some(50);
|
||||
assert!(request.validate().is_ok());
|
||||
|
||||
// min_tokens > max_completion_tokens should fail
|
||||
request.min_tokens = Some(250);
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user