Fix sampler nan check when calling top_k_top_p_sampling_from_probs (#5546)

This commit is contained in:
Yubo Wang
2025-04-19 21:47:23 -07:00
committed by GitHub
parent 613b197e57
commit 20f1c8e374
2 changed files with 8 additions and 9 deletions

View File

@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Optional, Union
import torch
from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream
@@ -109,7 +109,7 @@ def _top_p_sampling_from_probs_internal(
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_top_p_arr = (
@@ -135,7 +135,7 @@ def top_p_sampling_from_probs(
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
@@ -194,7 +194,7 @@ def _top_k_top_p_sampling_from_probs_internal(
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
@@ -225,7 +225,7 @@ def top_k_top_p_sampling_from_probs(
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-k and top-p sampling from probabilities,