sgl scaled_fp8_quant support output padding (#4861)

This commit is contained in:
Xiaoyu Zhang
2025-04-02 23:53:57 +08:00
committed by GitHub
parent 3fadc64793
commit e9c6ce461d
3 changed files with 61 additions and 4 deletions

View File

@@ -50,6 +50,7 @@ if _is_cuda:
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@@ -59,6 +60,8 @@ if _is_cuda:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
num_token_padding (Optional[int]): If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
@@ -75,6 +78,8 @@ if _is_cuda:
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:

View File

@@ -457,12 +457,9 @@ class Fp8LinearOp:
qinput, x_scale = sgl_scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
if self.output_padding:
pad_size = max(self.output_padding - qinput.shape[0], 0)
if pad_size > 0:
qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
else:
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,