From 583697cd71faa65a2e132a014743f5ff5c63890a Mon Sep 17 00:00:00 2001 From: Hongpeng Guo Date: Mon, 20 Jan 2025 02:00:35 -0800 Subject: [PATCH] [Enhancement] Custom Logit Processor Improvement (#2998) Signed-off-by: Hongpeng Guo --- python/sglang/bench_one_batch.py | 1 + python/sglang/srt/layers/sampler.py | 10 ++++ python/sglang/srt/managers/schedule_batch.py | 6 +++ python/sglang/srt/managers/scheduler.py | 2 + .../srt/sampling/sampling_batch_info.py | 53 ++++++++++++------- test/srt/test_srt_endpoint.py | 35 ++++++++---- 6 files changed, 79 insertions(+), 28 deletions(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 473f478ad..e01919399 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -232,6 +232,7 @@ def extend(reqs, model_runner): model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, + enable_custom_logit_processor=False, ) batch.prepare_for_extend() model_worker_batch = batch.get_model_worker_batch() diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index e8b25da07..ebaa1aa0e 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -132,6 +132,11 @@ class Sampler(nn.Module): """Apply custom logit processors to the logits. This function will modify the logits in-place.""" + assert logits.shape[0] == len(sampling_batch_info), ( + f"The batch size of logits ({logits.shape[0]}) does not match the batch size of " + f"sampling_batch_info ({len(sampling_batch_info)})" + ) + for _, ( processor, batch_mask, @@ -139,6 +144,11 @@ class Sampler(nn.Module): # Get the batch indices that need to be processed batch_indices = batch_mask.nonzero(as_tuple=True)[0] + assert batch_mask.shape[0] == len(sampling_batch_info), ( + f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of " + f"sampling_batch_info ({len(sampling_batch_info)})" + ) + # Apply the processor to the logits logits[batch_mask] = processor( logits[batch_mask], diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a09810a38..040afe3d3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -595,6 +595,9 @@ class ScheduleBatch: spec_algorithm: SpeculativeAlgorithm = None spec_info: Optional[SpecInfo] = None + # Enable custom logit processor + enable_custom_logit_processor: bool = False + @classmethod def init_new( cls, @@ -605,6 +608,7 @@ class ScheduleBatch: model_config: ModelConfig, enable_overlap: bool, spec_algorithm: SpeculativeAlgorithm, + enable_custom_logit_processor: bool, ): return cls( reqs=reqs, @@ -618,6 +622,7 @@ class ScheduleBatch: has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, spec_algorithm=spec_algorithm, + enable_custom_logit_processor=enable_custom_logit_processor, ) def batch_size(self): @@ -1201,6 +1206,7 @@ class ScheduleBatch: return_logprob=self.return_logprob, decoding_reqs=self.decoding_reqs, spec_algorithm=self.spec_algorithm, + enable_custom_logit_processor=self.enable_custom_logit_processor, ) def __str__(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 416abe21c..fba8a67ec 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -966,6 +966,7 @@ class Scheduler: self.model_config, self.enable_overlap, self.spec_algorithm, + self.server_args.enable_custom_logit_processor, ) new_batch.prepare_for_extend() @@ -1520,6 +1521,7 @@ class Scheduler: self.model_config, self.enable_overlap, self.spec_algorithm, + self.server_args.enable_custom_logit_processor, ) idle_batch.prepare_for_idle() return idle_batch diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index d4c5c3238..a27ff1ad2 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -89,7 +89,10 @@ class SamplingBatchInfo: ).to(device, non_blocking=True) # Check if any request has custom logit processor - has_custom_logit_processor = any(r.custom_logit_processor for r in reqs) + has_custom_logit_processor = ( + batch.enable_custom_logit_processor # check the flag first. + and any(r.custom_logit_processor for r in reqs) # then check the requests. + ) if has_custom_logit_processor: # Merge the same type of custom logit processors together @@ -247,8 +250,7 @@ class SamplingBatchInfo: self, unfinished_indices: List[int], new_indices: torch.Tensor ): """Filter the custom logit processor and custom params""" - if not self.custom_logit_processor: - return + self.custom_logit_processor = { k: (p, mask[new_indices]) for k, (p, mask) in self.custom_logit_processor.items() @@ -258,7 +260,9 @@ class SamplingBatchInfo: } self.custom_params = [self.custom_params[i] for i in unfinished_indices] - if len(self) == 0: + # If the custom logit processor is an empty dict, set the flag to False, + # and set the custom logit processor and custom params to None. + if len(self.custom_logit_processor) == 0: self.custom_logit_processor = None self.custom_params = None self.has_custom_logit_processor = False @@ -290,8 +294,8 @@ class SamplingBatchInfo: @staticmethod def merge_custom_logit_processor( - lhs: Optional[Dict[str, torch.Tensor]], - rhs: Optional[Dict[str, torch.Tensor]], + lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]], + rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]], bs1: int, bs2: int, device: str, @@ -319,27 +323,22 @@ class SamplingBatchInfo: ) merged_dict[k] = (processor, torch.cat([left_mask, right_mask])) + assert merged_dict[k][1].shape[0] == bs1 + bs2, ( + f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match " + f"the sum of the batch sizes of the two masks ({bs1 + bs2})" + f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}" + f"\n{lhs=}\n{rhs=}" + ) + return merged_dict def merge_batch(self, other: "SamplingBatchInfo"): self.penalizer_orchestrator.merge(other.penalizer_orchestrator) - for item in [ - "temperatures", - "top_ps", - "top_ks", - "min_ps", - ]: - self_val = getattr(self, item, None) - other_val = getattr(other, item, None) - setattr(self, item, torch.concat([self_val, other_val])) - - self.is_all_greedy = self.is_all_greedy and other.is_all_greedy + # Merge the logit bias tensor self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) - self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling - # Merge the custom logit processors and custom params lists if self.has_custom_logit_processor or other.has_custom_logit_processor: # Merge the custom logit processors @@ -360,6 +359,22 @@ class SamplingBatchInfo: # Set the flag to True if any of the two has custom logit processor self.has_custom_logit_processor = True + # Note: becasue the __len()__ operator is defined on the temperatures tensor, + # please make sure any merge operation with len(self) or len(other) is done before + # the merge operation of the temperatures tensor below. + for item in [ + "temperatures", + "top_ps", + "top_ks", + "min_ps", + ]: + self_val = getattr(self, item, None) + other_val = getattr(other, item, None) + setattr(self, item, torch.concat([self_val, other_val])) + + self.is_all_greedy = self.is_all_greedy and other.is_all_greedy + self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling + def apply_logits_bias(self, logits: torch.Tensor): # Apply logit_bias if self.logit_bias is not None: diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index cddd75fa6..7c57c13e2 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -4,8 +4,10 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_ """ import json +import random import unittest from concurrent.futures import ThreadPoolExecutor +from typing import Optional import numpy as np import requests @@ -253,8 +255,11 @@ class TestSRTEndpoint(unittest.TestCase): self.assertTrue(all(x is not None for x in logprobs)) - def run_custom_logit_processor(self, target_token_id: int): - """Test custom logit processor with custom params.""" + def run_custom_logit_processor(self, target_token_id: Optional[int] = None): + """Test custom logit processor with custom params. + + If target_token_id is None, the custom logit processor won't be passed in. + """ custom_params = {"token_id": target_token_id} @@ -285,8 +290,12 @@ class TestSRTEndpoint(unittest.TestCase): # Custom json data with custom logit processor and params. custom_json = base_json.copy() - custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str() - custom_json["sampling_params"]["custom_params"] = custom_params + # Only set the custom logit processor if target_token_id is not None. + if target_token_id is not None: + custom_json["custom_logit_processor"] = ( + DeterministicLogitProcessor().to_str() + ) + custom_json["sampling_params"]["custom_params"] = custom_params custom_response = requests.post( self.base_url + "/generate", @@ -297,22 +306,30 @@ class TestSRTEndpoint(unittest.TestCase): sampled_tokens = [x[1] for x in output_token_logprobs] # The logit processor should always sample the given token as the logits is deterministic. - self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens)) + if target_token_id is not None: + self.assertTrue( + all(x == custom_params["token_id"] for x in sampled_tokens), + # Print the detailed test case info if the test fails. + f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}", + ) def test_custom_logit_processor(self): """Test custom logit processor with a single request.""" - # Temporarily skipped due to buggy implementation - return self.run_custom_logit_processor(target_token_id=5) def test_custom_logit_processor_batch(self): """Test custom logit processor with a batch of requests.""" - # Temporarily skipped due to buggy implementation - return target_token_ids = list(range(32)) with ThreadPoolExecutor(len(target_token_ids)) as executor: list(executor.map(self.run_custom_logit_processor, target_token_ids)) + def test_custom_logit_processor_batch_mixed(self): + """Test a batch of requests mixed of requests with and without custom logit processor.""" + target_token_ids = list(range(32)) + [None] * 16 + random.shuffle(target_token_ids) + with ThreadPoolExecutor(len(target_token_ids)) as executor: + list(executor.map(self.run_custom_logit_processor, target_token_ids)) + def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") response_json = response.json()