[Generative Score API] Optimization to Remove Decode. (#8840)
This commit is contained in:
committed by
GitHub
parent
9e426466af
commit
a027a9b4b3
@@ -213,6 +213,88 @@ class TestScoreAPI(CustomTestCase):
|
||||
1.0, sum(score_list), 6, "Scores should sum to 1"
|
||||
)
|
||||
|
||||
def test_score_request_construction(self):
|
||||
"""Test that scoring requests are constructed to avoid decode phase."""
|
||||
from unittest.mock import patch
|
||||
|
||||
# Capture the internal request to verify optimization
|
||||
captured_requests = []
|
||||
original_gen = self.engine.tokenizer_manager.generate_request
|
||||
|
||||
async def mock_generate_request(req, request=None):
|
||||
captured_requests.append(req)
|
||||
async for result in original_gen(req, request):
|
||||
yield result
|
||||
|
||||
# Patch the generate_request method
|
||||
with patch.object(
|
||||
self.engine.tokenizer_manager,
|
||||
"generate_request",
|
||||
side_effect=mock_generate_request,
|
||||
):
|
||||
# Run a scoring request
|
||||
query = "What is the capital of"
|
||||
items = ["France", "Germany"]
|
||||
label_token_ids = [1, 2, 3]
|
||||
|
||||
scores = self.engine.score(
|
||||
query=query,
|
||||
items=items,
|
||||
label_token_ids=label_token_ids,
|
||||
apply_softmax=True,
|
||||
)
|
||||
|
||||
# Verify we got results
|
||||
self.assertEqual(len(scores), len(items))
|
||||
|
||||
# Verify the captured request has decode-avoiding properties
|
||||
self.assertEqual(len(captured_requests), 1)
|
||||
request = captured_requests[0]
|
||||
|
||||
# Key assertions for decode phase avoidance:
|
||||
# 1. max_new_tokens should be 0 (prevents token generation)
|
||||
# Handle both single and batch request cases
|
||||
if isinstance(request.sampling_params, dict):
|
||||
max_new_tokens = request.sampling_params.get("max_new_tokens", 0)
|
||||
elif isinstance(request.sampling_params, list):
|
||||
# For batch requests, check the first item
|
||||
max_new_tokens = request.sampling_params[0].get("max_new_tokens", 0)
|
||||
else:
|
||||
max_new_tokens = getattr(request.sampling_params, "max_new_tokens", 0)
|
||||
|
||||
self.assertEqual(
|
||||
max_new_tokens, 0, "max_new_tokens should be 0 to avoid decode phase"
|
||||
)
|
||||
|
||||
# 2. Should have token_ids_logprob for scoring
|
||||
# Handle both single and batch request cases
|
||||
if (
|
||||
isinstance(request.token_ids_logprob, list)
|
||||
and len(request.token_ids_logprob) > 0
|
||||
and isinstance(request.token_ids_logprob[0], list)
|
||||
):
|
||||
# Batch case: token_ids_logprob is a list of lists
|
||||
# Each item in the batch should have the same label_token_ids
|
||||
for item_token_ids in request.token_ids_logprob:
|
||||
self.assertEqual(
|
||||
item_token_ids,
|
||||
label_token_ids,
|
||||
"Each batch item should have label_token_ids for scoring",
|
||||
)
|
||||
else:
|
||||
# Single request case
|
||||
self.assertEqual(
|
||||
request.token_ids_logprob,
|
||||
label_token_ids,
|
||||
"Should have label_token_ids for scoring",
|
||||
)
|
||||
|
||||
# 3. Should request logprobs but not stream
|
||||
self.assertTrue(
|
||||
request.return_logprob, "Should request logprobs for scoring"
|
||||
)
|
||||
self.assertFalse(request.stream, "Scoring requests should not stream")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
121
test/srt/test_tokenizer_batch_encode.py
Normal file
121
test/srt/test_tokenizer_batch_encode.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Unit tests for enable_tokenizer_batch_encode feature.
|
||||
|
||||
This tests the batch tokenization functionality which allows processing
|
||||
multiple text inputs in a single batch for improved performance.
|
||||
|
||||
Usage:
|
||||
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncode.test_batch_validation_constraints
|
||||
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeUnit.test_batch_tokenize_and_process_logic
|
||||
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeLogic.test_batch_processing_path
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock, Mock, call, patch
|
||||
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput, TokenizedGenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
|
||||
class TestTokenizerBatchEncode(unittest.TestCase):
|
||||
"""Test cases for tokenizer batch encoding validation and setup."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.server_args = ServerArgs(
|
||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
enable_tokenizer_batch_encode=True,
|
||||
)
|
||||
self.port_args = PortArgs.init_new(self.server_args)
|
||||
|
||||
with patch("zmq.asyncio.Context"), patch(
|
||||
"sglang.srt.utils.get_zmq_socket"
|
||||
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
||||
|
||||
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
||||
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
||||
|
||||
def test_batch_encode_enabled(self):
|
||||
"""Test that batch encoding is enabled when configured."""
|
||||
self.assertTrue(self.server_args.enable_tokenizer_batch_encode)
|
||||
|
||||
def test_batch_encode_disabled(self):
|
||||
"""Test that batch encoding can be disabled."""
|
||||
server_args_disabled = ServerArgs(
|
||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
enable_tokenizer_batch_encode=False,
|
||||
)
|
||||
self.assertFalse(server_args_disabled.enable_tokenizer_batch_encode)
|
||||
|
||||
def test_multimodal_input_validation(self):
|
||||
"""Test that multimodal inputs are rejected in batch mode."""
|
||||
req = GenerateReqInput(text="test", image_data=["dummy"])
|
||||
req.contains_mm_input = Mock(return_value=True)
|
||||
|
||||
batch_obj = Mock()
|
||||
batch_obj.__getitem__ = lambda self, i: req
|
||||
|
||||
self.tokenizer_manager.is_generation = True
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
self.tokenizer_manager._validate_batch_tokenization_constraints(
|
||||
1, batch_obj
|
||||
)
|
||||
|
||||
self.assertIn("multimodal", str(cm.exception))
|
||||
|
||||
def test_pretokenized_input_validation(self):
|
||||
"""Test that pre-tokenized inputs are rejected in batch mode."""
|
||||
req = GenerateReqInput(input_ids=[1, 2, 3])
|
||||
|
||||
batch_obj = Mock()
|
||||
batch_obj.__getitem__ = lambda self, i: req
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
self.tokenizer_manager._validate_batch_tokenization_constraints(
|
||||
1, batch_obj
|
||||
)
|
||||
|
||||
self.assertIn("pre-tokenized", str(cm.exception))
|
||||
|
||||
def test_input_embeds_validation(self):
|
||||
"""Test that input embeds are rejected in batch mode."""
|
||||
req = GenerateReqInput(input_embeds=[0.1, 0.2])
|
||||
|
||||
batch_obj = Mock()
|
||||
batch_obj.__getitem__ = lambda self, i: req
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
self.tokenizer_manager._validate_batch_tokenization_constraints(
|
||||
1, batch_obj
|
||||
)
|
||||
|
||||
self.assertIn("input_embeds", str(cm.exception))
|
||||
|
||||
def test_valid_text_only_requests_pass_validation(self):
|
||||
"""Test that valid text-only requests pass validation."""
|
||||
# Create valid requests (text-only)
|
||||
requests = []
|
||||
for i in range(3):
|
||||
req = GenerateReqInput(text=f"test text {i}")
|
||||
req.contains_mm_input = Mock(return_value=False)
|
||||
requests.append(req)
|
||||
|
||||
batch_obj = Mock()
|
||||
batch_obj.__getitem__ = Mock(side_effect=lambda i: requests[i])
|
||||
|
||||
# Should not raise any exception
|
||||
try:
|
||||
self.tokenizer_manager._validate_batch_tokenization_constraints(
|
||||
3, batch_obj
|
||||
)
|
||||
except Exception as e:
|
||||
self.fail(f"Validation failed for valid text-only requests: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user