restruct compressed_tensors_w8a8_fp8 (#5475)
This commit is contained in:
@@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
Fp8LinearOp,
|
apply_fp8_linear,
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
||||||
@@ -29,7 +29,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
self.is_static_input_scheme = is_static_input_scheme
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@@ -149,11 +148,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
return apply_fp8_linear(
|
||||||
return self.fp8_linear.apply(
|
|
||||||
input=x,
|
input=x,
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
weight_scale=layer.weight_scale,
|
weight_scale=layer.weight_scale,
|
||||||
input_scale=layer.input_scale,
|
input_scale=layer.input_scale,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
|
use_per_token_if_dynamic=True,
|
||||||
|
compressed_tensor_quant=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -5,7 +6,7 @@ import torch
|
|||||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_ops as vllm_ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
VLLM_AVAILABLE = True
|
VLLM_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -234,6 +235,43 @@ def channel_quant_to_tensor_quant(
|
|||||||
return x_q_tensor, scale
|
return x_q_tensor, scale
|
||||||
|
|
||||||
|
|
||||||
|
def _process_scaled_mm_output(output, input_2d_shape, output_shape):
|
||||||
|
if type(output) is tuple and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
return torch.narrow(output, 0, 0, input_2d_shape[0]).view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_fallback_scaled_mm(
|
||||||
|
qinput,
|
||||||
|
weight,
|
||||||
|
x_scale,
|
||||||
|
weight_scale,
|
||||||
|
input_2d_shape,
|
||||||
|
output_shape,
|
||||||
|
bias,
|
||||||
|
input_dtype,
|
||||||
|
):
|
||||||
|
global TORCH_DEVICE_IDENTITY
|
||||||
|
if TORCH_DEVICE_IDENTITY is None:
|
||||||
|
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32, device=weight.device)
|
||||||
|
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
weight,
|
||||||
|
scale_a=TORCH_DEVICE_IDENTITY,
|
||||||
|
scale_b=TORCH_DEVICE_IDENTITY,
|
||||||
|
out_dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = _process_scaled_mm_output(output, input_2d_shape, output_shape)
|
||||||
|
x_scale = torch.narrow(x_scale, 0, 0, input_2d_shape[0])
|
||||||
|
|
||||||
|
output = output * x_scale * weight_scale.t()
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output.to(dtype=input_dtype)
|
||||||
|
|
||||||
|
|
||||||
def apply_fp8_linear(
|
def apply_fp8_linear(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
@@ -241,211 +279,38 @@ def apply_fp8_linear(
|
|||||||
input_scale: Optional[torch.Tensor] = None,
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
input_scale_ub: Optional[torch.Tensor] = None,
|
input_scale_ub: Optional[torch.Tensor] = None,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
cutlass_fp8_supported: bool = True,
|
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
||||||
use_per_token_if_dynamic: bool = False,
|
use_per_token_if_dynamic: bool = False,
|
||||||
|
pad_output: Optional[bool] = None,
|
||||||
|
compressed_tensor_quant: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Note: we pad the input because torch._scaled_mm is more performant
|
||||||
|
# for matrices with batch dimension > 16.
|
||||||
|
# This could change in the future.
|
||||||
|
# We also don't pad when using torch.compile,
|
||||||
|
# as it breaks with dynamic shapes.
|
||||||
|
if pad_output is None:
|
||||||
|
pad_output = not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
|
||||||
|
output_padding = 17 if pad_output else None
|
||||||
|
|
||||||
# View input as 2D matrix for fp8 methods
|
# View input as 2D matrix for fp8 methods
|
||||||
input_2d = input.view(-1, input.shape[-1])
|
input_2d = input.view(-1, input.shape[-1])
|
||||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
output_shape = [*input.shape[:-1], weight.shape[1]]
|
||||||
|
|
||||||
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale
|
if compressed_tensor_quant:
|
||||||
if input_scale is not None:
|
|
||||||
assert input_scale.numel() == 1
|
|
||||||
# broadcast per-tensor scale to per-token scale when supporting cutlass
|
|
||||||
qinput, x_scale = static_quant_fp8(
|
|
||||||
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# default use per-token quantization if dynamic
|
|
||||||
if _is_cuda:
|
|
||||||
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
|
||||||
else:
|
|
||||||
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
|
|
||||||
# final solution should be: 1. add support to per-tensor activation scaling.
|
|
||||||
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
|
||||||
if _is_hip and weight_scale.numel() == 1:
|
|
||||||
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
|
||||||
input_2d,
|
|
||||||
input_scale,
|
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
qinput, x_scale = per_token_group_quant_fp8(
|
|
||||||
input_2d, group_size=input_2d.shape[1]
|
|
||||||
)
|
|
||||||
|
|
||||||
if cutlass_fp8_supported:
|
|
||||||
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
|
||||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
|
||||||
output = vllm_ops.cutlass_scaled_mm(
|
|
||||||
qinput,
|
|
||||||
weight,
|
|
||||||
out_dtype=input.dtype,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=weight_scale,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
weight_scale.numel() == weight.shape[1]
|
|
||||||
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
|
||||||
output = fp8_scaled_mm(
|
|
||||||
qinput,
|
|
||||||
weight,
|
|
||||||
x_scale,
|
|
||||||
weight_scale,
|
|
||||||
out_dtype=input.dtype,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
return output.view(*output_shape)
|
|
||||||
|
|
||||||
# torch.scaled_mm supports per tensor weights + activations only
|
|
||||||
# so fallback to naive if per channel or per token
|
|
||||||
else:
|
|
||||||
per_tensor_weights = weight_scale.numel() == 1
|
|
||||||
per_tensor_activations = x_scale.numel() == 1
|
|
||||||
|
|
||||||
if per_tensor_weights and per_tensor_activations:
|
|
||||||
# Fused GEMM_DQ
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
qinput,
|
|
||||||
weight,
|
|
||||||
out_dtype=input.dtype,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=weight_scale,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
# A fix for discrepancy in scaled_mm which returns tuple
|
|
||||||
# for torch < 2.5 and a single value in torch >= 2.5
|
|
||||||
if type(output) is tuple and len(output) == 2:
|
|
||||||
output = output[0]
|
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Fallback for channelwise case, where we use unfused DQ
|
|
||||||
# due to limitations with scaled_mm
|
|
||||||
|
|
||||||
# Symmetric quantized GEMM by definition computes the following:
|
|
||||||
# C = (s_x * X) (s_w * W) + bias
|
|
||||||
# This is equivalent to dequantizing the weights and activations
|
|
||||||
# before applying a GEMM.
|
|
||||||
#
|
|
||||||
# In order to compute quantized operands, a quantized kernel
|
|
||||||
# will rewrite the above like so:
|
|
||||||
# C = s_w * s_x * (X * W) + bias
|
|
||||||
#
|
|
||||||
# For the scaled_mm fallback case, we break this down, since it
|
|
||||||
# does not support s_w being a vector.
|
|
||||||
|
|
||||||
# Making sure the dummy tensor is on the same device as the weight
|
|
||||||
global TORCH_DEVICE_IDENTITY
|
|
||||||
if TORCH_DEVICE_IDENTITY is None:
|
|
||||||
TORCH_DEVICE_IDENTITY = torch.ones(
|
|
||||||
1, dtype=torch.float32, device=weight.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# GEMM
|
|
||||||
# This computes C = (X * W).
|
|
||||||
# Output in fp32 to allow subsequent ops to happen in-place
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
qinput,
|
|
||||||
weight,
|
|
||||||
scale_a=TORCH_DEVICE_IDENTITY,
|
|
||||||
scale_b=TORCH_DEVICE_IDENTITY,
|
|
||||||
out_dtype=torch.float32,
|
|
||||||
)
|
|
||||||
# A fix for discrepancy in scaled_mm which returns tuple
|
|
||||||
# for torch < 2.5 and a single value in torch >= 2.5
|
|
||||||
if type(output) is tuple and len(output) == 2:
|
|
||||||
output = output[0]
|
|
||||||
# Unpad (undo num_token_padding)
|
|
||||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
|
||||||
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
|
||||||
|
|
||||||
# DQ
|
|
||||||
# C = sw * sx * (X * W) + bias
|
|
||||||
output = output * x_scale * weight_scale.t()
|
|
||||||
if bias is not None:
|
|
||||||
output = output + bias
|
|
||||||
return output.to(dtype=input.dtype).view(*output_shape)
|
|
||||||
|
|
||||||
|
|
||||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
|
|
||||||
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
|
|
||||||
# https://github.com/vllm-project/vllm/issues/14397
|
|
||||||
class Fp8LinearOp:
|
|
||||||
"""
|
|
||||||
This class executes a FP8 linear layer using cutlass if supported and
|
|
||||||
torch.scaled_mm otherwise.
|
|
||||||
It needs to be a class instead of a method so that config can be read
|
|
||||||
in the __init__ method, as reading config is not allowed inside forward.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
|
||||||
use_per_token_if_dynamic: bool = False,
|
|
||||||
pad_output: Optional[bool] = None,
|
|
||||||
):
|
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported
|
|
||||||
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
|
||||||
|
|
||||||
# Note: we pad the input because torch._scaled_mm is more performant
|
|
||||||
# for matrices with batch dimension > 16.
|
|
||||||
# This could change in the future.
|
|
||||||
# We also don't pad when using torch.compile,
|
|
||||||
# as it breaks with dynamic shapes.
|
|
||||||
if pad_output is None:
|
|
||||||
enable_torch_compile = get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE")
|
|
||||||
pad_output = not enable_torch_compile
|
|
||||||
self.output_padding = 17 if pad_output else None
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
weight_scale: torch.Tensor,
|
|
||||||
input_scale: Optional[torch.Tensor] = None,
|
|
||||||
input_scale_ub: Optional[torch.Tensor] = None,
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
# TODO(luka) remove this parameter in favor of __init__
|
|
||||||
use_per_token_if_dynamic: Optional[bool] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
|
||||||
|
|
||||||
# View input as 2D matrix for fp8 methods
|
|
||||||
input_2d = input.view(-1, input.shape[-1])
|
|
||||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
|
||||||
|
|
||||||
# TODO(luka) this is here because currently MLA only decides this
|
|
||||||
# during the forward method instead of in __init__.
|
|
||||||
if use_per_token_if_dynamic is None:
|
|
||||||
use_per_token_if_dynamic = self.use_per_token_if_dynamic
|
|
||||||
|
|
||||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||||
# for sgl-kernel fp8_scaled_mm, it support per channel W now
|
# for sgl-kernel fp8_scaled_mm, it support per channel W now
|
||||||
if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
|
if cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
|
||||||
if _is_cuda:
|
qinput, x_scale = scaled_fp8_quant(
|
||||||
qinput, x_scale = scaled_fp8_quant(
|
input_2d,
|
||||||
input_2d,
|
input_scale,
|
||||||
input_scale,
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
|
||||||
input_2d,
|
|
||||||
input_scale,
|
|
||||||
scale_ub=input_scale_ub,
|
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fused GEMM_DQ
|
# Fused GEMM_DQ
|
||||||
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||||
output = vllm_ops.cutlass_scaled_mm(
|
output = ops.cutlass_scaled_mm(
|
||||||
qinput,
|
qinput,
|
||||||
weight,
|
weight,
|
||||||
out_dtype=input.dtype,
|
out_dtype=input.dtype,
|
||||||
@@ -471,20 +336,21 @@ class Fp8LinearOp:
|
|||||||
# so fallback to naive if per channel or per token
|
# so fallback to naive if per channel or per token
|
||||||
else:
|
else:
|
||||||
# Maybe apply padding to output, see comment in __init__
|
# Maybe apply padding to output, see comment in __init__
|
||||||
if _is_cuda:
|
qinput, x_scale = (
|
||||||
qinput, x_scale = scaled_fp8_quant(
|
scaled_fp8_quant(
|
||||||
input_2d,
|
input_2d,
|
||||||
input_scale,
|
input_scale,
|
||||||
num_token_padding=self.output_padding,
|
num_token_padding=output_padding,
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||||
)
|
)
|
||||||
else:
|
if _is_cuda
|
||||||
qinput, x_scale = vllm_ops.scaled_fp8_quant(
|
else ops.scaled_fp8_quant(
|
||||||
input_2d,
|
input_2d,
|
||||||
input_scale,
|
input_scale,
|
||||||
num_token_padding=self.output_padding,
|
num_token_padding=output_padding,
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
per_tensor_weights = weight_scale.numel() == 1
|
per_tensor_weights = weight_scale.numel() == 1
|
||||||
per_tensor_activations = x_scale.numel() == 1
|
per_tensor_activations = x_scale.numel() == 1
|
||||||
@@ -499,12 +365,7 @@ class Fp8LinearOp:
|
|||||||
scale_b=weight_scale,
|
scale_b=weight_scale,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
# A fix for discrepancy in scaled_mm which returns tuple
|
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
|
||||||
# for torch < 2.5 and a single value in torch >= 2.5
|
|
||||||
if type(output) is tuple and len(output) == 2:
|
|
||||||
output = output[0]
|
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
use_per_token_if_dynamic
|
use_per_token_if_dynamic
|
||||||
@@ -527,10 +388,7 @@ class Fp8LinearOp:
|
|||||||
scale_b=weight_scale.t(),
|
scale_b=weight_scale.t(),
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
|
||||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
|
||||||
output = output.view(*output_shape)
|
|
||||||
return output
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Fallback for channelwise case, where we use unfused DQ
|
# Fallback for channelwise case, where we use unfused DQ
|
||||||
@@ -547,36 +405,110 @@ class Fp8LinearOp:
|
|||||||
#
|
#
|
||||||
# For the scaled_mm fallback case, we break this down, since it
|
# For the scaled_mm fallback case, we break this down, since it
|
||||||
# does not support s_w being a vector.
|
# does not support s_w being a vector.
|
||||||
|
return _apply_fallback_scaled_mm(
|
||||||
# GEMM
|
|
||||||
# This computes C = (X * W).
|
|
||||||
# Output in fp32 to allow subsequent ops to happen in-place
|
|
||||||
|
|
||||||
# Making sure the dummy tensor is on the same device as the weight
|
|
||||||
global TORCH_DEVICE_IDENTITY
|
|
||||||
if TORCH_DEVICE_IDENTITY is None:
|
|
||||||
TORCH_DEVICE_IDENTITY = torch.ones(
|
|
||||||
1, dtype=torch.float32, device=weight.device
|
|
||||||
)
|
|
||||||
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
qinput,
|
qinput,
|
||||||
weight,
|
weight,
|
||||||
scale_a=TORCH_DEVICE_IDENTITY,
|
x_scale,
|
||||||
scale_b=TORCH_DEVICE_IDENTITY,
|
weight_scale,
|
||||||
out_dtype=torch.float32,
|
input_2d.shape,
|
||||||
|
output_shape,
|
||||||
|
bias,
|
||||||
|
input.dtype,
|
||||||
)
|
)
|
||||||
# A fix for discrepancy in scaled_mm which returns tuple
|
else:
|
||||||
# for torch < 2.5 and a single value in torch >= 2.5
|
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale
|
||||||
if type(output) is tuple and len(output) == 2:
|
if input_scale is not None:
|
||||||
output = output[0]
|
assert input_scale.numel() == 1
|
||||||
# Unpad (undo num_token_padding)
|
# broadcast per-tensor scale to per-token scale when supporting cutlass
|
||||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
qinput, x_scale = static_quant_fp8(
|
||||||
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# default use per-token quantization if dynamic
|
||||||
|
if _is_cuda:
|
||||||
|
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
||||||
|
else:
|
||||||
|
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
|
||||||
|
# final solution should be: 1. add support to per-tensor activation scaling.
|
||||||
|
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
||||||
|
if _is_hip and weight_scale.numel() == 1:
|
||||||
|
qinput, x_scale = ops.scaled_fp8_quant(
|
||||||
|
input_2d,
|
||||||
|
input_scale,
|
||||||
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qinput, x_scale = per_token_group_quant_fp8(
|
||||||
|
input_2d, group_size=input_2d.shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
# DQ
|
if cutlass_fp8_supported:
|
||||||
# C = sw * sx * (X * W) + bias
|
try:
|
||||||
output = output * x_scale * weight_scale.t()
|
if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
|
||||||
if bias is not None:
|
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||||
output = output + bias
|
output = ops.cutlass_scaled_mm(
|
||||||
return output.to(dtype=input.dtype).view(*output_shape)
|
qinput,
|
||||||
|
weight,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=weight_scale,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
weight_scale.numel() == weight.shape[1]
|
||||||
|
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
||||||
|
output = fp8_scaled_mm(
|
||||||
|
qinput,
|
||||||
|
weight,
|
||||||
|
x_scale,
|
||||||
|
weight_scale,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
return output.view(*output_shape)
|
||||||
|
except (ImportError, NameError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# torch.scaled_mm supports per tensor weights + activations only
|
||||||
|
# so fallback to naive if per channel or per token
|
||||||
|
per_tensor_weights = weight_scale.numel() == 1
|
||||||
|
per_tensor_activations = x_scale.numel() == 1
|
||||||
|
|
||||||
|
if per_tensor_weights and per_tensor_activations:
|
||||||
|
# Fused GEMM_DQ
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
weight,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=weight_scale,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
return _process_scaled_mm_output(output, input_2d.shape, output_shape)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Fallback for channelwise case, where we use unfused DQ
|
||||||
|
# due to limitations with scaled_mm
|
||||||
|
|
||||||
|
# Symmetric quantized GEMM by definition computes the following:
|
||||||
|
# C = (s_x * X) (s_w * W) + bias
|
||||||
|
# This is equivalent to dequantizing the weights and activations
|
||||||
|
# before applying a GEMM.
|
||||||
|
#
|
||||||
|
# In order to compute quantized operands, a quantized kernel
|
||||||
|
# will rewrite the above like so:
|
||||||
|
# C = s_w * s_x * (X * W) + bias
|
||||||
|
#
|
||||||
|
# For the scaled_mm fallback case, we break this down, since it
|
||||||
|
# does not support s_w being a vector.
|
||||||
|
return _apply_fallback_scaled_mm(
|
||||||
|
qinput,
|
||||||
|
weight,
|
||||||
|
x_scale,
|
||||||
|
weight_scale,
|
||||||
|
input_2d.shape,
|
||||||
|
output_shape,
|
||||||
|
bias,
|
||||||
|
input.dtype,
|
||||||
|
)
|
||||||
|
|||||||
46
test/srt/models/test_compressed_tensors_models.py
Normal file
46
test/srt/models/test_compressed_tensors_models.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompressedTensorsLlama3FP8(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "RedHatAI/Meta-Llama-3.1-8B-FP8"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.45)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -20,6 +20,7 @@ suites = {
|
|||||||
TestFile("models/test_generation_models.py", 103),
|
TestFile("models/test_generation_models.py", 103),
|
||||||
TestFile("models/test_grok_models.py", 60),
|
TestFile("models/test_grok_models.py", 60),
|
||||||
TestFile("models/test_qwen_models.py", 82),
|
TestFile("models/test_qwen_models.py", 82),
|
||||||
|
TestFile("models/test_compressed_tensors_models.py", 100),
|
||||||
TestFile("models/test_reward_models.py", 83),
|
TestFile("models/test_reward_models.py", 83),
|
||||||
TestFile("models/test_gme_qwen_models.py", 45),
|
TestFile("models/test_gme_qwen_models.py", 45),
|
||||||
TestFile("models/test_clip_models.py", 100),
|
TestFile("models/test_clip_models.py", 100),
|
||||||
|
|||||||
Reference in New Issue
Block a user