diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index e0f434a19..4b3910bda 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -100,17 +100,16 @@ class Sampler(nn.Module): probs, sampling_info.min_ps ) else: + # Check Nan will throw exception, only check when crash_on_warnings is True + check_nan = self.use_nan_detection and crash_on_warnings() batch_next_token_ids = top_k_top_p_sampling_from_probs( probs, sampling_info.top_ks, sampling_info.top_ps, filter_apply_order="joint", + check_nan=check_nan, ) - if self.use_nan_detection: - logger.warning("Detected errors during sampling!") - batch_next_token_ids = torch.zeros_like(batch_next_token_ids) - elif global_server_args_dict["sampling_backend"] == "pytorch": # A slower fallback implementation with torch native operations. batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( diff --git a/sgl-kernel/python/sgl_kernel/sampling.py b/sgl-kernel/python/sgl_kernel/sampling.py index 5bc0be6c3..59bc8c351 100644 --- a/sgl-kernel/python/sgl_kernel/sampling.py +++ b/sgl-kernel/python/sgl_kernel/sampling.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream @@ -109,7 +109,7 @@ def _top_p_sampling_from_probs_internal( top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: with probs.device as device: probs = probs.float() maybe_top_p_arr = ( @@ -135,7 +135,7 @@ def top_p_sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities, this operator implements GPU-based rejection sampling without explicit sorting. @@ -194,7 +194,7 @@ def _top_k_top_p_sampling_from_probs_internal( top_p_val: float, deterministic: bool, generator: Optional[torch.Generator], -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: with probs.device as device: probs = probs.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None @@ -225,7 +225,7 @@ def top_k_top_p_sampling_from_probs( deterministic: bool = True, generator: Optional[torch.Generator] = None, check_nan: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py Fused GPU kernel for top-k and top-p sampling from probabilities,