support fp32 in sampling_scaling_penalties kernel (#3121)
This commit is contained in:
@@ -2,10 +2,14 @@ import pytest
|
||||
import torch
|
||||
from sgl_kernel import sampling_scaling_penalties
|
||||
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65]
|
||||
vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767]
|
||||
dtypes = [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 65])
|
||||
@pytest.mark.parametrize("vocab_size", [2048, 4096, 8192, 16384, 32768, 32767])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||
|
||||
@pytest.mark.parametrize("batch_size", batch_sizes)
|
||||
@pytest.mark.parametrize("vocab_size", vocab_sizes)
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
def test_sampling_scaling_penalties(batch_size, vocab_size, dtype):
|
||||
device = torch.device("cuda")
|
||||
rtol = 1e-3
|
||||
|
||||
Reference in New Issue
Block a user