[Feature] Add sampler custom logits processor (#2396)

Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
Hongpeng Guo
2025-01-19 14:46:53 -08:00
committed by GitHub
parent 3bcf5ecea7
commit e403d23757
12 changed files with 302 additions and 4 deletions

View File

@@ -1,11 +1,12 @@
import logging
from typing import List
from typing import Dict, List
import torch
from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
@@ -35,6 +36,10 @@ class Sampler(nn.Module):
):
logits = logits_output.next_token_logits
# Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor:
self._apply_custom_logit_processor(logits, sampling_info)
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
@@ -121,6 +126,29 @@ class Sampler(nn.Module):
return batch_next_token_ids
def _apply_custom_logit_processor(
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""
for _, (
processor,
batch_mask,
) in sampling_batch_info.custom_logit_processor.items():
# Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
# Apply the processor to the logits
logits[batch_mask] = processor(
logits[batch_mask],
[sampling_batch_info.custom_params[i] for i in batch_indices],
)
logger.debug(
f"Custom logit processor {processor.__class__.__name__} is applied."
)
def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor,