diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index aad9de93e..b37e8d13a 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -582,6 +582,7 @@ def v1_generate_request( "no_stop_trim": request.no_stop_trim, "ignore_eos": request.ignore_eos, "skip_special_tokens": request.skip_special_tokens, + "logit_bias": request.logit_bias, } ) return_logprobs.append(request.logprobs is not None) @@ -1219,6 +1220,7 @@ def v1_chat_generate_request( "no_stop_trim": request.no_stop_trim, "ignore_eos": request.ignore_eos, "skip_special_tokens": request.skip_special_tokens, + "logit_bias": request.logit_bias, } if request.response_format and request.response_format.type == "json_schema": diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 238f3a7cd..efacf37ad 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -10,6 +10,7 @@ import torch import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.sampling_params import TOP_K_ALL +from sglang.srt.utils import merge_bias_tensor if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -63,6 +64,9 @@ class SamplingBatchInfo: # Device device: str = "cuda" + # Handle logit bias + logit_bias: Optional[torch.Tensor] = None + @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): reqs = batch.reqs @@ -85,6 +89,14 @@ class SamplingBatchInfo: [r.sampling_params.min_p for r in reqs], dtype=torch.float ).to(device, non_blocking=True) + logit_bias = None + if any(r.sampling_params.logit_bias is not None for r in reqs): + logit_bias = torch.zeros(len(reqs), vocab_size, device=device) + for i, r in enumerate(reqs): + if r.sampling_params.logit_bias is not None: + for key, value in r.sampling_params.logit_bias.items(): + logit_bias[i, int(key)] = value + # Check if any request has custom logit processor has_custom_logit_processor = ( batch.enable_custom_logit_processor # check the flag first. @@ -150,6 +162,7 @@ class SamplingBatchInfo: custom_params=custom_params, custom_logit_processor=merged_custom_logit_processor, device=device, + logit_bias=logit_bias, ) return ret @@ -206,6 +219,9 @@ class SamplingBatchInfo: if self.vocab_mask is not None: self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask) + if self.logit_bias is not None: + logits.add_(self.logit_bias) + def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor): self.penalizer_orchestrator.filter(keep_indices_device) @@ -221,6 +237,9 @@ class SamplingBatchInfo: value = getattr(self, item, None) setattr(self, item, value[keep_indices_device]) + if self.logit_bias is not None: + self.logit_bias = self.logit_bias[keep_indices_device] + def _filter_batch_custom_logit_processor( self, keep_indices: List[int], keep_indices_device: torch.Tensor ): @@ -321,3 +340,8 @@ class SamplingBatchInfo: self.need_top_p_sampling |= other.need_top_p_sampling self.need_top_k_sampling |= other.need_top_k_sampling self.need_min_p_sampling |= other.need_min_p_sampling + + # Merge logit bias + self.logit_bias = merge_bias_tensor( + self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0 + ) diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 87436f86d..e0a88107f 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -52,6 +52,7 @@ class SamplingParams: no_stop_trim: bool = False, custom_params: Optional[Dict[str, Any]] = None, stream_interval: Optional[int] = None, + logit_bias: Optional[Dict[str, float]] = None, ) -> None: self.max_new_tokens = max_new_tokens self.stop_strs = stop @@ -78,6 +79,7 @@ class SamplingParams: self.no_stop_trim = no_stop_trim self.custom_params = custom_params self.stream_interval = stream_interval + self.logit_bias = logit_bias # Process some special cases if 0 <= self.temperature < _SAMPLING_EPS: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 18171b9c9..779843981 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2210,6 +2210,45 @@ class Withable(Generic[T]): self._value = None +def merge_bias_tensor( + lhs: Optional[torch.Tensor], + rhs: Optional[torch.Tensor], + bs1: int, + bs2: int, + device: str, + default: float, +): + """Merge two bias tensors for batch merging. + + Args: + lhs: Left-hand side tensor + rhs: Right-hand side tensor + bs1: Batch size of left-hand side tensor + bs2: Batch size of right-hand side tensor + device: Device to place the merged tensor on + default: Default value for missing tensor elements + + Returns: + Merged tensor or None if both inputs are None + """ + if lhs is None and rhs is None: + return None + + if lhs is not None and rhs is not None: + return torch.cat([lhs, rhs]) + else: + if lhs is not None: + shape, dtype = lhs.shape[1:], lhs.dtype + else: + shape, dtype = rhs.shape[1:], rhs.dtype + + if lhs is None: + lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default) + if rhs is None: + rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default) + return torch.cat([lhs, rhs]) + + def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]: import huggingface_hub as hf diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index a2fb1bff9..089da355d 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -504,6 +504,122 @@ class TestSRTEndpoint(CustomTestCase): version = response_json["version"] self.assertIsInstance(version, str) + def test_logit_bias(self): + """Test that a very high logit bias forces sampling of a specific token.""" + # Choose a token ID to bias (using 5 as an example) + target_token_id = 60704 # Paris for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST + logit_bias = {str(target_token_id): 100.0} # Very high positive bias + + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 1.0, # Use high temperature to encourage exploration + "max_new_tokens": 4, + "logit_bias": logit_bias, + }, + "return_logprob": True, + }, + ) + response_json = response.json() + + # Extract the sampled token IDs from the output + output_token_logprobs = response_json["meta_info"]["output_token_logprobs"] + sampled_tokens = [x[1] for x in output_token_logprobs] + + # Verify that all sampled tokens are the target token + self.assertTrue( + all(x == target_token_id for x in sampled_tokens), + f"Expected all tokens to be {target_token_id}, but got {sampled_tokens}", + ) + + def test_forbidden_token(self): + """Test that a forbidden token (very negative logit bias) doesn't appear in the output.""" + # Choose a token ID to forbid (using 10 as an example) + forbidden_token_id = 23994 # rice for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST + logit_bias = { + str(forbidden_token_id): -100.0 + } # Very negative bias to forbid the token + + response = requests.post( + self.base_url + "/generate", + json={ + "text": "Only output 'rice' exactly like this, in lowercase ONLY: rice", + "sampling_params": { + "temperature": 1.0, # Use high temperature to encourage diverse output + "max_new_tokens": 50, # Generate enough tokens to likely include numbers + "logit_bias": logit_bias, + }, + "return_logprob": True, + }, + ) + response_json = response.json() + + # Extract the sampled token IDs from the output + output_token_logprobs = response_json["meta_info"]["output_token_logprobs"] + sampled_tokens = [x[1] for x in output_token_logprobs] + + # Verify that the forbidden token doesn't appear in the output + self.assertNotIn( + forbidden_token_id, + sampled_tokens, + f"Expected forbidden token {forbidden_token_id} not to be present, but it was found", + ) + + def test_logit_bias_isolation(self): + """Test that logit_bias applied to one request doesn't affect other requests in batch.""" + # Choose a token ID to bias in first request only + biased_token_id = 60704 # Paris for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + # Prepare batch requests - one with logit_bias and one without + requests_data = [ + { + "text": "The capital of France is", + "sampling_params": { + "temperature": 1.0, + "max_new_tokens": 4, + "logit_bias": {str(biased_token_id): 100.0}, # Strong bias + }, + "return_logprob": True, + }, + { + "text": "The capital of France is", + "sampling_params": { + "temperature": 1.0, + "max_new_tokens": 4, + }, + "return_logprob": True, + }, + ] + + # Send both requests + responses = [] + for req in requests_data: + response = requests.post(self.base_url + "/generate", json=req) + responses.append(response.json()) + + # Extract token IDs from each response + biased_tokens = [ + x[1] for x in responses[0]["meta_info"]["output_token_logprobs"] + ] + unbiased_tokens = [ + x[1] for x in responses[1]["meta_info"]["output_token_logprobs"] + ] + + # Verify first response contains only biased tokens + self.assertTrue( + all(x == biased_token_id for x in biased_tokens), + f"Expected all tokens to be {biased_token_id} in first response, but got {biased_tokens}", + ) + + # Verify second response contains at least some different tokens + # (We can't guarantee exactly what tokens will be generated, but they shouldn't all be the biased token) + self.assertTrue( + any(x != biased_token_id for x in unbiased_tokens), + f"Expected some tokens to be different from {biased_token_id} in second response, but got {unbiased_tokens}", + ) + def test_get_server_info_concurrent(self): """Make sure the concurrent get_server_info doesn't crash the server.""" tp = ThreadPoolExecutor(max_workers=30)