Apply sgl w8a8 fp8 kernel (#3148)
This commit is contained in:
@@ -250,9 +250,11 @@ class ModelConfig:
|
|||||||
"compressed-tensors",
|
"compressed-tensors",
|
||||||
"experts_int8",
|
"experts_int8",
|
||||||
"w8a8_int8",
|
"w8a8_int8",
|
||||||
|
"w8a8_fp8",
|
||||||
]
|
]
|
||||||
compatible_quantization_methods = {
|
compatible_quantization_methods = {
|
||||||
"w8a8_int8": ["compressed-tensors", "compressed_tensors"]
|
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
|
||||||
|
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
|
||||||
}
|
}
|
||||||
if self.quantization is not None:
|
if self.quantization is not None:
|
||||||
self.quantization = self.quantization.lower()
|
self.quantization = self.quantization.lower()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from sglang.srt.distributed import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
|
BlockQuantScaleParameter,
|
||||||
PackedColumnParameter,
|
PackedColumnParameter,
|
||||||
PackedvLLMParameter,
|
PackedvLLMParameter,
|
||||||
PerTensorScaleParameter,
|
PerTensorScaleParameter,
|
||||||
@@ -27,7 +28,6 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ __all__ = [
|
|||||||
"ModelWeightParameter",
|
"ModelWeightParameter",
|
||||||
"ChannelQuantScaleParameter",
|
"ChannelQuantScaleParameter",
|
||||||
"GroupQuantScaleParameter",
|
"GroupQuantScaleParameter",
|
||||||
|
"BlockQuantScaleParameter",
|
||||||
"PackedColumnParameter",
|
"PackedColumnParameter",
|
||||||
"RowvLLMParameter",
|
"RowvLLMParameter",
|
||||||
]
|
]
|
||||||
@@ -221,6 +222,15 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||||
|
"""
|
||||||
|
Parameter class for weight scales loaded for weights with
|
||||||
|
block-wise quantization. Uses both column and row parallelism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PerTensorScaleParameter(BasevLLMParameter):
|
class PerTensorScaleParameter(BasevLLMParameter):
|
||||||
"""
|
"""
|
||||||
Parameter class for scales where the number of scales is
|
Parameter class for scales where the number of scales is
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
|||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||||
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
||||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
||||||
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||||
|
|
||||||
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||||
@@ -50,6 +51,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"qqq": QQQConfig,
|
"qqq": QQQConfig,
|
||||||
"experts_int8": ExpertsInt8Config,
|
"experts_int8": ExpertsInt8Config,
|
||||||
"w8a8_int8": W8A8Int8Config,
|
"w8a8_int8": W8A8Int8Config,
|
||||||
|
"w8a8_fp8": W8A8Fp8Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,11 @@ from sglang.srt.layers.linear import (
|
|||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
UnquantizedLinearMethod,
|
UnquantizedLinearMethod,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
|
||||||
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
|
|||||||
@@ -16,9 +16,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
apply_fp8_linear,
|
|
||||||
convert_to_channelwise,
|
convert_to_channelwise,
|
||||||
cutlass_fp8_supported,
|
|
||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
@@ -29,14 +27,21 @@ from sglang.srt.layers.linear import (
|
|||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
UnquantizedLinearMethod,
|
UnquantizedLinearMethod,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from sglang.srt.layers.parameter import (
|
||||||
|
BlockQuantScaleParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
|
PerTensorScaleParameter,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||||
from sglang.srt.layers.quantization.fp8_utils import (
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
BlockQuantScaleParameter,
|
apply_fp8_linear,
|
||||||
apply_w8a8_block_fp8_linear,
|
apply_w8a8_block_fp8_linear,
|
||||||
|
cutlass_fp8_supported,
|
||||||
|
input_to_float8,
|
||||||
normalize_e4m3fn_to_e4m3fnuz,
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -305,15 +310,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
if self.cutlass_fp8_supported or self.use_marlin:
|
||||||
|
# apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale
|
||||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
qweight, weight_scale = per_token_group_quant_fp8(
|
||||||
# so extend the weight scales to be channelwise.
|
layer.weight, layer.weight.shape[-1]
|
||||||
if self.use_marlin:
|
|
||||||
assert weight_scale.numel() == 1
|
|
||||||
weight_scale = convert_to_channelwise(
|
|
||||||
weight_scale.expand(len(layer.logical_widths)), layer.logical_widths
|
|
||||||
)
|
)
|
||||||
|
weight_scale = weight_scale.t().contiguous()
|
||||||
|
else:
|
||||||
|
# per-tensor quantization
|
||||||
|
qweight, weight_scale = input_to_float8(layer.weight)
|
||||||
|
|
||||||
# Update the layer with the new values.
|
# Update the layer with the new values.
|
||||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||||
@@ -330,23 +335,19 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.input_scale = torch.nn.Parameter(
|
layer.input_scale = torch.nn.Parameter(
|
||||||
layer.input_scale.data, requires_grad=False
|
layer.input_scale.data, requires_grad=False
|
||||||
)
|
)
|
||||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
|
||||||
# so extend the weight scales to be channelwise.
|
# cutlass sgl-kernel and marlin only support per-channel scale
|
||||||
if self.use_marlin:
|
if self.cutlass_fp8_supported or self.use_marlin:
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
weight_scale = convert_to_channelwise(
|
weight_scale = convert_to_channelwise(
|
||||||
layer.weight_scale, layer.logical_widths
|
layer.weight_scale, layer.logical_widths
|
||||||
)
|
)
|
||||||
|
|
||||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
|
||||||
# requantize the logical shards as a single weight.
|
|
||||||
else:
|
else:
|
||||||
# Dequant -> Quant with max scale so we can run per tensor.
|
# Dequant -> Quant with max scale so we can run per tensor.
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
weight_scale = layer.weight_scale
|
weight_scale = layer.weight_scale
|
||||||
|
|
||||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
if is_hip_:
|
if is_hip():
|
||||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
weight_scale=weight_scale,
|
weight_scale=weight_scale,
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
|||||||
|
|
||||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import sgl_per_token_group_quant_fp8
|
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -70,7 +70,8 @@ def _per_token_group_quant_fp8(
|
|||||||
# Quant
|
# Quant
|
||||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||||
y_s = _absmax / fp8_max
|
y_s = _absmax / fp8_max
|
||||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
y_s_inv = 1.0 / y_s
|
||||||
|
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||||
|
|
||||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||||
tl.store(y_s_ptr, y_s)
|
tl.store(y_s_ptr, y_s)
|
||||||
@@ -140,7 +141,7 @@ def per_token_group_quant_fp8(
|
|||||||
x: The input tenosr with ndim >= 2.
|
x: The input tenosr with ndim >= 2.
|
||||||
group_size: The group size used for quantization.
|
group_size: The group size used for quantization.
|
||||||
eps: The minimum to avoid dividing zero.
|
eps: The minimum to avoid dividing zero.
|
||||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
|
dtype: The dype of output tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
||||||
@@ -241,6 +242,132 @@ def sglang_per_token_group_quant_fp8(
|
|||||||
return x_q, x_s
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
def sglang_per_token_quant_fp8(
|
||||||
|
x: torch.Tensor,
|
||||||
|
dtype: torch.dtype = fp8_type_,
|
||||||
|
):
|
||||||
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
|
||||||
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||||
|
x_s = torch.empty(
|
||||||
|
x.shape[0],
|
||||||
|
1,
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
sgl_per_token_quant_fp8(x, x_q, x_s)
|
||||||
|
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _static_quant_fp8(
|
||||||
|
# Pointers to inputs and output
|
||||||
|
y_ptr,
|
||||||
|
y_q_ptr,
|
||||||
|
y_s_ptr,
|
||||||
|
y_s_repeat_ptr,
|
||||||
|
# Stride of input
|
||||||
|
y_stride,
|
||||||
|
# Collums of input
|
||||||
|
N,
|
||||||
|
# Information for float8
|
||||||
|
fp8_min,
|
||||||
|
fp8_max,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK: tl.constexpr,
|
||||||
|
REPEAT_SCALE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""A Triton-accelerated function to perform quantization using the given scale on a
|
||||||
|
tensor
|
||||||
|
|
||||||
|
This function converts the tensor values into float8 values.
|
||||||
|
"""
|
||||||
|
# Map the program id to the row of X and Y it should compute.
|
||||||
|
g_id = tl.program_id(0)
|
||||||
|
y_ptr += g_id * y_stride
|
||||||
|
y_q_ptr += g_id * y_stride
|
||||||
|
if REPEAT_SCALE:
|
||||||
|
y_s_repeat_ptr += g_id
|
||||||
|
|
||||||
|
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
||||||
|
mask = cols < N
|
||||||
|
|
||||||
|
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
y_s = tl.load(y_s_ptr).to(tl.float32)
|
||||||
|
y_s_inv = 1.0 / y_s
|
||||||
|
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||||
|
|
||||||
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||||
|
if REPEAT_SCALE:
|
||||||
|
tl.store(y_s_repeat_ptr, y_s)
|
||||||
|
|
||||||
|
|
||||||
|
def static_quant_fp8(
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_s: torch.Tensor,
|
||||||
|
repeat_scale: bool = False,
|
||||||
|
dtype: torch.dtype = fp8_type_,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Function to perform static quantization using the given scale on an input tensor `x`.
|
||||||
|
|
||||||
|
It converts the tensor values into signed float8 values and returns the
|
||||||
|
quantized tensor along with the scaling factor used for quantization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tenosr with ndim >= 2.
|
||||||
|
x_s: The quantization scale.
|
||||||
|
repeat_scale: Whether to broadcast per-tensor scale to per-channel scale.
|
||||||
|
dtype: The dype of output tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
||||||
|
"""
|
||||||
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
assert x_s.numel() == 1, "only supports per-tensor scale"
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
if is_hip_:
|
||||||
|
fp8_max = 224.0
|
||||||
|
|
||||||
|
fp8_min = -fp8_max
|
||||||
|
|
||||||
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||||
|
M = x.numel() // x.shape[-1]
|
||||||
|
N = x.shape[-1]
|
||||||
|
if repeat_scale:
|
||||||
|
x_s_repeat = torch.empty(
|
||||||
|
(M, 1),
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x_s_repeat = None
|
||||||
|
|
||||||
|
BLOCK = triton.next_power_of_2(N)
|
||||||
|
# heuristics for number of warps
|
||||||
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||||
|
num_stages = 1
|
||||||
|
_static_quant_fp8[(M,)](
|
||||||
|
x,
|
||||||
|
x_q,
|
||||||
|
x_s,
|
||||||
|
x_s_repeat,
|
||||||
|
N,
|
||||||
|
N,
|
||||||
|
fp8_min=fp8_min,
|
||||||
|
fp8_max=fp8_max,
|
||||||
|
BLOCK=BLOCK,
|
||||||
|
REPEAT_SCALE=repeat_scale,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
|
x_s = x_s_repeat if repeat_scale else x_s
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _w8a8_block_fp8_matmul(
|
def _w8a8_block_fp8_matmul(
|
||||||
# Pointers to inputs and output
|
# Pointers to inputs and output
|
||||||
|
|||||||
@@ -2,13 +2,23 @@ import os
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
|
static_quant_fp8,
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
from sglang.srt.utils import (
|
||||||
|
get_bool_env_var,
|
||||||
|
get_cuda_version,
|
||||||
|
get_device_capability,
|
||||||
|
is_hip,
|
||||||
|
)
|
||||||
|
|
||||||
|
use_vllm_cutlass_w8a8_fp8_kernel = os.environ.get(
|
||||||
|
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL", default=False
|
||||||
|
)
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
is_hip_ = is_hip()
|
||||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||||
@@ -18,6 +28,25 @@ _is_cuda = torch.cuda.is_available() and torch.version.cuda
|
|||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
|
||||||
|
|
||||||
|
if use_vllm_cutlass_w8a8_fp8_kernel:
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
else:
|
||||||
|
from sgl_kernel import fp8_scaled_mm
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_fp8_supported():
|
||||||
|
if not _is_cuda:
|
||||||
|
return False
|
||||||
|
major, minor = get_device_capability()
|
||||||
|
cuda_version = get_cuda_version()
|
||||||
|
if major >= 9:
|
||||||
|
return cuda_version >= (12, 0)
|
||||||
|
elif major == 8 and minor == 9:
|
||||||
|
return cuda_version >= (12, 4)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def normalize_e4m3fn_to_e4m3fnuz(
|
def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
@@ -158,10 +187,121 @@ def block_quant_to_tensor_quant(
|
|||||||
return x_q_tensor, scale
|
return x_q_tensor, scale
|
||||||
|
|
||||||
|
|
||||||
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
def apply_fp8_linear(
|
||||||
"""
|
input: torch.Tensor,
|
||||||
Parameter class for weight scales loaded for weights with
|
weight: torch.Tensor,
|
||||||
block-wise quantization. Uses both column and row parallelism.
|
weight_scale: torch.Tensor,
|
||||||
"""
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
input_scale_ub: Optional[torch.Tensor] = None,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
cutlass_fp8_supported: bool = True,
|
||||||
|
use_per_token_if_dynamic: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# View input as 2D matrix for fp8 methods
|
||||||
|
input_2d = input.view(-1, input.shape[-1])
|
||||||
|
output_shape = [*input.shape[:-1], weight.shape[1]]
|
||||||
|
|
||||||
pass
|
# cutlass w8a8 fp8 sgl-kernel only supports per-token scale
|
||||||
|
if input_scale is not None:
|
||||||
|
assert input_scale.numel() == 1
|
||||||
|
# broadcast per-tensor scale to per-token scale when supporting cutlass
|
||||||
|
qinput, x_scale = static_quant_fp8(
|
||||||
|
input_2d, input_scale, repeat_scale=cutlass_fp8_supported
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# default use per-token quantization if dynamic
|
||||||
|
if _is_cuda:
|
||||||
|
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
||||||
|
else:
|
||||||
|
qinput, x_scale = per_token_group_quant_fp8(
|
||||||
|
input_2d, group_size=input_2d.shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
if cutlass_fp8_supported:
|
||||||
|
if use_vllm_cutlass_w8a8_fp8_kernel:
|
||||||
|
# Fall back to vllm cutlass w8a8 fp8 kernel
|
||||||
|
output = ops.cutlass_scaled_mm(
|
||||||
|
qinput,
|
||||||
|
weight,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=weight_scale,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
weight_scale.numel() == weight.shape[1]
|
||||||
|
), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
|
||||||
|
output = fp8_scaled_mm(
|
||||||
|
qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias
|
||||||
|
)
|
||||||
|
return output.view(*output_shape)
|
||||||
|
|
||||||
|
# torch.scaled_mm supports per tensor weights + activations only
|
||||||
|
# so fallback to naive if per channel or per token
|
||||||
|
else:
|
||||||
|
per_tensor_weights = weight_scale.numel() == 1
|
||||||
|
per_tensor_activations = x_scale.numel() == 1
|
||||||
|
|
||||||
|
if per_tensor_weights and per_tensor_activations:
|
||||||
|
# Fused GEMM_DQ
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
weight,
|
||||||
|
out_dtype=input.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=weight_scale,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
# A fix for discrepancy in scaled_mm which returns tuple
|
||||||
|
# for torch < 2.5 and a single value in torch >= 2.5
|
||||||
|
if type(output) is tuple and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Fallback for channelwise case, where we use unfused DQ
|
||||||
|
# due to limitations with scaled_mm
|
||||||
|
|
||||||
|
# Symmetric quantized GEMM by definition computes the following:
|
||||||
|
# C = (s_x * X) (s_w * W) + bias
|
||||||
|
# This is equivalent to dequantizing the weights and activations
|
||||||
|
# before applying a GEMM.
|
||||||
|
#
|
||||||
|
# In order to compute quantized operands, a quantized kernel
|
||||||
|
# will rewrite the above like so:
|
||||||
|
# C = s_w * s_x * (X * W) + bias
|
||||||
|
#
|
||||||
|
# For the scaled_mm fallback case, we break this down, since it
|
||||||
|
# does not support s_w being a vector.
|
||||||
|
|
||||||
|
# Making sure the dummy tensor is on the same device as the weight
|
||||||
|
global TORCH_DEVICE_IDENTITY
|
||||||
|
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||||
|
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||||
|
|
||||||
|
# GEMM
|
||||||
|
# This computes C = (X * W).
|
||||||
|
# Output in fp32 to allow subsequent ops to happen in-place
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
weight,
|
||||||
|
scale_a=TORCH_DEVICE_IDENTITY,
|
||||||
|
scale_b=TORCH_DEVICE_IDENTITY,
|
||||||
|
out_dtype=torch.float32,
|
||||||
|
)
|
||||||
|
# A fix for discrepancy in scaled_mm which returns tuple
|
||||||
|
# for torch < 2.5 and a single value in torch >= 2.5
|
||||||
|
if type(output) is tuple and len(output) == 2:
|
||||||
|
output = output[0]
|
||||||
|
# Unpad (undo num_token_padding)
|
||||||
|
output = torch.narrow(output, 0, 0, input_2d.shape[0])
|
||||||
|
x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
|
||||||
|
|
||||||
|
# DQ
|
||||||
|
# C = sw * sx * (X * W) + bias
|
||||||
|
output = output * x_scale * weight_scale.t()
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output.to(dtype=input.dtype).view(*output_shape)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import torch
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear,
|
convert_to_channelwise,
|
||||||
cutlass_fp8_supported,
|
cutlass_fp8_supported,
|
||||||
requantize_with_max_scale,
|
requantize_with_max_scale,
|
||||||
)
|
)
|
||||||
@@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
|
||||||
|
|
||||||
# Initialize logger for the module
|
# Initialize logger for the module
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -161,6 +162,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
|
|||||||
layer.weight, layer.weight_scale, layer.logical_widths
|
layer.weight, layer.weight_scale, layer.logical_widths
|
||||||
)
|
)
|
||||||
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
|
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
|
||||||
|
# cutlass sgl-kernel only supports per-channel scale
|
||||||
|
if self.cutlass_fp8_supported:
|
||||||
|
max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
|
||||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||||
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
||||||
|
|
||||||
|
|||||||
126
python/sglang/srt/layers/quantization/w8a8_fp8.py
Normal file
126
python/sglang/srt/layers/quantization/w8a8_fp8.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from sglang.srt.layers.linear import LinearMethodBase
|
||||||
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
||||||
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig,
|
||||||
|
QuantizeMethodBase,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8_utils import (
|
||||||
|
apply_fp8_linear,
|
||||||
|
cutlass_fp8_supported,
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
|
||||||
|
class W8A8Fp8Config(QuantizationConfig):
|
||||||
|
"""Config class for W8A8 FP8 Quantization.
|
||||||
|
|
||||||
|
- Weight: static, per-channel, symmetric
|
||||||
|
- Activation: dynamic, per-token, symmetric
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 89
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "w8a8_fp8"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def get_quant_method(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
prefix: str,
|
||||||
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
|
from sglang.srt.layers.linear import LinearBase
|
||||||
|
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return W8A8Fp8LinearMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||||
|
|
||||||
|
def __init__(self, quantization_config: W8A8Fp8Config):
|
||||||
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
self.quantization_config = quantization_config
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
weight = layer.weight
|
||||||
|
weight_scale = layer.weight_scale.detach()
|
||||||
|
if is_hip():
|
||||||
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=weight, weight_scale=weight_scale
|
||||||
|
)
|
||||||
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs
|
||||||
|
):
|
||||||
|
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
self.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
sum(output_partition_sizes),
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
|
||||||
|
weight_scale = ChannelQuantScaleParameter(
|
||||||
|
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
return apply_fp8_linear(
|
||||||
|
x,
|
||||||
|
layer.weight,
|
||||||
|
layer.weight_scale,
|
||||||
|
bias=bias,
|
||||||
|
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||||
|
)
|
||||||
@@ -405,6 +405,7 @@ class ServerArgs:
|
|||||||
"gguf",
|
"gguf",
|
||||||
"modelopt",
|
"modelopt",
|
||||||
"w8a8_int8",
|
"w8a8_int8",
|
||||||
|
"w8a8_fp8",
|
||||||
],
|
],
|
||||||
help="The quantization method.",
|
help="The quantization method.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -52,11 +52,13 @@ import triton
|
|||||||
import zmq
|
import zmq
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
|
from packaging.version import Version, parse
|
||||||
from starlette.routing import Mount
|
from starlette.routing import Mount
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.func import functional_call
|
from torch.func import functional_call
|
||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
from torch.profiler import ProfilerActivity, profile, record_function
|
from torch.profiler import ProfilerActivity, profile, record_function
|
||||||
|
from torch.utils.cpp_extension import CUDA_HOME
|
||||||
from triton.runtime.cache import (
|
from triton.runtime.cache import (
|
||||||
FileCacheManager,
|
FileCacheManager,
|
||||||
default_cache_dir,
|
default_cache_dir,
|
||||||
@@ -1431,6 +1433,12 @@ def rank0_print(msg: str):
|
|||||||
print(msg, flush=True)
|
print(msg, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_version():
|
||||||
|
if torch.version.cuda:
|
||||||
|
return tuple(map(int, torch.version.cuda.split(".")))
|
||||||
|
return (0, 0)
|
||||||
|
|
||||||
|
|
||||||
def launch_dummy_health_check_server(host, port):
|
def launch_dummy_health_check_server(host, port):
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, Response
|
from fastapi import FastAPI, Response
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
|
static_quant_fp8,
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -63,7 +64,7 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
|
|||||||
out, scale = per_token_group_quant_fp8(x, group_size)
|
out, scale = per_token_group_quant_fp8(x, group_size)
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
|
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20)
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(scale, ref_scale))
|
self.assertTrue(torch.allclose(scale, ref_scale))
|
||||||
|
|
||||||
@@ -85,6 +86,71 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
|
|||||||
self._per_token_group_quant_fp8(*params)
|
self._per_token_group_quant_fp8(*params)
|
||||||
|
|
||||||
|
|
||||||
|
# For test
|
||||||
|
def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
|
||||||
|
"""Function to perform static quantization on an input tensor `x` using native torch.
|
||||||
|
|
||||||
|
It converts the tensor values into float8 values and returns the
|
||||||
|
quantized tensor along with the scaling factor used for quantization.
|
||||||
|
"""
|
||||||
|
assert x.is_contiguous(), "`x` is not contiguous"
|
||||||
|
assert x_s.numel() == 1, "only supports per-tensor scale"
|
||||||
|
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
fp8_min = finfo.min
|
||||||
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
x_ = x.reshape(x.numel() // x.shape[-1], x.shape[-1])
|
||||||
|
x_s_inv = 1.0 / x_s
|
||||||
|
x_q = (x_ * x_s_inv).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||||
|
x_q = x_q.reshape(x.shape)
|
||||||
|
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
class TestStaticQuantFP8(unittest.TestCase):
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||||
|
NUM_TOKENS = [7, 83, 2048]
|
||||||
|
D = [512, 4096, 5120, 13824]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise unittest.SkipTest("CUDA is not available")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
def _static_quant_fp8(self, num_tokens, d, dtype, seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
x = torch.rand(num_tokens, d, dtype=dtype)
|
||||||
|
fp8_max = torch.finfo(torch.float8_e4m3fn).max
|
||||||
|
x_s = x.max() / fp8_max
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
ref_out, _ = native_static_quant_fp8(x, x_s)
|
||||||
|
out, _ = static_quant_fp8(x, x_s, repeat_scale=True)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_static_quant_fp8(self):
|
||||||
|
for params in itertools.product(
|
||||||
|
self.NUM_TOKENS,
|
||||||
|
self.D,
|
||||||
|
self.DTYPES,
|
||||||
|
self.SEEDS,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
num_tokens=params[0],
|
||||||
|
d=params[1],
|
||||||
|
dtype=params[2],
|
||||||
|
seed=params[3],
|
||||||
|
):
|
||||||
|
self._static_quant_fp8(*params)
|
||||||
|
|
||||||
|
|
||||||
# For test
|
# For test
|
||||||
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
|
||||||
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
"""This function performs matrix multiplication with block-wise quantization using native torch.
|
||||||
|
|||||||
Reference in New Issue
Block a user