support fp32 in sampling_scaling_penalties kernel (#3121)

This commit is contained in:
Xiaoyu Zhang
2025-01-25 16:52:17 +08:00
committed by GitHub
parent 665e5e85f6
commit 5d9d15e70f
3 changed files with 26 additions and 5 deletions

View File

@@ -2,10 +2,14 @@ import pytest
import torch
from sgl_kernel import sampling_scaling_penalties
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65]
vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767]
dtypes = [torch.float32, torch.half, torch.bfloat16]
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 65])
@pytest.mark.parametrize("vocab_size", [2048, 4096, 8192, 16384, 32768, 32767])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("batch_size", batch_sizes)
@pytest.mark.parametrize("vocab_size", vocab_sizes)
@pytest.mark.parametrize("dtype", dtypes)
def test_sampling_scaling_penalties(batch_size, vocab_size, dtype):
device = torch.device("cuda")
rtol = 1e-3