[Feature] Add Logit Bias (#6579)

Co-authored-by: Cinjon Resnick <cinjon.resnick@gmail.com>
This commit is contained in:
Brayden Zhong
2025-06-10 18:39:25 -04:00
committed by GitHub
parent 344adb00ec
commit ca9291181d
5 changed files with 183 additions and 0 deletions

View File

@@ -504,6 +504,122 @@ class TestSRTEndpoint(CustomTestCase):
version = response_json["version"]
self.assertIsInstance(version, str)
def test_logit_bias(self):
"""Test that a very high logit bias forces sampling of a specific token."""
# Choose a token ID to bias (using 5 as an example)
target_token_id = 60704 # Paris for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST
logit_bias = {str(target_token_id): 100.0} # Very high positive bias
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 1.0, # Use high temperature to encourage exploration
"max_new_tokens": 4,
"logit_bias": logit_bias,
},
"return_logprob": True,
},
)
response_json = response.json()
# Extract the sampled token IDs from the output
output_token_logprobs = response_json["meta_info"]["output_token_logprobs"]
sampled_tokens = [x[1] for x in output_token_logprobs]
# Verify that all sampled tokens are the target token
self.assertTrue(
all(x == target_token_id for x in sampled_tokens),
f"Expected all tokens to be {target_token_id}, but got {sampled_tokens}",
)
def test_forbidden_token(self):
"""Test that a forbidden token (very negative logit bias) doesn't appear in the output."""
# Choose a token ID to forbid (using 10 as an example)
forbidden_token_id = 23994 # rice for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST
logit_bias = {
str(forbidden_token_id): -100.0
} # Very negative bias to forbid the token
response = requests.post(
self.base_url + "/generate",
json={
"text": "Only output 'rice' exactly like this, in lowercase ONLY: rice",
"sampling_params": {
"temperature": 1.0, # Use high temperature to encourage diverse output
"max_new_tokens": 50, # Generate enough tokens to likely include numbers
"logit_bias": logit_bias,
},
"return_logprob": True,
},
)
response_json = response.json()
# Extract the sampled token IDs from the output
output_token_logprobs = response_json["meta_info"]["output_token_logprobs"]
sampled_tokens = [x[1] for x in output_token_logprobs]
# Verify that the forbidden token doesn't appear in the output
self.assertNotIn(
forbidden_token_id,
sampled_tokens,
f"Expected forbidden token {forbidden_token_id} not to be present, but it was found",
)
def test_logit_bias_isolation(self):
"""Test that logit_bias applied to one request doesn't affect other requests in batch."""
# Choose a token ID to bias in first request only
biased_token_id = 60704 # Paris for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# Prepare batch requests - one with logit_bias and one without
requests_data = [
{
"text": "The capital of France is",
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": 4,
"logit_bias": {str(biased_token_id): 100.0}, # Strong bias
},
"return_logprob": True,
},
{
"text": "The capital of France is",
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": 4,
},
"return_logprob": True,
},
]
# Send both requests
responses = []
for req in requests_data:
response = requests.post(self.base_url + "/generate", json=req)
responses.append(response.json())
# Extract token IDs from each response
biased_tokens = [
x[1] for x in responses[0]["meta_info"]["output_token_logprobs"]
]
unbiased_tokens = [
x[1] for x in responses[1]["meta_info"]["output_token_logprobs"]
]
# Verify first response contains only biased tokens
self.assertTrue(
all(x == biased_token_id for x in biased_tokens),
f"Expected all tokens to be {biased_token_id} in first response, but got {biased_tokens}",
)
# Verify second response contains at least some different tokens
# (We can't guarantee exactly what tokens will be generated, but they shouldn't all be the biased token)
self.assertTrue(
any(x != biased_token_id for x in unbiased_tokens),
f"Expected some tokens to be different from {biased_token_id} in second response, but got {unbiased_tokens}",
)
def test_get_server_info_concurrent(self):
"""Make sure the concurrent get_server_info doesn't crash the server."""
tp = ThreadPoolExecutor(max_workers=30)