fix sgl-kernel unit tests (#5666)
This commit is contained in:
23
sgl-kernel/tests/test_apply_token_bitmask_inplace.py
Normal file
23
sgl-kernel/tests/test_apply_token_bitmask_inplace.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import apply_token_bitmask_inplace_cuda
|
||||
|
||||
|
||||
def test_apply_token_bitmask_inplace_kernel():
|
||||
neginf = float("-inf")
|
||||
bool_mask = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], dtype=torch.bool)
|
||||
logits = torch.tensor(
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], dtype=torch.float32
|
||||
)
|
||||
expected = torch.where(bool_mask, logits, neginf)
|
||||
|
||||
logits_gpu = logits.to("cuda")
|
||||
bitmask = torch.tensor([0b1010101010], dtype=torch.int32).to("cuda")
|
||||
apply_token_bitmask_inplace_cuda(logits_gpu, bitmask)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(logits_gpu, expected.to("cuda"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_apply_token_bitmask_inplace_kernel()
|
||||
pytest.main([__file__])
|
||||
@@ -47,6 +47,16 @@ def baseline_scaled_mm(
|
||||
).to(out_dtype)
|
||||
|
||||
|
||||
def is_sm100_supported(device=None) -> bool:
|
||||
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
||||
torch.version.cuda >= "12.8"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_sm100_supported(),
|
||||
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100",
|
||||
)
|
||||
@pytest.mark.parametrize("num_experts", [8, 16])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
||||
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
|
||||
@@ -48,6 +48,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
|
||||
topk_group=topk_group,
|
||||
compiled=False,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=2.5,
|
||||
)
|
||||
|
||||
# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension
|
||||
|
||||
Reference in New Issue
Block a user