Remove vllm ops scaled fp8 quant and accelerate per token quant by 20-28% (#4215)
Co-authored-by: Stefan He <bhe@linkedin.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -40,3 +42,60 @@ class CustomOp(nn.Module):
|
||||
return self.forward_hip
|
||||
else:
|
||||
return self.forward_native
|
||||
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 (8-bit floating point) format.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor to be quantized
|
||||
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
||||
If None, scales will be computed dynamically.
|
||||
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
||||
determines the quantization granularity:
|
||||
- True: compute scale per token
|
||||
- False: compute single scale per tensor
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- quantized_tensor: The FP8 quantized version of input
|
||||
- scale_tensor: The scaling factors used for quantization
|
||||
|
||||
Raises:
|
||||
AssertionError: If input is not 2D or if static scale's numel != 1
|
||||
"""
|
||||
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
||||
shape = input.shape
|
||||
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
|
||||
if scale is None:
|
||||
# Dynamic scaling
|
||||
if use_per_token_if_dynamic:
|
||||
scale = torch.empty(
|
||||
(shape[0], 1), device=input.device, dtype=torch.float32
|
||||
)
|
||||
sgl_per_token_quant_fp8(input, output, scale)
|
||||
else:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=False
|
||||
) # False for dynamic
|
||||
else:
|
||||
# Static scaling
|
||||
assert (
|
||||
scale.numel() == 1
|
||||
), f"Expected scalar scale, got numel={scale.numel()}"
|
||||
sgl_per_tensor_quant_fp8(
|
||||
input, output, scale, is_static=True
|
||||
) # True for static
|
||||
|
||||
return output, scale
|
||||
|
||||
Reference in New Issue
Block a user