[sgl-kernel] Support FlashInfer top_k_top_p_sampling_from_logits (#9060)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -5,6 +5,54 @@ import sgl_kernel
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("k", [100])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||
def test_top_k_top_p_sampling_from_probs_logits_top_k_first_alignment(
|
||||
batch_size, vocab_size, k, p
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
|
||||
generator_logits = torch.Generator("cuda:0")
|
||||
generator_probs = generator_logits.clone_state()
|
||||
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, filter_apply_order="top_k_first", generator=generator_logits
|
||||
)
|
||||
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
|
||||
torch.softmax(logits, dim=-1),
|
||||
k,
|
||||
p,
|
||||
filter_apply_order="top_k_first",
|
||||
generator=generator_probs,
|
||||
)
|
||||
assert torch.all(samples == samples_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("k", [100])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||
def test_top_k_top_p_sampling_from_probs_logits_joint_alignment(
|
||||
batch_size, vocab_size, k, p
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(batch_size, vocab_size, device="cuda:0") * 5
|
||||
generator_logits = torch.Generator("cuda:0")
|
||||
generator_probs = generator_logits.clone_state()
|
||||
samples = sgl_kernel.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, filter_apply_order="joint", generator=generator_logits
|
||||
)
|
||||
samples_ref = sgl_kernel.sampling.top_k_top_p_sampling_from_probs(
|
||||
torch.softmax(logits, dim=-1),
|
||||
k,
|
||||
p,
|
||||
filter_apply_order="joint",
|
||||
generator=generator_probs,
|
||||
)
|
||||
assert torch.all(samples == samples_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||
|
||||
Reference in New Issue
Block a user