diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 184cc08c4..72d188e71 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -128,6 +128,7 @@ ext_modules = [ "3rdparty/flashinfer/csrc/group_gemm_sm90.cu", "3rdparty/flashinfer/csrc/norm.cu", "3rdparty/flashinfer/csrc/sampling.cu", + "3rdparty/flashinfer/csrc/renorm.cu", ], include_dirs=include_dirs, extra_compile_args={ diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 9eaa64e50..c7fcd2742 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -11,12 +11,16 @@ from sgl_kernel.ops import ( init_custom_reduce, int8_scaled_mm, lightning_attention_decode, + min_p_sampling_from_probs, moe_align_block_size, register_graph_buffers, rmsnorm, rotary_embedding, sampling_scaling_penalties, silu_and_mul, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, ) __all__ = [ @@ -31,11 +35,15 @@ __all__ = [ "get_graph_buffer_ipc_meta", "init_custom_reduce", "int8_scaled_mm", + "lightning_attention_decode", + "min_p_sampling_from_probs", "moe_align_block_size", "register_graph_buffers", "rmsnorm", "rotary_embedding", "sampling_scaling_penalties", - "lightning_attention_decode", "silu_and_mul", + "top_k_renorm_prob", + "top_k_top_p_sampling_from_probs", + "top_p_renorm_prob", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index cd5df0789..876d62b7e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -61,6 +61,30 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); +// min p sampling from probs +void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + std::optional maybe_min_p_arr, double min_p_val, bool deterministic, + int64_t cuda_stream); + +// top k renorm probs +void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, + unsigned int top_k_val, int64_t cuda_stream); + +// top p renorm probs +void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, + double top_p_val, int64_t cuda_stream); + +// top k top p sampling from probs +void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic, + int64_t cuda_stream); + +// top p sampling from probs +void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic, + int64_t cuda_stream); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // trt_reduce m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); @@ -94,4 +118,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)"); // bmm fp8 m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)"); + // min p sampling from probs + m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, "Min P Sampling From Probs (CUDA)"); + // top k renorm probs + m.def("top_k_renorm_probs", &top_k_renorm_probs, "Top K Renorm Probs (CUDA)"); + // top p renorm probs + m.def("top_p_renorm_probs", &top_p_renorm_probs, "Top P Renorm Probs (CUDA)"); + // top k top p sampling from probs + m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, "Top K Top P Sampling From Probs (CUDA)"); + // top p sampling from probs + m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, "Top P Sampling From Probs (CUDA)"); } diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 0aead260b..cd69eb3c2 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple, Union import torch from sgl_kernel.ops._kernels import all_reduce as _all_reduce @@ -17,6 +17,9 @@ from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm from sgl_kernel.ops._kernels import ( lightning_attention_decode as _lightning_attention_decode, ) +from sgl_kernel.ops._kernels import ( + min_p_sampling_from_probs as _min_p_sampling_from_probs, +) from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm @@ -25,7 +28,19 @@ from sgl_kernel.ops._kernels import ( sampling_scaling_penalties as _sampling_scaling_penalties, ) from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul -from sgl_kernel.ops.utils import _get_cache_buf, _get_cuda_stream +from sgl_kernel.ops._kernels import top_k_renorm_probs as _top_k_renorm_probs +from sgl_kernel.ops._kernels import ( + top_k_top_p_sampling_from_probs as _top_k_top_p_sampling_from_probs, +) +from sgl_kernel.ops._kernels import top_p_renorm_probs as _top_p_renorm_probs +from sgl_kernel.ops._kernels import ( + top_p_sampling_from_probs as _top_p_sampling_from_probs, +) +from sgl_kernel.ops.utils import ( + _get_cache_buf, + _get_cuda_stream, + _to_tensor_scalar_tuple, +) def init_custom_reduce( @@ -236,3 +251,213 @@ def bmm_fp8( workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale) return out + + +def _top_k_renorm_probs_internal( + probs: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + renorm_probs = torch.empty_like(probs) + _top_k_renorm_probs( + probs, + renorm_probs, + maybe_top_k_arr, + top_k_val, + _get_cuda_stream(device), + ) + return renorm_probs + + +def top_k_renorm_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], +) -> torch.Tensor: + return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k)) + + +top_k_renorm_prob = top_k_renorm_probs + + +def _top_p_renorm_probs_internal( + probs: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + renorm_probs = torch.empty_like(probs) + _top_p_renorm_probs( + probs, + renorm_probs, + maybe_top_p_arr, + top_p_val, + _get_cuda_stream(device), + ) + return renorm_probs + + +def top_p_renorm_probs( + probs: torch.Tensor, + top_p: Union[torch.Tensor, float], +) -> torch.Tensor: + return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p)) + + +top_p_renorm_prob = top_p_renorm_probs + + +def _top_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + _top_p_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_top_p_arr, + top_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples, success + + +def top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_p: Union[torch.Tensor, float], + deterministic: bool = True, + check_nan: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _top_p_sampling_from_probs_internal( + probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic + ) + + +def _top_k_top_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + _top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples, success + + +def top_k_top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_k: Union[torch.Tensor, int], + top_p: Union[torch.Tensor, float], + filter_apply_order: str = "top_k_first", + deterministic: bool = True, + check_nan: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if filter_apply_order == "top_k_first": + renorm_probs = top_k_renorm_probs(probs, top_k) + return top_p_sampling_from_probs( + renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan + ) + elif filter_apply_order == "joint": + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _top_k_top_p_sampling_from_probs_internal( + probs, + uniform_samples, + *_to_tensor_scalar_tuple(top_k), + *_to_tensor_scalar_tuple(top_p), + deterministic, + ) + else: + raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") + + +def _min_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_min_p_arr: Optional[torch.Tensor], + min_p_val: float, + deterministic: bool, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_min_p_arr = ( + maybe_min_p_arr.float() if maybe_min_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + _min_p_sampling_from_probs( + probs, + uniform_samples, + samples, + maybe_min_p_arr, + min_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples + + +def min_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + min_p: Union[torch.Tensor, float], + deterministic: bool = True, + check_nan: bool = False, +) -> torch.Tensor: + if uniform_samples.dim() == 2: + # Take the first row (round) of uniform_samples + uniform_samples = uniform_samples[0] + + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _min_p_sampling_from_probs_internal( + probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic + ) diff --git a/sgl-kernel/src/sgl-kernel/ops/utils.py b/sgl-kernel/src/sgl-kernel/ops/utils.py index af5fccbb7..31a6bbf99 100644 --- a/sgl-kernel/src/sgl-kernel/ops/utils.py +++ b/sgl-kernel/src/sgl-kernel/ops/utils.py @@ -17,3 +17,10 @@ def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: buf = torch.empty(bytes, dtype=torch.uint8, device=device) _cache_buf[key] = buf return buf + + +def _to_tensor_scalar_tuple(x): + if isinstance(x, torch.Tensor): + return (x, 0) + else: + return (None, x) diff --git a/sgl-kernel/tests/test_sampling.py b/sgl-kernel/tests/test_sampling.py new file mode 100644 index 000000000..7d3bc5059 --- /dev/null +++ b/sgl-kernel/tests/test_sampling.py @@ -0,0 +1,141 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/93e1a2634e22355b0856246b032b285ad1d1da6b/tests/test_sampling.py + +import pytest +import sgl_kernel +import torch + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5]) +def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): + torch.manual_seed(42) + if p == 0.1: + k = int(vocab_size * 0.5) + elif p == 0.5: + k = int(vocab_size * 0.1) + else: + raise ValueError("p not recognized") + max_top_k_trails = 32 + eps = 1e-4 + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + # top-p mask + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) + # top-k mask + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int() + # overall mask + mask = torch.minimum(mask_top_p, mask_top_k) + uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to( + 0 + ) + top_p_tensor = torch.full((batch_size,), p).to(0) + top_k_tensor = torch.full((batch_size,), k).to(0) + + num_trails = 1000 + for _ in range(num_trails): + uniform_samples.uniform_() + samples, success = sgl_kernel.top_k_top_p_sampling_from_probs( + normalized_prob, + uniform_samples, + top_k_tensor, + top_p_tensor, + filter_apply_order="joint", + ) + assert torch.all(success) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ + torch.arange(batch_size), samples + ] + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) +def test_top_p_renorm_probs(batch_size, vocab_size, p): + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask.scatter_add_(1, indices, (cdf >= (1 - p)).int()) + renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p) + torch.testing.assert_close( + renorm_prob_ground_truth, + renorm_prob, + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_top_k_renorm_probs(batch_size, vocab_size, k): + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask = (normalized_prob >= pivot.unsqueeze(-1)).int() + renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k) + torch.testing.assert_close( + renorm_prob_ground_truth, + renorm_prob, + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1]) +def test_min_p_sampling(batch_size, vocab_size, p): + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + # scale min-p + top_probs = sorted_prob[:, -1].unsqueeze(-1) + scaled_p = p * top_probs + # min-p mask + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int()) + uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0) + min_p_tensor = torch.full((batch_size,), p).to(0) + + num_trails = 1000 + for _ in range(num_trails): + uniform_samples.uniform_() + samples = sgl_kernel.min_p_sampling_from_probs( + normalized_prob, + uniform_samples, + min_p_tensor, + ) + + assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[ + torch.nonzero(mask[torch.arange(batch_size), samples] == 0) + ] + + +if __name__ == "__main__": + pytest.main([__file__])