[Enhancement] Custom Logit Processor Improvement (#2998)
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
@@ -89,7 +89,10 @@ class SamplingBatchInfo:
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
# Check if any request has custom logit processor
|
||||
has_custom_logit_processor = any(r.custom_logit_processor for r in reqs)
|
||||
has_custom_logit_processor = (
|
||||
batch.enable_custom_logit_processor # check the flag first.
|
||||
and any(r.custom_logit_processor for r in reqs) # then check the requests.
|
||||
)
|
||||
|
||||
if has_custom_logit_processor:
|
||||
# Merge the same type of custom logit processors together
|
||||
@@ -247,8 +250,7 @@ class SamplingBatchInfo:
|
||||
self, unfinished_indices: List[int], new_indices: torch.Tensor
|
||||
):
|
||||
"""Filter the custom logit processor and custom params"""
|
||||
if not self.custom_logit_processor:
|
||||
return
|
||||
|
||||
self.custom_logit_processor = {
|
||||
k: (p, mask[new_indices])
|
||||
for k, (p, mask) in self.custom_logit_processor.items()
|
||||
@@ -258,7 +260,9 @@ class SamplingBatchInfo:
|
||||
}
|
||||
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
|
||||
|
||||
if len(self) == 0:
|
||||
# If the custom logit processor is an empty dict, set the flag to False,
|
||||
# and set the custom logit processor and custom params to None.
|
||||
if len(self.custom_logit_processor) == 0:
|
||||
self.custom_logit_processor = None
|
||||
self.custom_params = None
|
||||
self.has_custom_logit_processor = False
|
||||
@@ -290,8 +294,8 @@ class SamplingBatchInfo:
|
||||
|
||||
@staticmethod
|
||||
def merge_custom_logit_processor(
|
||||
lhs: Optional[Dict[str, torch.Tensor]],
|
||||
rhs: Optional[Dict[str, torch.Tensor]],
|
||||
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
||||
rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
|
||||
bs1: int,
|
||||
bs2: int,
|
||||
device: str,
|
||||
@@ -319,27 +323,22 @@ class SamplingBatchInfo:
|
||||
)
|
||||
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
|
||||
|
||||
assert merged_dict[k][1].shape[0] == bs1 + bs2, (
|
||||
f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
|
||||
f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
|
||||
f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
|
||||
f"\n{lhs=}\n{rhs=}"
|
||||
)
|
||||
|
||||
return merged_dict
|
||||
|
||||
def merge_batch(self, other: "SamplingBatchInfo"):
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
||||
# Merge the logit bias tensor
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
||||
|
||||
# Merge the custom logit processors and custom params lists
|
||||
if self.has_custom_logit_processor or other.has_custom_logit_processor:
|
||||
# Merge the custom logit processors
|
||||
@@ -360,6 +359,22 @@ class SamplingBatchInfo:
|
||||
# Set the flag to True if any of the two has custom logit processor
|
||||
self.has_custom_logit_processor = True
|
||||
|
||||
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
|
||||
# please make sure any merge operation with len(self) or len(other) is done before
|
||||
# the merge operation of the temperatures tensor below.
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
self_val = getattr(self, item, None)
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
||||
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
|
||||
|
||||
def apply_logits_bias(self, logits: torch.Tensor):
|
||||
# Apply logit_bias
|
||||
if self.logit_bias is not None:
|
||||
|
||||
Reference in New Issue
Block a user