[sgl-kernel] Support FlashInfer top_k_top_p_sampling_from_logits (#9060)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
Yuan Luo
2025-08-15 01:56:36 +08:00
committed by GitHub
parent 432f2053dd
commit 53dcc750b6
6 changed files with 349 additions and 5 deletions

View File

@@ -5,6 +5,54 @@ import sgl_kernel
import torch
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("k", [100])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_sampling_from_probs_logits_top_k_first_alignment(
batch_size, vocab_size, k, p
):
torch.manual_seed(42)
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
generator_logits = torch.Generator("cuda:0")
generator_probs = generator_logits.clone_state()
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, filter_apply_order="top_k_first", generator=generator_logits
)
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
torch.softmax(logits, dim=-1),
k,
p,
filter_apply_order="top_k_first",
generator=generator_probs,
)
assert torch.all(samples == samples_ref)
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("k", [100])
@pytest.mark.parametrize("p", [0.1, 0.5])
def test_top_k_top_p_sampling_from_probs_logits_joint_alignment(
batch_size, vocab_size, k, p
):
torch.manual_seed(42)
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
generator_logits = torch.Generator("cuda:0")
generator_probs = generator_logits.clone_state()
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, filter_apply_order="joint", generator=generator_logits
)
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
torch.softmax(logits, dim=-1),
k,
p,
filter_apply_order="joint",
generator=generator_probs,
)
assert torch.all(samples == samples_ref)
@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("p", [0.1, 0.5])