[Enhancement] Custom Logit Processor Improvement (#2998)
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
@@ -132,6 +132,11 @@ class Sampler(nn.Module):
|
||||
"""Apply custom logit processors to the logits.
|
||||
This function will modify the logits in-place."""
|
||||
|
||||
assert logits.shape[0] == len(sampling_batch_info), (
|
||||
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
|
||||
f"sampling_batch_info ({len(sampling_batch_info)})"
|
||||
)
|
||||
|
||||
for _, (
|
||||
processor,
|
||||
batch_mask,
|
||||
@@ -139,6 +144,11 @@ class Sampler(nn.Module):
|
||||
# Get the batch indices that need to be processed
|
||||
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
||||
|
||||
assert batch_mask.shape[0] == len(sampling_batch_info), (
|
||||
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
|
||||
f"sampling_batch_info ({len(sampling_batch_info)})"
|
||||
)
|
||||
|
||||
# Apply the processor to the logits
|
||||
logits[batch_mask] = processor(
|
||||
logits[batch_mask],
|
||||
|
||||
Reference in New Issue
Block a user