[Feat] Update sgl-kernel flashinfer to latest main version (#5500)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -5,8 +5,8 @@ import sgl_kernel
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5])
|
||||
def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
|
||||
torch.manual_seed(42)
|
||||
@@ -16,14 +16,13 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
|
||||
k = int(vocab_size * 0.1)
|
||||
else:
|
||||
raise ValueError("p not recognized")
|
||||
max_top_k_trails = 32
|
||||
eps = 1e-4
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
# top-p mask
|
||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
|
||||
mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
||||
mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int())
|
||||
# top-k mask
|
||||
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
||||
@@ -31,40 +30,35 @@ def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p):
|
||||
mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
||||
# overall mask
|
||||
mask = torch.minimum(mask_top_p, mask_top_k)
|
||||
uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to(
|
||||
0
|
||||
)
|
||||
top_p_tensor = torch.full((batch_size,), p).to(0)
|
||||
top_k_tensor = torch.full((batch_size,), k).to(0)
|
||||
top_p_tensor = torch.full((batch_size,), p, device="cuda:0")
|
||||
top_k_tensor = torch.full((batch_size,), k, device="cuda:0")
|
||||
|
||||
num_trails = 1000
|
||||
for _ in range(num_trails):
|
||||
uniform_samples.uniform_()
|
||||
samples, success = sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
samples = sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
normalized_prob,
|
||||
uniform_samples,
|
||||
top_k_tensor,
|
||||
top_p_tensor,
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
assert torch.all(success)
|
||||
assert torch.all(samples < vocab_size) and torch.all(samples >= 0)
|
||||
assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[
|
||||
torch.arange(batch_size), samples
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
|
||||
def test_top_p_renorm_probs(batch_size, vocab_size, p):
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
||||
torch.manual_seed(42)
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||
cdf = torch.cumsum(sorted_prob, dim=-1)
|
||||
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
|
||||
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
||||
mask.scatter_add_(1, indices, (cdf >= (1 - p)).int())
|
||||
renorm_prob_ground_truth = normalized_prob
|
||||
renorm_prob_ground_truth = normalized_prob.clone()
|
||||
renorm_prob_ground_truth[mask == 0] = 0
|
||||
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
||||
dim=-1, keepdim=True
|
||||
@@ -79,56 +73,54 @@ def test_top_p_renorm_probs(batch_size, vocab_size, p):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("k", [10, 100, 500])
|
||||
def test_top_k_renorm_probs(batch_size, vocab_size, k):
|
||||
if k > vocab_size:
|
||||
pytest.skip("k should be less than vocab_size")
|
||||
torch.manual_seed(42)
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
sorted_prob, _ = torch.sort(normalized_prob, descending=True)
|
||||
pivot = sorted_prob[:, k - 1]
|
||||
mask = (normalized_prob >= pivot.unsqueeze(-1)).int()
|
||||
renorm_prob_ground_truth = normalized_prob
|
||||
renorm_prob_ground_truth = normalized_prob.clone()
|
||||
renorm_prob_ground_truth[mask == 0] = 0
|
||||
renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k)
|
||||
torch.testing.assert_close(
|
||||
renorm_prob_ground_truth,
|
||||
renorm_prob,
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
for i in range(batch_size):
|
||||
torch.testing.assert_close(
|
||||
renorm_prob_ground_truth[i],
|
||||
renorm_prob[i],
|
||||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256])
|
||||
@pytest.mark.parametrize("batch_size", [1, 99, 989])
|
||||
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
|
||||
@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1])
|
||||
def test_min_p_sampling(batch_size, vocab_size, p):
|
||||
torch.manual_seed(42)
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
|
||||
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")
|
||||
normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
|
||||
sorted_prob, indices = torch.sort(normalized_prob, descending=False)
|
||||
# scale min-p
|
||||
top_probs = sorted_prob[:, -1].unsqueeze(-1)
|
||||
scaled_p = p * top_probs
|
||||
# min-p mask
|
||||
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0)
|
||||
mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32, device="cuda:0")
|
||||
mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int())
|
||||
uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0)
|
||||
min_p_tensor = torch.full((batch_size,), p).to(0)
|
||||
min_p_tensor = torch.full((batch_size,), p, device="cuda:0")
|
||||
|
||||
num_trails = 1000
|
||||
for _ in range(num_trails):
|
||||
uniform_samples.uniform_()
|
||||
samples = sgl_kernel.min_p_sampling_from_probs(
|
||||
normalized_prob,
|
||||
uniform_samples,
|
||||
min_p_tensor,
|
||||
)
|
||||
|
||||
@@ -136,6 +128,10 @@ def test_min_p_sampling(batch_size, vocab_size, p):
|
||||
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
|
||||
]
|
||||
|
||||
assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[
|
||||
torch.nonzero(mask[torch.arange(batch_size), samples] == 0)
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user