Add more refactored openai test & in CI (#7284)

This commit is contained in:
Jinn
2025-06-18 15:52:55 -05:00
committed by GitHub
parent 09ae5b20f3
commit ffd1a26e09
8 changed files with 566 additions and 1049 deletions

View File

@@ -15,9 +15,9 @@
import json
import time
import unittest
from typing import Dict, List, Optional
import pytest
from pydantic import ValidationError
from sglang.srt.entrypoints.openai.protocol import (
@@ -64,18 +64,18 @@ from sglang.srt.entrypoints.openai.protocol import (
)
class TestModelCard:
class TestModelCard(unittest.TestCase):
"""Test ModelCard protocol model"""
def test_basic_model_card_creation(self):
"""Test basic model card creation with required fields"""
card = ModelCard(id="test-model")
assert card.id == "test-model"
assert card.object == "model"
assert card.owned_by == "sglang"
assert isinstance(card.created, int)
assert card.root is None
assert card.max_model_len is None
self.assertEqual(card.id, "test-model")
self.assertEqual(card.object, "model")
self.assertEqual(card.owned_by, "sglang")
self.assertIsInstance(card.created, int)
self.assertIsNone(card.root)
self.assertIsNone(card.max_model_len)
def test_model_card_with_optional_fields(self):
"""Test model card with optional fields"""
@@ -85,28 +85,28 @@ class TestModelCard:
max_model_len=2048,
created=1234567890,
)
assert card.id == "test-model"
assert card.root == "/path/to/model"
assert card.max_model_len == 2048
assert card.created == 1234567890
self.assertEqual(card.id, "test-model")
self.assertEqual(card.root, "/path/to/model")
self.assertEqual(card.max_model_len, 2048)
self.assertEqual(card.created, 1234567890)
def test_model_card_serialization(self):
"""Test model card JSON serialization"""
card = ModelCard(id="test-model", max_model_len=4096)
data = card.model_dump()
assert data["id"] == "test-model"
assert data["object"] == "model"
assert data["max_model_len"] == 4096
self.assertEqual(data["id"], "test-model")
self.assertEqual(data["object"], "model")
self.assertEqual(data["max_model_len"], 4096)
class TestModelList:
class TestModelList(unittest.TestCase):
"""Test ModelList protocol model"""
def test_empty_model_list(self):
"""Test empty model list creation"""
model_list = ModelList()
assert model_list.object == "list"
assert len(model_list.data) == 0
self.assertEqual(model_list.object, "list")
self.assertEqual(len(model_list.data), 0)
def test_model_list_with_cards(self):
"""Test model list with model cards"""
@@ -115,12 +115,12 @@ class TestModelList:
ModelCard(id="model-2", max_model_len=2048),
]
model_list = ModelList(data=cards)
assert len(model_list.data) == 2
assert model_list.data[0].id == "model-1"
assert model_list.data[1].id == "model-2"
self.assertEqual(len(model_list.data), 2)
self.assertEqual(model_list.data[0].id, "model-1")
self.assertEqual(model_list.data[1].id, "model-2")
class TestErrorResponse:
class TestErrorResponse(unittest.TestCase):
"""Test ErrorResponse protocol model"""
def test_basic_error_response(self):
@@ -128,11 +128,11 @@ class TestErrorResponse:
error = ErrorResponse(
message="Invalid request", type="BadRequestError", code=400
)
assert error.object == "error"
assert error.message == "Invalid request"
assert error.type == "BadRequestError"
assert error.code == 400
assert error.param is None
self.assertEqual(error.object, "error")
self.assertEqual(error.message, "Invalid request")
self.assertEqual(error.type, "BadRequestError")
self.assertEqual(error.code, 400)
self.assertIsNone(error.param)
def test_error_response_with_param(self):
"""Test error response with parameter"""
@@ -142,19 +142,19 @@ class TestErrorResponse:
code=422,
param="temperature",
)
assert error.param == "temperature"
self.assertEqual(error.param, "temperature")
class TestUsageInfo:
class TestUsageInfo(unittest.TestCase):
"""Test UsageInfo protocol model"""
def test_basic_usage_info(self):
"""Test basic usage info creation"""
usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30)
assert usage.prompt_tokens == 10
assert usage.completion_tokens == 20
assert usage.total_tokens == 30
assert usage.prompt_tokens_details is None
self.assertEqual(usage.prompt_tokens, 10)
self.assertEqual(usage.completion_tokens, 20)
self.assertEqual(usage.total_tokens, 30)
self.assertIsNone(usage.prompt_tokens_details)
def test_usage_info_with_cache_details(self):
"""Test usage info with cache details"""
@@ -164,22 +164,22 @@ class TestUsageInfo:
total_tokens=30,
prompt_tokens_details={"cached_tokens": 5},
)
assert usage.prompt_tokens_details == {"cached_tokens": 5}
self.assertEqual(usage.prompt_tokens_details, {"cached_tokens": 5})
class TestCompletionRequest:
class TestCompletionRequest(unittest.TestCase):
"""Test CompletionRequest protocol model"""
def test_basic_completion_request(self):
"""Test basic completion request"""
request = CompletionRequest(model="test-model", prompt="Hello world")
assert request.model == "test-model"
assert request.prompt == "Hello world"
assert request.max_tokens == 16 # default
assert request.temperature == 1.0 # default
assert request.n == 1 # default
assert not request.stream # default
assert not request.echo # default
self.assertEqual(request.model, "test-model")
self.assertEqual(request.prompt, "Hello world")
self.assertEqual(request.max_tokens, 16) # default
self.assertEqual(request.temperature, 1.0) # default
self.assertEqual(request.n, 1) # default
self.assertFalse(request.stream) # default
self.assertFalse(request.echo) # default
def test_completion_request_with_options(self):
"""Test completion request with various options"""
@@ -195,15 +195,15 @@ class TestCompletionRequest:
stop=[".", "!"],
logprobs=5,
)
assert request.prompt == ["Hello", "world"]
assert request.max_tokens == 100
assert request.temperature == 0.7
assert request.top_p == 0.9
assert request.n == 2
assert request.stream
assert request.echo
assert request.stop == [".", "!"]
assert request.logprobs == 5
self.assertEqual(request.prompt, ["Hello", "world"])
self.assertEqual(request.max_tokens, 100)
self.assertEqual(request.temperature, 0.7)
self.assertEqual(request.top_p, 0.9)
self.assertEqual(request.n, 2)
self.assertTrue(request.stream)
self.assertTrue(request.echo)
self.assertEqual(request.stop, [".", "!"])
self.assertEqual(request.logprobs, 5)
def test_completion_request_sglang_extensions(self):
"""Test completion request with SGLang-specific extensions"""
@@ -217,23 +217,23 @@ class TestCompletionRequest:
json_schema='{"type": "object"}',
lora_path="/path/to/lora",
)
assert request.top_k == 50
assert request.min_p == 0.1
assert request.repetition_penalty == 1.1
assert request.regex == r"\d+"
assert request.json_schema == '{"type": "object"}'
assert request.lora_path == "/path/to/lora"
self.assertEqual(request.top_k, 50)
self.assertEqual(request.min_p, 0.1)
self.assertEqual(request.repetition_penalty, 1.1)
self.assertEqual(request.regex, r"\d+")
self.assertEqual(request.json_schema, '{"type": "object"}')
self.assertEqual(request.lora_path, "/path/to/lora")
def test_completion_request_validation_errors(self):
"""Test completion request validation errors"""
with pytest.raises(ValidationError):
with self.assertRaises(ValidationError):
CompletionRequest() # missing required fields
with pytest.raises(ValidationError):
with self.assertRaises(ValidationError):
CompletionRequest(model="test-model") # missing prompt
class TestCompletionResponse:
class TestCompletionResponse(unittest.TestCase):
"""Test CompletionResponse protocol model"""
def test_basic_completion_response(self):
@@ -245,28 +245,28 @@ class TestCompletionResponse:
response = CompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
assert response.id == "test-id"
assert response.object == "text_completion"
assert response.model == "test-model"
assert len(response.choices) == 1
assert response.choices[0].text == "Hello world!"
assert response.usage.total_tokens == 5
self.assertEqual(response.id, "test-id")
self.assertEqual(response.object, "text_completion")
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.choices), 1)
self.assertEqual(response.choices[0].text, "Hello world!")
self.assertEqual(response.usage.total_tokens, 5)
class TestChatCompletionRequest:
class TestChatCompletionRequest(unittest.TestCase):
"""Test ChatCompletionRequest protocol model"""
def test_basic_chat_completion_request(self):
"""Test basic chat completion request"""
messages = [{"role": "user", "content": "Hello"}]
request = ChatCompletionRequest(model="test-model", messages=messages)
assert request.model == "test-model"
assert len(request.messages) == 1
assert request.messages[0].role == "user"
assert request.messages[0].content == "Hello"
assert request.temperature == 0.7 # default
assert not request.stream # default
assert request.tool_choice == "none" # default when no tools
self.assertEqual(request.model, "test-model")
self.assertEqual(len(request.messages), 1)
self.assertEqual(request.messages[0].role, "user")
self.assertEqual(request.messages[0].content, "Hello")
self.assertEqual(request.temperature, 0.7) # default
self.assertFalse(request.stream) # default
self.assertEqual(request.tool_choice, "none") # default when no tools
def test_chat_completion_with_multimodal_content(self):
"""Test chat completion with multimodal content"""
@@ -283,9 +283,9 @@ class TestChatCompletionRequest:
}
]
request = ChatCompletionRequest(model="test-model", messages=messages)
assert len(request.messages[0].content) == 2
assert request.messages[0].content[0].type == "text"
assert request.messages[0].content[1].type == "image_url"
self.assertEqual(len(request.messages[0].content), 2)
self.assertEqual(request.messages[0].content[0].type, "text")
self.assertEqual(request.messages[0].content[1].type, "image_url")
def test_chat_completion_with_tools(self):
"""Test chat completion with tools"""
@@ -306,9 +306,9 @@ class TestChatCompletionRequest:
request = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools
)
assert len(request.tools) == 1
assert request.tools[0].function.name == "get_weather"
assert request.tool_choice == "auto" # default when tools present
self.assertEqual(len(request.tools), 1)
self.assertEqual(request.tools[0].function.name, "get_weather")
self.assertEqual(request.tool_choice, "auto") # default when tools present
def test_chat_completion_tool_choice_validation(self):
"""Test tool choice validation logic"""
@@ -316,7 +316,7 @@ class TestChatCompletionRequest:
# No tools, tool_choice should default to "none"
request1 = ChatCompletionRequest(model="test-model", messages=messages)
assert request1.tool_choice == "none"
self.assertEqual(request1.tool_choice, "none")
# With tools, tool_choice should default to "auto"
tools = [
@@ -328,7 +328,7 @@ class TestChatCompletionRequest:
request2 = ChatCompletionRequest(
model="test-model", messages=messages, tools=tools
)
assert request2.tool_choice == "auto"
self.assertEqual(request2.tool_choice, "auto")
def test_chat_completion_sglang_extensions(self):
"""Test chat completion with SGLang extensions"""
@@ -342,14 +342,14 @@ class TestChatCompletionRequest:
stream_reasoning=False,
chat_template_kwargs={"custom_param": "value"},
)
assert request.top_k == 40
assert request.min_p == 0.05
assert not request.separate_reasoning
assert not request.stream_reasoning
assert request.chat_template_kwargs == {"custom_param": "value"}
self.assertEqual(request.top_k, 40)
self.assertEqual(request.min_p, 0.05)
self.assertFalse(request.separate_reasoning)
self.assertFalse(request.stream_reasoning)
self.assertEqual(request.chat_template_kwargs, {"custom_param": "value"})
class TestChatCompletionResponse:
class TestChatCompletionResponse(unittest.TestCase):
"""Test ChatCompletionResponse protocol model"""
def test_basic_chat_completion_response(self):
@@ -362,11 +362,11 @@ class TestChatCompletionResponse:
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
assert response.id == "test-id"
assert response.object == "chat.completion"
assert response.model == "test-model"
assert len(response.choices) == 1
assert response.choices[0].message.content == "Hello there!"
self.assertEqual(response.id, "test-id")
self.assertEqual(response.object, "chat.completion")
self.assertEqual(response.model, "test-model")
self.assertEqual(len(response.choices), 1)
self.assertEqual(response.choices[0].message.content, "Hello there!")
def test_chat_completion_response_with_tool_calls(self):
"""Test chat completion response with tool calls"""
@@ -384,28 +384,30 @@ class TestChatCompletionResponse:
response = ChatCompletionResponse(
id="test-id", model="test-model", choices=[choice], usage=usage
)
assert response.choices[0].message.tool_calls[0].function.name == "get_weather"
assert response.choices[0].finish_reason == "tool_calls"
self.assertEqual(
response.choices[0].message.tool_calls[0].function.name, "get_weather"
)
self.assertEqual(response.choices[0].finish_reason, "tool_calls")
class TestEmbeddingRequest:
class TestEmbeddingRequest(unittest.TestCase):
"""Test EmbeddingRequest protocol model"""
def test_basic_embedding_request(self):
"""Test basic embedding request"""
request = EmbeddingRequest(model="test-model", input="Hello world")
assert request.model == "test-model"
assert request.input == "Hello world"
assert request.encoding_format == "float" # default
assert request.dimensions is None # default
self.assertEqual(request.model, "test-model")
self.assertEqual(request.input, "Hello world")
self.assertEqual(request.encoding_format, "float") # default
self.assertIsNone(request.dimensions) # default
def test_embedding_request_with_list_input(self):
"""Test embedding request with list input"""
request = EmbeddingRequest(
model="test-model", input=["Hello", "world"], dimensions=512
)
assert request.input == ["Hello", "world"]
assert request.dimensions == 512
self.assertEqual(request.input, ["Hello", "world"])
self.assertEqual(request.dimensions, 512)
def test_multimodal_embedding_request(self):
"""Test multimodal embedding request"""
@@ -414,14 +416,14 @@ class TestEmbeddingRequest:
MultimodalEmbeddingInput(text="World", image=None),
]
request = EmbeddingRequest(model="test-model", input=multimodal_input)
assert len(request.input) == 2
assert request.input[0].text == "Hello"
assert request.input[0].image == "base64_image_data"
assert request.input[1].text == "World"
assert request.input[1].image is None
self.assertEqual(len(request.input), 2)
self.assertEqual(request.input[0].text, "Hello")
self.assertEqual(request.input[0].image, "base64_image_data")
self.assertEqual(request.input[1].text, "World")
self.assertIsNone(request.input[1].image)
class TestEmbeddingResponse:
class TestEmbeddingResponse(unittest.TestCase):
"""Test EmbeddingResponse protocol model"""
def test_basic_embedding_response(self):
@@ -431,14 +433,14 @@ class TestEmbeddingResponse:
response = EmbeddingResponse(
data=[embedding_obj], model="test-model", usage=usage
)
assert response.object == "list"
assert len(response.data) == 1
assert response.data[0].embedding == [0.1, 0.2, 0.3]
assert response.data[0].index == 0
assert response.usage.prompt_tokens == 3
self.assertEqual(response.object, "list")
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data[0].embedding, [0.1, 0.2, 0.3])
self.assertEqual(response.data[0].index, 0)
self.assertEqual(response.usage.prompt_tokens, 3)
class TestScoringRequest:
class TestScoringRequest(unittest.TestCase):
"""Test ScoringRequest protocol model"""
def test_basic_scoring_request(self):
@@ -446,11 +448,11 @@ class TestScoringRequest:
request = ScoringRequest(
model="test-model", query="Hello", items=["World", "Earth"]
)
assert request.model == "test-model"
assert request.query == "Hello"
assert request.items == ["World", "Earth"]
assert not request.apply_softmax # default
assert not request.item_first # default
self.assertEqual(request.model, "test-model")
self.assertEqual(request.query, "Hello")
self.assertEqual(request.items, ["World", "Earth"])
self.assertFalse(request.apply_softmax) # default
self.assertFalse(request.item_first) # default
def test_scoring_request_with_token_ids(self):
"""Test scoring request with token IDs"""
@@ -462,34 +464,34 @@ class TestScoringRequest:
apply_softmax=True,
item_first=True,
)
assert request.query == [1, 2, 3]
assert request.items == [[4, 5], [6, 7]]
assert request.label_token_ids == [8, 9]
assert request.apply_softmax
assert request.item_first
self.assertEqual(request.query, [1, 2, 3])
self.assertEqual(request.items, [[4, 5], [6, 7]])
self.assertEqual(request.label_token_ids, [8, 9])
self.assertTrue(request.apply_softmax)
self.assertTrue(request.item_first)
class TestScoringResponse:
class TestScoringResponse(unittest.TestCase):
"""Test ScoringResponse protocol model"""
def test_basic_scoring_response(self):
"""Test basic scoring response"""
response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model")
assert response.object == "scoring"
assert response.scores == [[0.1, 0.9], [0.3, 0.7]]
assert response.model == "test-model"
assert response.usage is None # default
self.assertEqual(response.object, "scoring")
self.assertEqual(response.scores, [[0.1, 0.9], [0.3, 0.7]])
self.assertEqual(response.model, "test-model")
self.assertIsNone(response.usage) # default
class TestFileOperations:
class TestFileOperations(unittest.TestCase):
"""Test file operation protocol models"""
def test_file_request(self):
"""Test file request model"""
file_data = b"test file content"
request = FileRequest(file=file_data, purpose="batch")
assert request.file == file_data
assert request.purpose == "batch"
self.assertEqual(request.file, file_data)
self.assertEqual(request.purpose, "batch")
def test_file_response(self):
"""Test file response model"""
@@ -500,20 +502,20 @@ class TestFileOperations:
filename="test.jsonl",
purpose="batch",
)
assert response.id == "file-123"
assert response.object == "file"
assert response.bytes == 1024
assert response.filename == "test.jsonl"
self.assertEqual(response.id, "file-123")
self.assertEqual(response.object, "file")
self.assertEqual(response.bytes, 1024)
self.assertEqual(response.filename, "test.jsonl")
def test_file_delete_response(self):
"""Test file delete response model"""
response = FileDeleteResponse(id="file-123", deleted=True)
assert response.id == "file-123"
assert response.object == "file"
assert response.deleted
self.assertEqual(response.id, "file-123")
self.assertEqual(response.object, "file")
self.assertTrue(response.deleted)
class TestBatchOperations:
class TestBatchOperations(unittest.TestCase):
"""Test batch operation protocol models"""
def test_batch_request(self):
@@ -524,10 +526,10 @@ class TestBatchOperations:
completion_window="24h",
metadata={"custom": "value"},
)
assert request.input_file_id == "file-123"
assert request.endpoint == "/v1/chat/completions"
assert request.completion_window == "24h"
assert request.metadata == {"custom": "value"}
self.assertEqual(request.input_file_id, "file-123")
self.assertEqual(request.endpoint, "/v1/chat/completions")
self.assertEqual(request.completion_window, "24h")
self.assertEqual(request.metadata, {"custom": "value"})
def test_batch_response(self):
"""Test batch response model"""
@@ -538,20 +540,20 @@ class TestBatchOperations:
completion_window="24h",
created_at=1234567890,
)
assert response.id == "batch-123"
assert response.object == "batch"
assert response.status == "validating" # default
assert response.endpoint == "/v1/chat/completions"
self.assertEqual(response.id, "batch-123")
self.assertEqual(response.object, "batch")
self.assertEqual(response.status, "validating") # default
self.assertEqual(response.endpoint, "/v1/chat/completions")
class TestResponseFormats:
class TestResponseFormats(unittest.TestCase):
"""Test response format protocol models"""
def test_basic_response_format(self):
"""Test basic response format"""
format_obj = ResponseFormat(type="json_object")
assert format_obj.type == "json_object"
assert format_obj.json_schema is None
self.assertEqual(format_obj.type, "json_object")
self.assertIsNone(format_obj.json_schema)
def test_json_schema_response_format(self):
"""Test JSON schema response format"""
@@ -560,9 +562,9 @@ class TestResponseFormats:
name="person_schema", description="Person schema", schema=schema
)
format_obj = ResponseFormat(type="json_schema", json_schema=json_schema)
assert format_obj.type == "json_schema"
assert format_obj.json_schema.name == "person_schema"
assert format_obj.json_schema.schema_ == schema
self.assertEqual(format_obj.type, "json_schema")
self.assertEqual(format_obj.json_schema.name, "person_schema")
self.assertEqual(format_obj.json_schema.schema_, schema)
def test_structural_tag_response_format(self):
"""Test structural tag response format"""
@@ -576,12 +578,12 @@ class TestResponseFormats:
format_obj = StructuralTagResponseFormat(
type="structural_tag", structures=structures, triggers=["think"]
)
assert format_obj.type == "structural_tag"
assert len(format_obj.structures) == 1
assert format_obj.triggers == ["think"]
self.assertEqual(format_obj.type, "structural_tag")
self.assertEqual(len(format_obj.structures), 1)
self.assertEqual(format_obj.triggers, ["think"])
class TestLogProbs:
class TestLogProbs(unittest.TestCase):
"""Test LogProbs protocol models"""
def test_basic_logprobs(self):
@@ -592,9 +594,9 @@ class TestLogProbs:
tokens=["Hello", " ", "world"],
top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}],
)
assert len(logprobs.tokens) == 3
assert logprobs.tokens == ["Hello", " ", "world"]
assert logprobs.token_logprobs == [-0.1, -0.2, -0.3]
self.assertEqual(len(logprobs.tokens), 3)
self.assertEqual(logprobs.tokens, ["Hello", " ", "world"])
self.assertEqual(logprobs.token_logprobs, [-0.1, -0.2, -0.3])
def test_choice_logprobs(self):
"""Test ChoiceLogprobs model"""
@@ -607,17 +609,17 @@ class TestLogProbs:
],
)
choice_logprobs = ChoiceLogprobs(content=[token_logprob])
assert len(choice_logprobs.content) == 1
assert choice_logprobs.content[0].token == "Hello"
self.assertEqual(len(choice_logprobs.content), 1)
self.assertEqual(choice_logprobs.content[0].token, "Hello")
class TestStreamingModels:
class TestStreamingModels(unittest.TestCase):
"""Test streaming response models"""
def test_stream_options(self):
"""Test StreamOptions model"""
options = StreamOptions(include_usage=True)
assert options.include_usage
self.assertTrue(options.include_usage)
def test_chat_completion_stream_response(self):
"""Test ChatCompletionStreamResponse model"""
@@ -626,29 +628,29 @@ class TestStreamingModels:
response = ChatCompletionStreamResponse(
id="test-id", model="test-model", choices=[choice]
)
assert response.object == "chat.completion.chunk"
assert response.choices[0].delta.content == "Hello"
self.assertEqual(response.object, "chat.completion.chunk")
self.assertEqual(response.choices[0].delta.content, "Hello")
class TestValidationEdgeCases:
class TestValidationEdgeCases(unittest.TestCase):
"""Test edge cases and validation scenarios"""
def test_empty_messages_validation(self):
"""Test validation with empty messages"""
with pytest.raises(ValidationError):
with self.assertRaises(ValidationError):
ChatCompletionRequest(model="test-model", messages=[])
def test_invalid_tool_choice_type(self):
"""Test invalid tool choice type"""
messages = [{"role": "user", "content": "Hello"}]
with pytest.raises(ValidationError):
with self.assertRaises(ValidationError):
ChatCompletionRequest(
model="test-model", messages=messages, tool_choice=123
)
def test_negative_token_limits(self):
"""Test negative token limits"""
with pytest.raises(ValidationError):
with self.assertRaises(ValidationError):
CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1)
def test_invalid_temperature_range(self):
@@ -656,7 +658,7 @@ class TestValidationEdgeCases:
# Note: The current protocol doesn't enforce temperature range,
# but this test documents expected behavior
request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0)
assert request.temperature == 5.0 # Currently allowed
self.assertEqual(request.temperature, 5.0) # Currently allowed
def test_model_serialization_roundtrip(self):
"""Test that models can be serialized and deserialized"""
@@ -673,11 +675,11 @@ class TestValidationEdgeCases:
# Deserialize back
restored_request = ChatCompletionRequest(**data)
assert restored_request.model == original_request.model
assert restored_request.temperature == original_request.temperature
assert restored_request.max_tokens == original_request.max_tokens
assert len(restored_request.messages) == len(original_request.messages)
self.assertEqual(restored_request.model, original_request.model)
self.assertEqual(restored_request.temperature, original_request.temperature)
self.assertEqual(restored_request.max_tokens, original_request.max_tokens)
self.assertEqual(len(restored_request.messages), len(original_request.messages))
if __name__ == "__main__":
pytest.main([__file__])
unittest.main(verbosity=2)