sgl scaled_fp8_quant support output padding (#4861)

This commit is contained in:
Xiaoyu Zhang
2025-04-02 23:53:57 +08:00
committed by GitHub
parent 3fadc64793
commit e9c6ce461d
3 changed files with 61 additions and 4 deletions

View File

@@ -82,6 +82,61 @@ if is_cuda:
dequantize_per_token(ref_y, scale, dtype),
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant_with_padding(dtype) -> None:
original_rows = 5
x = (torch.randn(size=(original_rows, 16), device="cuda") * 13).to(dtype)
padding_size = 10
# Test with dynamic quantization
y_dynamic, scale_dynamic = scaled_fp8_quant(
x, None, num_token_padding=padding_size
)
# Verify output shape has the padded size
assert y_dynamic.shape[0] == padding_size
assert y_dynamic.shape[1] == x.shape[1]
# Verify that the actual data in the non-padded region is correctly quantized
y_without_padding, scale_without_padding = scaled_fp8_quant(x, None)
torch.testing.assert_close(y_dynamic[:original_rows], y_without_padding)
# Test with static quantization
# First get a scale
_, scale = scaled_fp8_quant(x, None)
# Then use it for static quantization with padding
y_static, _ = scaled_fp8_quant(x, scale, num_token_padding=padding_size)
# Verify output shape has the padded size
assert y_static.shape[0] == padding_size
assert y_static.shape[1] == x.shape[1]
# Verify that the actual data in the non-padded region is correctly quantized
y_static_without_padding, _ = scaled_fp8_quant(x, scale)
torch.testing.assert_close(y_static[:original_rows], y_static_without_padding)
# Test with per-token dynamic quantization
y_per_token, scale_per_token = scaled_fp8_quant(
x, None, num_token_padding=padding_size, use_per_token_if_dynamic=True
)
# Verify output shape has the padded size
assert y_per_token.shape[0] == padding_size
assert y_per_token.shape[1] == x.shape[1]
# Verify that the actual data in the non-padded region is correctly quantized
y_per_token_without_padding, scale_per_token_without_padding = scaled_fp8_quant(
x, None, use_per_token_if_dynamic=True
)
torch.testing.assert_close(
y_per_token[:original_rows], y_per_token_without_padding
)
torch.testing.assert_close(
scale_per_token[:original_rows], scale_per_token_without_padding
)
if __name__ == "__main__":
# Run the specific test function directly