sgl scaled_fp8_quant support output padding (#4861)

This commit is contained in:
Xiaoyu Zhang
2025-04-02 23:53:57 +08:00
committed by GitHub
parent 3fadc64793
commit e9c6ce461d
3 changed files with 61 additions and 4 deletions

View File

@@ -457,12 +457,9 @@ class Fp8LinearOp:
qinput, x_scale = sgl_scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic,
)
if self.output_padding:
pad_size = max(self.output_padding - qinput.shape[0], 0)
if pad_size > 0:
qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
else:
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,