[Enhancement] Custom Logit Processor Improvement (#2998)
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
@@ -232,6 +232,7 @@ def extend(reqs, model_runner):
|
||||
model_config=model_runner.model_config,
|
||||
enable_overlap=False,
|
||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||
enable_custom_logit_processor=False,
|
||||
)
|
||||
batch.prepare_for_extend()
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
|
||||
@@ -132,6 +132,11 @@ class Sampler(nn.Module):
|
||||
"""Apply custom logit processors to the logits.
|
||||
This function will modify the logits in-place."""
|
||||
|
||||
assert logits.shape[0] == len(sampling_batch_info), (
|
||||
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
|
||||
f"sampling_batch_info ({len(sampling_batch_info)})"
|
||||
)
|
||||
|
||||
for _, (
|
||||
processor,
|
||||
batch_mask,
|
||||
@@ -139,6 +144,11 @@ class Sampler(nn.Module):
|
||||
# Get the batch indices that need to be processed
|
||||
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
||||
|
||||
assert batch_mask.shape[0] == len(sampling_batch_info), (
|
||||
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
|
||||
f"sampling_batch_info ({len(sampling_batch_info)})"
|
||||
)
|
||||
|
||||
# Apply the processor to the logits
|
||||
logits[batch_mask] = processor(
|
||||
logits[batch_mask],
|
||||
|
||||
@@ -595,6 +595,9 @@ class ScheduleBatch:
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[SpecInfo] = None
|
||||
|
||||
# Enable custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
@@ -605,6 +608,7 @@ class ScheduleBatch:
|
||||
model_config: ModelConfig,
|
||||
enable_overlap: bool,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -618,6 +622,7 @@ class ScheduleBatch:
|
||||
has_grammar=any(req.grammar for req in reqs),
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -1201,6 +1206,7 @@ class ScheduleBatch:
|
||||
return_logprob=self.return_logprob,
|
||||
decoding_reqs=self.decoding_reqs,
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -966,6 +966,7 @@ class Scheduler:
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
)
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
@@ -1520,6 +1521,7 @@ class Scheduler:
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
)
|
||||
idle_batch.prepare_for_idle()
|
||||
return idle_batch
|
||||
|
||||
@@ -89,7 +89,10 @@ class SamplingBatchInfo:
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
# Check if any request has custom logit processor
|
||||
has_custom_logit_processor = any(r.custom_logit_processor for r in reqs)
|
||||
has_custom_logit_processor = (
|
||||
batch.enable_custom_logit_processor # check the flag first.
|
||||
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
||||
)
|
||||
|
||||
if has_custom_logit_processor:
|
||||
# Merge the same type of custom logit processors together
|
||||
@@ -247,8 +250,7 @@ class SamplingBatchInfo:
|
||||
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
||||
):
|
||||
"""Filter the custom logit processor and custom params"""
|
||||
if not self.custom_logit_processor:
|
||||
return
|
||||
|
||||
self.custom_logit_processor = {
|
||||
k: (p, mask[new_indices])
|
||||
for k, (p, mask) in self.custom_logit_processor.items()
|
||||
@@ -258,7 +260,9 @@ class SamplingBatchInfo:
|
||||
}
|
||||
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
||||
|
||||
if len(self) == 0:
|
||||
# If the custom logit processor is an empty dict, set the flag to False,
|
||||
# and set the custom logit processor and custom params to None.
|
||||
if len(self.custom_logit_processor) == 0:
|
||||
self.custom_logit_processor = None
|
||||
self.custom_params = None
|
||||
self.has_custom_logit_processor = False
|
||||
@@ -290,8 +294,8 @@ class SamplingBatchInfo:
|
||||
|
||||
@staticmethod
|
||||
def merge_custom_logit_processor(
|
||||
lhs: Optional[Dict[str, torch.Tensor]],
|
||||
rhs: Optional[Dict[str, torch.Tensor]],
|
||||
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
||||
rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
||||
bs1: int,
|
||||
bs2: int,
|
||||
device: str,
|
||||
@@ -319,27 +323,22 @@ class SamplingBatchInfo:
|
||||
)
|
||||
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
|
||||
|
||||
assert merged_dict[k][1].shape[0] == bs1 + bs2, (
|
||||
f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
|
||||
f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
|
||||
f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
|
||||
f"\n{lhs=}\n{rhs=}"
|
||||
)
|
||||
|
||||
return merged_dict
|
||||
|
||||
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
||||
# Merge the logit bias tensor
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
||||
|
||||
# Merge the custom logit processors and custom params lists
|
||||
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||
# Merge the custom logit processors
|
||||
@@ -360,6 +359,22 @@ class SamplingBatchInfo:
|
||||
# Set the flag to True if any of the two has custom logit processor
|
||||
self.has_custom_logit_processor = True
|
||||
|
||||
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
|
||||
# please make sure any merge operation with len(self) or len(other) is done before
|
||||
# the merge operation of the temperatures tensor below.
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
||||
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor):
|
||||
# Apply logit_bias
|
||||
if self.logit_bias is not None:
|
||||
|
||||
Reference in New Issue
Block a user