Apply sgl w8a8 fp8 kernel (#3148)
This commit is contained in:
@@ -250,9 +250,11 @@ class ModelConfig:
|
||||
"compressed-tensors",
|
||||
"experts_int8",
|
||||
"w8a8_int8",
|
||||
"w8a8_fp8",
|
||||
]
|
||||
compatible_quantization_methods = {
|
||||
"w8a8_int8": ["compressed-tensors", "compressed_tensors"]
|
||||
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
|
||||
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
|
||||
}
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
@@ -18,6 +18,7 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.layers.parameter import (
|
||||
BasevLLMParameter,
|
||||
BlockQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
PerTensorScaleParameter,
|
||||
@@ -27,7 +28,6 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,6 +16,7 @@ __all__ = [
|
||||
"ModelWeightParameter",
|
||||
"ChannelQuantScaleParameter",
|
||||
"GroupQuantScaleParameter",
|
||||
"BlockQuantScaleParameter",
|
||||
"PackedColumnParameter",
|
||||
"RowvLLMParameter",
|
||||
]
|
||||
@@ -221,6 +222,15 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter):
|
||||
pass
|
||||
|
||||
|
||||
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
"""
|
||||
Parameter class for weight scales loaded for weights with
|
||||
block-wise quantization. Uses both column and row parallelism.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PerTensorScaleParameter(BasevLLMParameter):
|
||||
"""
|
||||
Parameter class for scales where the number of scales is
|
||||
|
||||
@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
|
||||
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
@@ -50,6 +51,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"qqq": QQQConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"w8a8_int8": W8A8Int8Config,
|
||||
"w8a8_fp8": W8A8Fp8Config,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -13,12 +13,11 @@ from sglang.srt.layers.linear import (
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
||||
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
|
||||
@@ -16,9 +16,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d,
|
||||
apply_fp8_linear,
|
||||
convert_to_channelwise,
|
||||
cutlass_fp8_supported,
|
||||
per_tensor_dequantize,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
@@ -29,14 +27,21 @@ from sglang.srt.layers.linear import (
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
BlockQuantScaleParameter,
|
||||
apply_fp8_linear,
|
||||
apply_w8a8_block_fp8_linear,
|
||||
cutlass_fp8_supported,
|
||||
input_to_float8,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
@@ -305,15 +310,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
assert weight_scale.numel() == 1
|
||||
weight_scale = convert_to_channelwise(
|
||||
weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
|
||||
if self.cutlass_fp8_supported or self.use_marlin:
|
||||
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
|
||||
qweight, weight_scale = per_token_group_quant_fp8(
|
||||
layer.weight, layer.weight.shape[-1]
|
||||
)
|
||||
weight_scale = weight_scale.t().contiguous()
|
||||
else:
|
||||
# per-tensor quantization
|
||||
qweight, weight_scale = input_to_float8(layer.weight)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
@@ -330,23 +335,19 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data, requires_grad=False
|
||||
)
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
|
||||
# cutlass sgl-kernel and marlin only support per-channel scale
|
||||
if self.cutlass_fp8_supported or self.use_marlin:
|
||||
weight = layer.weight
|
||||
weight_scale = convert_to_channelwise(
|
||||
layer.weight_scale, layer.logical_widths
|
||||
)
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
else:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip_:
|
||||
if is_hip():
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
|
||||
@@ -29,7 +29,7 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
if _is_cuda:
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8
|
||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,7 +70,8 @@ def _per_token_group_quant_fp8(
|
||||
# Quant
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
y_s = _absmax / fp8_max
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
y_s_inv = 1.0 / y_s
|
||||
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||
tl.store(y_s_ptr, y_s)
|
||||
@@ -140,7 +141,7 @@ def per_token_group_quant_fp8(
|
||||
x: The input tenosr with ndim >= 2.
|
||||
group_size: The group size used for quantization.
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
|
||||
dtype: The dype of output tensor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
||||
@@ -241,6 +242,132 @@ def sglang_per_token_group_quant_fp8(
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def sglang_per_token_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
):
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
x_s = torch.empty(
|
||||
x.shape[0],
|
||||
1,
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
sgl_per_token_quant_fp8(x, x_q, x_s)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _static_quant_fp8(
|
||||
# Pointers to inputs and output
|
||||
y_ptr,
|
||||
y_q_ptr,
|
||||
y_s_ptr,
|
||||
y_s_repeat_ptr,
|
||||
# Stride of input
|
||||
y_stride,
|
||||
# Collums of input
|
||||
N,
|
||||
# Information for float8
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
# Meta-parameters
|
||||
BLOCK: tl.constexpr,
|
||||
REPEAT_SCALE: tl.constexpr,
|
||||
):
|
||||
"""A Triton-accelerated function to perform quantization using the given scale on a
|
||||
tensor
|
||||
|
||||
This function converts the tensor values into float8 values.
|
||||
"""
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
g_id = tl.program_id(0)
|
||||
y_ptr += g_id * y_stride
|
||||
y_q_ptr += g_id * y_stride
|
||||
if REPEAT_SCALE:
|
||||
y_s_repeat_ptr += g_id
|
||||
|
||||
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
||||
mask = cols < N
|
||||
|
||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
y_s = tl.load(y_s_ptr).to(tl.float32)
|
||||
y_s_inv = 1.0 / y_s
|
||||
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||
if REPEAT_SCALE:
|
||||
tl.store(y_s_repeat_ptr, y_s)
|
||||
|
||||
|
||||
def static_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
repeat_scale: bool = False,
|
||||
dtype: torch.dtype = fp8_type_,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform static quantization using the given scale on an input tensor `x`.
|
||||
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
|
||||
Args:
|
||||
x: The input tenosr with ndim >= 2.
|
||||
x_s: The quantization scale.
|
||||
repeat_scale: Whether to broadcast per-tensor scale to per-channel scale.
|
||||
dtype: The dype of output tensor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
||||
"""
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
assert x_s.numel() == 1, "only supports per-tensor scale"
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_max = finfo.max
|
||||
|
||||
if is_hip_:
|
||||
fp8_max = 224.0
|
||||
|
||||
fp8_min = -fp8_max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
M = x.numel() // x.shape[-1]
|
||||
N = x.shape[-1]
|
||||
if repeat_scale:
|
||||
x_s_repeat = torch.empty(
|
||||
(M, 1),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
else:
|
||||
x_s_repeat = None
|
||||
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
_static_quant_fp8[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
x_s_repeat,
|
||||
N,
|
||||
N,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
BLOCK=BLOCK,
|
||||
REPEAT_SCALE=repeat_scale,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
x_s = x_s_repeat if repeat_scale else x_s
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_block_fp8_matmul(
|
||||
# Pointers to inputs and output
|
||||
|
||||
@@ -2,13 +2,23 @@ import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
|
||||
from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_cuda_version,
|
||||
get_device_capability,
|
||||
is_hip,
|
||||
)
|
||||
|
||||
use_vllm_cutlass_w8a8_fp8_kernel = os.environ.get(
|
||||
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL", default=False
|
||||
)
|
||||
|
||||
is_hip_ = is_hip()
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
@@ -18,6 +28,25 @@ _is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
if _is_cuda:
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
||||
|
||||
if use_vllm_cutlass_w8a8_fp8_kernel:
|
||||
from vllm import _custom_ops as ops
|
||||
else:
|
||||
from sgl_kernel import fp8_scaled_mm
|
||||
|
||||
|
||||
def cutlass_fp8_supported():
|
||||
if not _is_cuda:
|
||||
return False
|
||||
major, minor = get_device_capability()
|
||||
cuda_version = get_cuda_version()
|
||||
if major >= 9:
|
||||
return cuda_version >= (12, 0)
|
||||
elif major == 8 and minor == 9:
|
||||
return cuda_version >= (12, 4)
|
||||
return False
|
||||
|
||||
|
||||
def normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight: torch.Tensor,
|
||||
@@ -158,10 +187,121 @@ def block_quant_to_tensor_quant(
|
||||
return x_q_tensor, scale
|
||||
|
||||
|
||||
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
"""
|
||||
Parameter class for weight scales loaded for weights with
|
||||
block-wise quantization. Uses both column and row parallelism.
|
||||
"""
|
||||
def apply_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
input_scale_ub: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_fp8_supported: bool = True,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
||||
|
||||
pass
|
||||
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale
|
||||
if input_scale is not None:
|
||||
assert input_scale.numel() == 1
|
||||
# broadcast per-tensor scale to per-token scale when supporting cutlass
|
||||
qinput, x_scale = static_quant_fp8(
|
||||
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
|
||||
)
|
||||
else:
|
||||
# default use per-token quantization if dynamic
|
||||
if _is_cuda:
|
||||
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
||||
else:
|
||||
qinput, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, group_size=input_2d.shape[1]
|
||||
)
|
||||
|
||||
if cutlass_fp8_supported:
|
||||
if use_vllm_cutlass_w8a8_fp8_kernel:
|
||||
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||
output = ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
weight_scale.numel() == weight.shape[1]
|
||||
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
||||
output = fp8_scaled_mm(
|
||||
qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
# torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
else:
|
||||
per_tensor_weights = weight_scale.numel() == 1
|
||||
per_tensor_activations = x_scale.numel() == 1
|
||||
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
# Fused GEMM_DQ
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||
|
||||
else:
|
||||
# Fallback for channelwise case, where we use unfused DQ
|
||||
# due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
# This is equivalent to dequantizing the weights and activations
|
||||
# before applying a GEMM.
|
||||
#
|
||||
# In order to compute quantized operands, a quantized kernel
|
||||
# will rewrite the above like so:
|
||||
# C = s_w * s_x * (X * W) + bias
|
||||
#
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# Making sure the dummy tensor is on the same device as the weight
|
||||
global TORCH_DEVICE_IDENTITY
|
||||
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
scale_a=TORCH_DEVICE_IDENTITY,
|
||||
scale_b=TORCH_DEVICE_IDENTITY,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * weight_scale.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear,
|
||||
convert_to_channelwise,
|
||||
cutlass_fp8_supported,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
@@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
|
||||
|
||||
# Initialize logger for the module
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -161,6 +162,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
||||
layer.weight, layer.weight_scale, layer.logical_widths
|
||||
)
|
||||
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
|
||||
# cutlass sgl-kernel only supports per-channel scale
|
||||
if self.cutlass_fp8_supported:
|
||||
max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
||||
|
||||
|
||||
126
python/sglang/srt/layers/quantization/w8a8_fp8.py
Normal file
126
python/sglang/srt/layers/quantization/w8a8_fp8.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.linear import LinearMethodBase
|
||||
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
|
||||
class W8A8Fp8Config(QuantizationConfig):
|
||||
"""Config class for W8A8 FP8 Quantization.
|
||||
|
||||
- Weight: static, per-channel, symmetric
|
||||
- Activation: dynamic, per-token, symmetric
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
return "w8a8_fp8"
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return W8A8Fp8LinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quantization_config: W8A8Fp8Config):
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale.detach()
|
||||
if is_hip():
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale
|
||||
)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs
|
||||
):
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
self.logical_widths = output_partition_sizes
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
return apply_fp8_linear(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
bias=bias,
|
||||
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||
)
|
||||
@@ -405,6 +405,7 @@ class ServerArgs:
|
||||
"gguf",
|
||||
"modelopt",
|
||||
"w8a8_int8",
|
||||
"w8a8_fp8",
|
||||
],
|
||||
help="The quantization method.",
|
||||
)
|
||||
|
||||
@@ -52,11 +52,13 @@ import triton
|
||||
import zmq
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from packaging.version import Version, parse
|
||||
from starlette.routing import Mount
|
||||
from torch import nn
|
||||
from torch.func import functional_call
|
||||
from torch.library import Library
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
from triton.runtime.cache import (
|
||||
FileCacheManager,
|
||||
default_cache_dir,
|
||||
@@ -1431,6 +1433,12 @@ def rank0_print(msg: str):
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
def get_cuda_version():
|
||||
if torch.version.cuda:
|
||||
return tuple(map(int, torch.version.cuda.split(".")))
|
||||
return (0, 0)
|
||||
|
||||
|
||||
def launch_dummy_health_check_server(host, port):
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Response
|
||||
|
||||
@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
|
||||
@@ -63,7 +64,7 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
|
||||
out, scale = per_token_group_quant_fp8(x, group_size)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
|
||||
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20)
|
||||
)
|
||||
self.assertTrue(torch.allclose(scale, ref_scale))
|
||||
|
||||
@@ -85,6 +86,71 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
|
||||
self._per_token_group_quant_fp8(*params)
|
||||
|
||||
|
||||
# For test
|
||||
def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
|
||||
"""Function to perform static quantization on an input tensor `x` using native torch.
|
||||
|
||||
It converts the tensor values into float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
"""
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
assert x_s.numel() == 1, "only supports per-tensor scale"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // x.shape[-1], x.shape[-1])
|
||||
x_s_inv = 1.0 / x_s
|
||||
x_q = (x_ * x_s_inv).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
class TestStaticQuantFP8(unittest.TestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _static_quant_fp8(self, num_tokens, d, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
x = torch.rand(num_tokens, d, dtype=dtype)
|
||||
fp8_max = torch.finfo(torch.float8_e4m3fn).max
|
||||
x_s = x.max() / fp8_max
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out, _ = native_static_quant_fp8(x, x_s)
|
||||
out, _ = static_quant_fp8(x, x_s, repeat_scale=True)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
|
||||
)
|
||||
|
||||
def test_static_quant_fp8(self):
|
||||
for params in itertools.product(
|
||||
self.NUM_TOKENS,
|
||||
self.D,
|
||||
self.DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
num_tokens=params[0],
|
||||
d=params[1],
|
||||
dtype=params[2],
|
||||
seed=params[3],
|
||||
):
|
||||
self._static_quant_fp8(*params)
|
||||
|
||||
|
||||
# For test
|
||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||
|
||||
Reference in New Issue
Block a user