sgl scaled_fp8_quant support output padding (#4861)
This commit is contained in:
@@ -50,6 +50,7 @@ if _is_cuda:
|
|||||||
def scaled_fp8_quant(
|
def scaled_fp8_quant(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
num_token_padding: Optional[int] = None,
|
||||||
use_per_token_if_dynamic: bool = False,
|
use_per_token_if_dynamic: bool = False,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@@ -59,6 +60,8 @@ if _is_cuda:
|
|||||||
input (torch.Tensor): Input tensor to be quantized
|
input (torch.Tensor): Input tensor to be quantized
|
||||||
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
||||||
If None, scales will be computed dynamically.
|
If None, scales will be computed dynamically.
|
||||||
|
num_token_padding (Optional[int]): If specified, pad the first dimension
|
||||||
|
of the output to at least this value.
|
||||||
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
||||||
determines the quantization granularity:
|
determines the quantization granularity:
|
||||||
- True: compute scale per token
|
- True: compute scale per token
|
||||||
@@ -75,6 +78,8 @@ if _is_cuda:
|
|||||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||||
shape = input.shape
|
shape = input.shape
|
||||||
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
|
if num_token_padding:
|
||||||
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||||
|
|
||||||
if scale is None:
|
if scale is None:
|
||||||
|
|||||||
@@ -457,12 +457,9 @@ class Fp8LinearOp:
|
|||||||
qinput, x_scale = sgl_scaled_fp8_quant(
|
qinput, x_scale = sgl_scaled_fp8_quant(
|
||||||
input_2d,
|
input_2d,
|
||||||
input_scale,
|
input_scale,
|
||||||
|
num_token_padding=self.output_padding,
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||||
)
|
)
|
||||||
if self.output_padding:
|
|
||||||
pad_size = max(self.output_padding - qinput.shape[0], 0)
|
|
||||||
if pad_size > 0:
|
|
||||||
qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
|
|
||||||
else:
|
else:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(
|
qinput, x_scale = ops.scaled_fp8_quant(
|
||||||
input_2d,
|
input_2d,
|
||||||
|
|||||||
@@ -82,6 +82,61 @@ if is_cuda:
|
|||||||
dequantize_per_token(ref_y, scale, dtype),
|
dequantize_per_token(ref_y, scale, dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
def test_scaled_fp8_quant_with_padding(dtype) -> None:
|
||||||
|
original_rows = 5
|
||||||
|
x = (torch.randn(size=(original_rows, 16), device="cuda") * 13).to(dtype)
|
||||||
|
|
||||||
|
padding_size = 10
|
||||||
|
|
||||||
|
# Test with dynamic quantization
|
||||||
|
y_dynamic, scale_dynamic = scaled_fp8_quant(
|
||||||
|
x, None, num_token_padding=padding_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify output shape has the padded size
|
||||||
|
assert y_dynamic.shape[0] == padding_size
|
||||||
|
assert y_dynamic.shape[1] == x.shape[1]
|
||||||
|
|
||||||
|
# Verify that the actual data in the non-padded region is correctly quantized
|
||||||
|
y_without_padding, scale_without_padding = scaled_fp8_quant(x, None)
|
||||||
|
torch.testing.assert_close(y_dynamic[:original_rows], y_without_padding)
|
||||||
|
|
||||||
|
# Test with static quantization
|
||||||
|
# First get a scale
|
||||||
|
_, scale = scaled_fp8_quant(x, None)
|
||||||
|
|
||||||
|
# Then use it for static quantization with padding
|
||||||
|
y_static, _ = scaled_fp8_quant(x, scale, num_token_padding=padding_size)
|
||||||
|
|
||||||
|
# Verify output shape has the padded size
|
||||||
|
assert y_static.shape[0] == padding_size
|
||||||
|
assert y_static.shape[1] == x.shape[1]
|
||||||
|
|
||||||
|
# Verify that the actual data in the non-padded region is correctly quantized
|
||||||
|
y_static_without_padding, _ = scaled_fp8_quant(x, scale)
|
||||||
|
torch.testing.assert_close(y_static[:original_rows], y_static_without_padding)
|
||||||
|
|
||||||
|
# Test with per-token dynamic quantization
|
||||||
|
y_per_token, scale_per_token = scaled_fp8_quant(
|
||||||
|
x, None, num_token_padding=padding_size, use_per_token_if_dynamic=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify output shape has the padded size
|
||||||
|
assert y_per_token.shape[0] == padding_size
|
||||||
|
assert y_per_token.shape[1] == x.shape[1]
|
||||||
|
|
||||||
|
# Verify that the actual data in the non-padded region is correctly quantized
|
||||||
|
y_per_token_without_padding, scale_per_token_without_padding = scaled_fp8_quant(
|
||||||
|
x, None, use_per_token_if_dynamic=True
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
y_per_token[:original_rows], y_per_token_without_padding
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
scale_per_token[:original_rows], scale_per_token_without_padding
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Run the specific test function directly
|
# Run the specific test function directly
|
||||||
|
|||||||
Reference in New Issue
Block a user