sgl scaled_fp8_quant support output padding (#4861)
This commit is contained in:
@@ -50,6 +50,7 @@ if _is_cuda:
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -59,6 +60,8 @@ if _is_cuda:
|
||||
input (torch.Tensor): Input tensor to be quantized
|
||||
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
||||
If None, scales will be computed dynamically.
|
||||
num_token_padding (Optional[int]): If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
||||
determines the quantization granularity:
|
||||
- True: compute scale per token
|
||||
@@ -75,6 +78,8 @@ if _is_cuda:
|
||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||
shape = input.shape
|
||||
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
|
||||
if scale is None:
|
||||
|
||||
Reference in New Issue
Block a user