fix sgl-kernel unit tests (#5666)

This commit is contained in:
Yineng Zhang
2025-04-23 01:18:30 -07:00
committed by GitHub
parent e62c49557d
commit 15fabcc07f
9 changed files with 313 additions and 0 deletions

View File

@@ -47,6 +47,16 @@ def baseline_scaled_mm(
).to(out_dtype)
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
@pytest.mark.skipif(
not is_sm100_supported(),
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100",
)
@pytest.mark.parametrize("num_experts", [8, 16])
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):