[Generative Score API] Optimization to Remove Decode. (#8840)

This commit is contained in:
Sundara Raman Ramachandran
2025-08-13 14:12:24 -07:00
committed by GitHub
parent 9e426466af
commit a027a9b4b3
6 changed files with 843 additions and 20 deletions

View File

@@ -213,6 +213,88 @@ class TestScoreAPI(CustomTestCase):
1.0, sum(score_list), 6, "Scores should sum to 1"
)
def test_score_request_construction(self):
"""Test that scoring requests are constructed to avoid decode phase."""
from unittest.mock import patch
# Capture the internal request to verify optimization
captured_requests = []
original_gen = self.engine.tokenizer_manager.generate_request
async def mock_generate_request(req, request=None):
captured_requests.append(req)
async for result in original_gen(req, request):
yield result
# Patch the generate_request method
with patch.object(
self.engine.tokenizer_manager,
"generate_request",
side_effect=mock_generate_request,
):
# Run a scoring request
query = "What is the capital of"
items = ["France", "Germany"]
label_token_ids = [1, 2, 3]
scores = self.engine.score(
query=query,
items=items,
label_token_ids=label_token_ids,
apply_softmax=True,
)
# Verify we got results
self.assertEqual(len(scores), len(items))
# Verify the captured request has decode-avoiding properties
self.assertEqual(len(captured_requests), 1)
request = captured_requests[0]
# Key assertions for decode phase avoidance:
# 1. max_new_tokens should be 0 (prevents token generation)
# Handle both single and batch request cases
if isinstance(request.sampling_params, dict):
max_new_tokens = request.sampling_params.get("max_new_tokens", 0)
elif isinstance(request.sampling_params, list):
# For batch requests, check the first item
max_new_tokens = request.sampling_params[0].get("max_new_tokens", 0)
else:
max_new_tokens = getattr(request.sampling_params, "max_new_tokens", 0)
self.assertEqual(
max_new_tokens, 0, "max_new_tokens should be 0 to avoid decode phase"
)
# 2. Should have token_ids_logprob for scoring
# Handle both single and batch request cases
if (
isinstance(request.token_ids_logprob, list)
and len(request.token_ids_logprob) > 0
and isinstance(request.token_ids_logprob[0], list)
):
# Batch case: token_ids_logprob is a list of lists
# Each item in the batch should have the same label_token_ids
for item_token_ids in request.token_ids_logprob:
self.assertEqual(
item_token_ids,
label_token_ids,
"Each batch item should have label_token_ids for scoring",
)
else:
# Single request case
self.assertEqual(
request.token_ids_logprob,
label_token_ids,
"Should have label_token_ids for scoring",
)
# 3. Should request logprobs but not stream
self.assertTrue(
request.return_logprob, "Should request logprobs for scoring"
)
self.assertFalse(request.stream, "Scoring requests should not stream")
if __name__ == "__main__":
unittest.main()