[Enhancement] Custom Logit Processor Improvement (#2998)

Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
Hongpeng Guo
2025-01-20 02:00:35 -08:00
committed by GitHub
parent 2584f6d944
commit 583697cd71
6 changed files with 79 additions and 28 deletions

View File

@@ -232,6 +232,7 @@ def extend(reqs, model_runner):
model_config=model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
enable_custom_logit_processor=False,
)
batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch()

View File

@@ -132,6 +132,11 @@ class Sampler(nn.Module):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""
assert logits.shape[0] == len(sampling_batch_info), (
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
for _, (
processor,
batch_mask,
@@ -139,6 +144,11 @@ class Sampler(nn.Module):
# Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
assert batch_mask.shape[0] == len(sampling_batch_info), (
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
# Apply the processor to the logits
logits[batch_mask] = processor(
logits[batch_mask],

View File

@@ -595,6 +595,9 @@ class ScheduleBatch:
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None
# Enable custom logit processor
enable_custom_logit_processor: bool = False
@classmethod
def init_new(
cls,
@@ -605,6 +608,7 @@ class ScheduleBatch:
model_config: ModelConfig,
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
):
return cls(
reqs=reqs,
@@ -618,6 +622,7 @@ class ScheduleBatch:
has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
)
def batch_size(self):
@@ -1201,6 +1206,7 @@ class ScheduleBatch:
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
enable_custom_logit_processor=self.enable_custom_logit_processor,
)
def __str__(self):

View File

@@ -966,6 +966,7 @@ class Scheduler:
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
new_batch.prepare_for_extend()
@@ -1520,6 +1521,7 @@ class Scheduler:
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
idle_batch.prepare_for_idle()
return idle_batch

View File

@@ -89,7 +89,10 @@ class SamplingBatchInfo:
).to(device, non_blocking=True)
# Check if any request has custom logit processor
has_custom_logit_processor = any(r.custom_logit_processor for r in reqs)
has_custom_logit_processor = (
batch.enable_custom_logit_processor # check the flag first.
and any(r.custom_logit_processor for r in reqs) # then check the requests.
)
if has_custom_logit_processor:
# Merge the same type of custom logit processors together
@@ -247,8 +250,7 @@ class SamplingBatchInfo:
self, unfinished_indices: List[int], new_indices: torch.Tensor
):
"""Filter the custom logit processor and custom params"""
if not self.custom_logit_processor:
return
self.custom_logit_processor = {
k: (p, mask[new_indices])
for k, (p, mask) in self.custom_logit_processor.items()
@@ -258,7 +260,9 @@ class SamplingBatchInfo:
}
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
if len(self) == 0:
# If the custom logit processor is an empty dict, set the flag to False,
# and set the custom logit processor and custom params to None.
if len(self.custom_logit_processor) == 0:
self.custom_logit_processor = None
self.custom_params = None
self.has_custom_logit_processor = False
@@ -290,8 +294,8 @@ class SamplingBatchInfo:
@staticmethod
def merge_custom_logit_processor(
lhs: Optional[Dict[str, torch.Tensor]],
rhs: Optional[Dict[str, torch.Tensor]],
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
bs1: int,
bs2: int,
device: str,
@@ -319,27 +323,22 @@ class SamplingBatchInfo:
)
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
assert merged_dict[k][1].shape[0] == bs1 + bs2, (
f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
f"\n{lhs=}\n{rhs=}"
)
return merged_dict
def merge_batch(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
# Merge the logit bias tensor
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
# Merge the custom logit processors and custom params lists
if self.has_custom_logit_processor or other.has_custom_logit_processor:
# Merge the custom logit processors
@@ -360,6 +359,22 @@ class SamplingBatchInfo:
# Set the flag to True if any of the two has custom logit processor
self.has_custom_logit_processor = True
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
# please make sure any merge operation with len(self) or len(other) is done before
# the merge operation of the temperatures tensor below.
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
def apply_logits_bias(self, logits: torch.Tensor):
# Apply logit_bias
if self.logit_bias is not None: