[Feature] Add sampler custom logits processor (#2396)
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
||||
|
||||
@@ -35,6 +36,10 @@ class Sampler(nn.Module):
|
||||
):
|
||||
logits = logits_output.next_token_logits
|
||||
|
||||
# Apply the custom logit processors if registered in the sampling info.
|
||||
if sampling_info.has_custom_logit_processor:
|
||||
self._apply_custom_logit_processor(logits, sampling_info)
|
||||
|
||||
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||
logits = torch.where(
|
||||
@@ -121,6 +126,29 @@ class Sampler(nn.Module):
|
||||
|
||||
return batch_next_token_ids
|
||||
|
||||
def _apply_custom_logit_processor(
|
||||
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
|
||||
):
|
||||
"""Apply custom logit processors to the logits.
|
||||
This function will modify the logits in-place."""
|
||||
|
||||
for _, (
|
||||
processor,
|
||||
batch_mask,
|
||||
) in sampling_batch_info.custom_logit_processor.items():
|
||||
# Get the batch indices that need to be processed
|
||||
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
|
||||
|
||||
# Apply the processor to the logits
|
||||
logits[batch_mask] = processor(
|
||||
logits[batch_mask],
|
||||
[sampling_batch_info.custom_params[i] for i in batch_indices],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Custom logit processor {processor.__class__.__name__} is applied."
|
||||
)
|
||||
|
||||
|
||||
def top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user