[Enhancement] Custom Logit Processor Improvement (#2998)

Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
Hongpeng Guo
2025-01-20 02:00:35 -08:00
committed by GitHub
parent 2584f6d944
commit 583697cd71
6 changed files with 79 additions and 28 deletions

View File

@@ -132,6 +132,11 @@ class Sampler(nn.Module):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""
assert logits.shape[0] == len(sampling_batch_info), (
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
for _, (
processor,
batch_mask,
@@ -139,6 +144,11 @@ class Sampler(nn.Module):
# Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
assert batch_mask.shape[0] == len(sampling_batch_info), (
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
# Apply the processor to the logits
logits[batch_mask] = processor(
logits[batch_mask],