fix sgl-kernel unit tests (#5666)
This commit is contained in:
@@ -47,6 +47,16 @@ def baseline_scaled_mm(
|
||||
).to(out_dtype)
|
||||
|
||||
|
||||
def is_sm100_supported(device=None) -> bool:
|
||||
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
||||
torch.version.cuda >= "12.8"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_sm100_supported(),
|
||||
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100",
|
||||
)
|
||||
@pytest.mark.parametrize("num_experts", [8, 16])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
||||
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
|
||||
Reference in New Issue
Block a user