feat: integrate sampling kernels into sgl-kernel (#3086)
Co-authored-by: Zihao Ye <expye@outlook.com>
This commit is contained in:
@@ -128,6 +128,7 @@ ext_modules = [
|
||||
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu",
|
||||
"3rdparty/flashinfer/csrc/norm.cu",
|
||||
"3rdparty/flashinfer/csrc/sampling.cu",
|
||||
"3rdparty/flashinfer/csrc/renorm.cu",
|
||||
],
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args={
|
||||
|
||||
@@ -11,12 +11,16 @@ from sgl_kernel.ops import (
|
||||
init_custom_reduce,
|
||||
int8_scaled_mm,
|
||||
lightning_attention_decode,
|
||||
min_p_sampling_from_probs,
|
||||
moe_align_block_size,
|
||||
register_graph_buffers,
|
||||
rmsnorm,
|
||||
rotary_embedding,
|
||||
sampling_scaling_penalties,
|
||||
silu_and_mul,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -31,11 +35,15 @@ __all__ = [
|
||||
"get_graph_buffer_ipc_meta",
|
||||
"init_custom_reduce",
|
||||
"int8_scaled_mm",
|
||||
"lightning_attention_decode",
|
||||
"min_p_sampling_from_probs",
|
||||
"moe_align_block_size",
|
||||
"register_graph_buffers",
|
||||
"rmsnorm",
|
||||
"rotary_embedding",
|
||||
"sampling_scaling_penalties",
|
||||
"lightning_attention_decode",
|
||||
"silu_and_mul",
|
||||
"top_k_renorm_prob",
|
||||
"top_k_top_p_sampling_from_probs",
|
||||
"top_p_renorm_prob",
|
||||
]
|
||||
|
||||
@@ -61,6 +61,30 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
|
||||
at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream);
|
||||
|
||||
// min p sampling from probs
|
||||
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
|
||||
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
// top k renorm probs
|
||||
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
|
||||
unsigned int top_k_val, int64_t cuda_stream);
|
||||
|
||||
// top p renorm probs
|
||||
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
|
||||
double top_p_val, int64_t cuda_stream);
|
||||
|
||||
// top k top p sampling from probs
|
||||
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
|
||||
at::Tensor success, std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
|
||||
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
// top p sampling from probs
|
||||
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
|
||||
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// trt_reduce
|
||||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
||||
@@ -94,4 +118,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)");
|
||||
// bmm fp8
|
||||
m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)");
|
||||
// min p sampling from probs
|
||||
m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, "Min P Sampling From Probs (CUDA)");
|
||||
// top k renorm probs
|
||||
m.def("top_k_renorm_probs", &top_k_renorm_probs, "Top K Renorm Probs (CUDA)");
|
||||
// top p renorm probs
|
||||
m.def("top_p_renorm_probs", &top_p_renorm_probs, "Top P Renorm Probs (CUDA)");
|
||||
// top k top p sampling from probs
|
||||
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, "Top K Top P Sampling From Probs (CUDA)");
|
||||
// top p sampling from probs
|
||||
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, "Top P Sampling From Probs (CUDA)");
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
|
||||
@@ -17,6 +17,9 @@ from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
|
||||
from sgl_kernel.ops._kernels import (
|
||||
lightning_attention_decode as _lightning_attention_decode,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import (
|
||||
min_p_sampling_from_probs as _min_p_sampling_from_probs,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
|
||||
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
|
||||
from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
|
||||
@@ -25,7 +28,19 @@ from sgl_kernel.ops._kernels import (
|
||||
sampling_scaling_penalties as _sampling_scaling_penalties,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul
|
||||
from sgl_kernel.ops.utils import _get_cache_buf, _get_cuda_stream
|
||||
from sgl_kernel.ops._kernels import top_k_renorm_probs as _top_k_renorm_probs
|
||||
from sgl_kernel.ops._kernels import (
|
||||
top_k_top_p_sampling_from_probs as _top_k_top_p_sampling_from_probs,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import top_p_renorm_probs as _top_p_renorm_probs
|
||||
from sgl_kernel.ops._kernels import (
|
||||
top_p_sampling_from_probs as _top_p_sampling_from_probs,
|
||||
)
|
||||
from sgl_kernel.ops.utils import (
|
||||
_get_cache_buf,
|
||||
_get_cuda_stream,
|
||||
_to_tensor_scalar_tuple,
|
||||
)
|
||||
|
||||
|
||||
def init_custom_reduce(
|
||||
@@ -236,3 +251,213 @@ def bmm_fp8(
|
||||
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
|
||||
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
|
||||
return out
|
||||
|
||||
|
||||
def _top_k_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
_top_k_renorm_probs(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_k_arr,
|
||||
top_k_val,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_k_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
) -> torch.Tensor:
|
||||
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
|
||||
|
||||
|
||||
top_k_renorm_prob = top_k_renorm_probs
|
||||
|
||||
|
||||
def _top_p_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
_top_p_renorm_probs(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_p_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
) -> torch.Tensor:
|
||||
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
|
||||
|
||||
|
||||
top_p_renorm_prob = top_p_renorm_probs
|
||||
|
||||
|
||||
def _top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
_top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
success,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
return samples, success
|
||||
|
||||
|
||||
def top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_p_sampling_from_probs_internal(
|
||||
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic
|
||||
)
|
||||
|
||||
|
||||
def _top_k_top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
_top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
success,
|
||||
maybe_top_k_arr,
|
||||
top_k_val,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
return samples, success
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
top_p: Union[torch.Tensor, float],
|
||||
filter_apply_order: str = "top_k_first",
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if filter_apply_order == "top_k_first":
|
||||
renorm_probs = top_k_renorm_probs(probs, top_k)
|
||||
return top_p_sampling_from_probs(
|
||||
renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan
|
||||
)
|
||||
elif filter_apply_order == "joint":
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_k_top_p_sampling_from_probs_internal(
|
||||
probs,
|
||||
uniform_samples,
|
||||
*_to_tensor_scalar_tuple(top_k),
|
||||
*_to_tensor_scalar_tuple(top_p),
|
||||
deterministic,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
||||
|
||||
|
||||
def _min_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_min_p_arr: Optional[torch.Tensor],
|
||||
min_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_min_p_arr = (
|
||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
_min_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
maybe_min_p_arr,
|
||||
min_p_val,
|
||||
deterministic,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def min_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
min_p: Union[torch.Tensor, float],
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if uniform_samples.dim() == 2:
|
||||
# Take the first row (round) of uniform_samples
|
||||
uniform_samples = uniform_samples[0]
|
||||
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _min_p_sampling_from_probs_internal(
|
||||
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
|
||||
)
|
||||
|
||||
@@ -17,3 +17,10 @@ def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
|
||||
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
|
||||
_cache_buf[key] = buf
|
||||
return buf
|
||||
|
||||
|
||||
def _to_tensor_scalar_tuple(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return (x, 0)
|
||||
else:
|
||||
return (None, x)
|
||||
|
||||
141
sgl-kernel/tests/test_sampling.py
Normal file
141
sgl-kernel/tests/test_sampling.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/93e1a2634e22355b0856246b032b285ad1d1da6b/tests/test_sampling.py
|
||||
|
||||
import pytest
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||
def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
|
||||
torch.manual_seed(42)
|
||||
if p == 0.1:
|
||||
k = int(vocab_size * 0.5)
|
||||
elif p == 0.5:
|
||||
k = int(vocab_size * 0.1)
|
||||
else:
|
||||
raise ValueError("p not recognized")
|
||||
max_top_k_trails = 32
|
||||
eps = 1e-4
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
# top-p mask
|
||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
|
||||
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
|
||||
# top-k mask
|
||||
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
||||
pivot = sorted_prob[:, k - 1]
|
||||
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
||||
# overall mask
|
||||
mask = torch.minimum(mask_top_p, mask_top_k)
|
||||
uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to(
|
||||
0
|
||||
)
|
||||
top_p_tensor = torch.full((batch_size,), p).to(0)
|
||||
top_k_tensor = torch.full((batch_size,), k).to(0)
|
||||
|
||||
num_trails = 1000
|
||||
for _ in range(num_trails):
|
||||
uniform_samples.uniform_()
|
||||
samples, success = sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
normalized_prob,
|
||||
uniform_samples,
|
||||
top_k_tensor,
|
||||
top_p_tensor,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
assert torch.all(success)
|
||||
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
|
||||
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
|
||||
torch.arange(batch_size), samples
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
|
||||
def test_top_p_renorm_probs(batch_size, vocab_size, p):
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
|
||||
mask.scatter_add_(1, indices, (cdf >= (1 - p)).int())
|
||||
renorm_prob_ground_truth = normalized_prob
|
||||
renorm_prob_ground_truth[mask == 0] = 0
|
||||
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p)
|
||||
torch.testing.assert_close(
|
||||
renorm_prob_ground_truth,
|
||||
renorm_prob,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
||||
@pytest.mark.parametrize("k", [10, 100, 500])
|
||||
def test_top_k_renorm_probs(batch_size, vocab_size, k):
|
||||
if k > vocab_size:
|
||||
pytest.skip("k should be less than vocab_size")
|
||||
torch.manual_seed(42)
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
||||
pivot = sorted_prob[:, k - 1]
|
||||
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
||||
renorm_prob_ground_truth = normalized_prob
|
||||
renorm_prob_ground_truth[mask == 0] = 0
|
||||
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k)
|
||||
torch.testing.assert_close(
|
||||
renorm_prob_ground_truth,
|
||||
renorm_prob,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
||||
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
|
||||
def test_min_p_sampling(batch_size, vocab_size, p):
|
||||
torch.manual_seed(42)
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||
# scale min-p
|
||||
top_probs = sorted_prob[:, -1].unsqueeze(-1)
|
||||
scaled_p = p * top_probs
|
||||
# min-p mask
|
||||
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
|
||||
mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int())
|
||||
uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0)
|
||||
min_p_tensor = torch.full((batch_size,), p).to(0)
|
||||
|
||||
num_trails = 1000
|
||||
for _ in range(num_trails):
|
||||
uniform_samples.uniform_()
|
||||
samples = sgl_kernel.min_p_sampling_from_probs(
|
||||
normalized_prob,
|
||||
uniform_samples,
|
||||
min_p_tensor,
|
||||
)
|
||||
|
||||
assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[
|
||||
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user