Fix sampler nan check when calling top_k_top_p_sampling_from_probs (#5546)
This commit is contained in:
@@ -100,17 +100,16 @@ class Sampler(nn.Module):
|
|||||||
probs, sampling_info.min_ps
|
probs, sampling_info.min_ps
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Check Nan will throw exception, only check when crash_on_warnings is True
|
||||||
|
check_nan = self.use_nan_detection and crash_on_warnings()
|
||||||
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
||||||
probs,
|
probs,
|
||||||
sampling_info.top_ks,
|
sampling_info.top_ks,
|
||||||
sampling_info.top_ps,
|
sampling_info.top_ps,
|
||||||
filter_apply_order="joint",
|
filter_apply_order="joint",
|
||||||
|
check_nan=check_nan,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_nan_detection:
|
|
||||||
logger.warning("Detected errors during sampling!")
|
|
||||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
|
||||||
|
|
||||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||||
# A slower fallback implementation with torch native operations.
|
# A slower fallback implementation with torch native operations.
|
||||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream
|
from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream
|
||||||
@@ -109,7 +109,7 @@ def _top_p_sampling_from_probs_internal(
|
|||||||
top_p_val: float,
|
top_p_val: float,
|
||||||
deterministic: bool,
|
deterministic: bool,
|
||||||
generator: Optional[torch.Generator],
|
generator: Optional[torch.Generator],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
with probs.device as device:
|
with probs.device as device:
|
||||||
probs = probs.float()
|
probs = probs.float()
|
||||||
maybe_top_p_arr = (
|
maybe_top_p_arr = (
|
||||||
@@ -135,7 +135,7 @@ def top_p_sampling_from_probs(
|
|||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
check_nan: bool = False,
|
check_nan: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||||
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
|
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
|
||||||
this operator implements GPU-based rejection sampling without explicit sorting.
|
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||||
@@ -194,7 +194,7 @@ def _top_k_top_p_sampling_from_probs_internal(
|
|||||||
top_p_val: float,
|
top_p_val: float,
|
||||||
deterministic: bool,
|
deterministic: bool,
|
||||||
generator: Optional[torch.Generator],
|
generator: Optional[torch.Generator],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
with probs.device as device:
|
with probs.device as device:
|
||||||
probs = probs.float()
|
probs = probs.float()
|
||||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||||
@@ -225,7 +225,7 @@ def top_k_top_p_sampling_from_probs(
|
|||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
check_nan: bool = False,
|
check_nan: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||||
Fused GPU kernel for top-k and top-p sampling from probabilities,
|
Fused GPU kernel for top-k and top-p sampling from probabilities,
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user