Remove vllm ops scaled fp8 quant and accelerate per token quant by 20-28% (#4215)

Co-authored-by: Stefan He <bhe@linkedin.com>
This commit is contained in:
Stefan He
2025-03-12 00:08:03 -07:00
committed by GitHub
parent 7130a7cea9
commit e0917e6bd0
5 changed files with 202 additions and 37 deletions

View File

@@ -1,3 +1,5 @@
from typing import Optional
import torch
from torch import nn
@@ -40,3 +42,60 @@ class CustomOp(nn.Module):
return self.forward_hip
else:
return self.forward_native
if _is_cuda:
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
determines the quantization granularity:
- True: compute scale per token
- False: compute single scale per tensor
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
"""
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
shape = input.shape
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None:
# Dynamic scaling
if use_per_token_if_dynamic:
scale = torch.empty(
(shape[0], 1), device=input.device, dtype=torch.float32
)
sgl_per_token_quant_fp8(input, output, scale)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=False
) # False for dynamic
else:
# Static scaling
assert (
scale.numel() == 1
), f"Expected scalar scale, got numel={scale.numel()}"
sgl_per_tensor_quant_fp8(
input, output, scale, is_static=True
) # True for static
return output, scale