diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 6f103bcc6..489cc6d4b 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -250,9 +250,11 @@ class ModelConfig: "compressed-tensors", "experts_int8", "w8a8_int8", + "w8a8_fp8", ] compatible_quantization_methods = { - "w8a8_int8": ["compressed-tensors", "compressed_tensors"] + "w8a8_int8": ["compressed-tensors", "compressed_tensors"], + "w8a8_fp8": ["compressed-tensors", "compressed_tensors"], } if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 919bcced3..85748fa74 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -18,6 +18,7 @@ from sglang.srt.distributed import ( ) from sglang.srt.layers.parameter import ( BasevLLMParameter, + BlockQuantScaleParameter, PackedColumnParameter, PackedvLLMParameter, PerTensorScaleParameter, @@ -27,7 +28,6 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter from sglang.srt.utils import set_weight_attrs logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index 78be67982..b3fc6b440 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -16,6 +16,7 @@ __all__ = [ "ModelWeightParameter", "ChannelQuantScaleParameter", "GroupQuantScaleParameter", + "BlockQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter", ] @@ -221,6 +222,15 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter): pass +class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + block-wise quantization. Uses both column and row parallelism. + """ + + pass + + class PerTensorScaleParameter(BasevLLMParameter): """ Parameter class for scales where the number of scales is diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 1ef8f4381..c09fb5a1a 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config +from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { @@ -50,6 +51,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "qqq": QQQConfig, "experts_int8": ExpertsInt8Config, "w8a8_int8": W8A8Int8Config, + "w8a8_fp8": W8A8Fp8Config, } diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index 1470ca427..ce526cd6a 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -13,12 +13,11 @@ from sglang.srt.layers.linear import ( LinearMethodBase, UnquantizedLinearMethod, ) -from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter +from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear from sglang.srt.utils import set_weight_attrs diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index e296756b5..44a3cba8a 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -16,9 +16,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, - apply_fp8_linear, convert_to_channelwise, - cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale, ) @@ -29,14 +27,21 @@ from sglang.srt.layers.linear import ( LinearMethodBase, UnquantizedLinearMethod, ) -from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter +from sglang.srt.layers.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_utils import ( - BlockQuantScaleParameter, + apply_fp8_linear, apply_w8a8_block_fp8_linear, + cutlass_fp8_supported, + input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.utils import ( @@ -305,15 +310,15 @@ class Fp8LinearMethod(LinearMethodBase): layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) - - # If using marlin (w8a16), kernel uses channelwise weights, - # so extend the weight scales to be channelwise. - if self.use_marlin: - assert weight_scale.numel() == 1 - weight_scale = convert_to_channelwise( - weight_scale.expand(len(layer.logical_widths)), layer.logical_widths + if self.cutlass_fp8_supported or self.use_marlin: + # apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale + qweight, weight_scale = per_token_group_quant_fp8( + layer.weight, layer.weight.shape[-1] ) + weight_scale = weight_scale.t().contiguous() + else: + # per-tensor quantization + qweight, weight_scale = input_to_float8(layer.weight) # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) @@ -330,23 +335,19 @@ class Fp8LinearMethod(LinearMethodBase): layer.input_scale = torch.nn.Parameter( layer.input_scale.data, requires_grad=False ) - # If using marlin (w8a16), kernel uses channelwise weights, - # so extend the weight scales to be channelwise. - if self.use_marlin: + + # cutlass sgl-kernel and marlin only support per-channel scale + if self.cutlass_fp8_supported or self.use_marlin: weight = layer.weight weight_scale = convert_to_channelwise( layer.weight_scale, layer.logical_widths ) - - # If using w8a8, torch._scaled_mm needs per tensor, so - # requantize the logical shards as a single weight. else: # Dequant -> Quant with max scale so we can run per tensor. weight = layer.weight weight_scale = layer.weight_scale - # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip_: + if is_hip(): weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale, diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 47f310a24..54c07f909 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -29,7 +29,7 @@ fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn _is_cuda = torch.cuda.is_available() and torch.version.cuda if _is_cuda: - from sgl_kernel import sgl_per_token_group_quant_fp8 + from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 logger = logging.getLogger(__name__) @@ -70,7 +70,8 @@ def _per_token_group_quant_fp8( # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) y_s = _absmax / fp8_max - y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + y_s_inv = 1.0 / y_s + y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) @@ -140,7 +141,7 @@ def per_token_group_quant_fp8( x: The input tenosr with ndim >= 2. group_size: The group size used for quantization. eps: The minimum to avoid dividing zero. - dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. + dtype: The dype of output tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. @@ -241,6 +242,132 @@ def sglang_per_token_group_quant_fp8( return x_q, x_s +def sglang_per_token_quant_fp8( + x: torch.Tensor, + dtype: torch.dtype = fp8_type_, +): + assert x.is_contiguous(), "`x` is not contiguous" + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + x_s = torch.empty( + x.shape[0], + 1, + device=x.device, + dtype=torch.float32, + ) + + sgl_per_token_quant_fp8(x, x_q, x_s) + + return x_q, x_s + + +@triton.jit +def _static_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + y_s_repeat_ptr, + # Stride of input + y_stride, + # Collums of input + N, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, + REPEAT_SCALE: tl.constexpr, +): + """A Triton-accelerated function to perform quantization using the given scale on a + tensor + + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + if REPEAT_SCALE: + y_s_repeat_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + y_s = tl.load(y_s_ptr).to(tl.float32) + y_s_inv = 1.0 / y_s + y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + if REPEAT_SCALE: + tl.store(y_s_repeat_ptr, y_s) + + +def static_quant_fp8( + x: torch.Tensor, + x_s: torch.Tensor, + repeat_scale: bool = False, + dtype: torch.dtype = fp8_type_, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform static quantization using the given scale on an input tensor `x`. + + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + + Args: + x: The input tenosr with ndim >= 2. + x_s: The quantization scale. + repeat_scale: Whether to broadcast per-tensor scale to per-channel scale. + dtype: The dype of output tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. + """ + assert x.is_contiguous(), "`x` is not contiguous" + assert x_s.numel() == 1, "only supports per-tensor scale" + finfo = torch.finfo(dtype) + fp8_max = finfo.max + + if is_hip_: + fp8_max = 224.0 + + fp8_min = -fp8_max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // x.shape[-1] + N = x.shape[-1] + if repeat_scale: + x_s_repeat = torch.empty( + (M, 1), + device=x.device, + dtype=torch.float32, + ) + else: + x_s_repeat = None + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _static_quant_fp8[(M,)]( + x, + x_q, + x_s, + x_s_repeat, + N, + N, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + REPEAT_SCALE=repeat_scale, + num_warps=num_warps, + num_stages=num_stages, + ) + x_s = x_s_repeat if repeat_scale else x_s + return x_q, x_s + + @triton.jit def _w8a8_block_fp8_matmul( # Pointers to inputs and output diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index ff10f0a56..feaae26f6 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -2,13 +2,23 @@ import os from typing import List, Optional, Tuple import torch +from packaging.version import Version -from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, + static_quant_fp8, w8a8_block_fp8_matmul, ) -from sglang.srt.utils import get_bool_env_var, is_hip +from sglang.srt.utils import ( + get_bool_env_var, + get_cuda_version, + get_device_capability, + is_hip, +) + +use_vllm_cutlass_w8a8_fp8_kernel = os.environ.get( + "USE_VLLM_CUTLASS_W8A8_FP8_KERNEL", default=False +) is_hip_ = is_hip() if is_hip_ and get_bool_env_var("CK_MOE"): @@ -18,6 +28,25 @@ _is_cuda = torch.cuda.is_available() and torch.version.cuda if _is_cuda: from sgl_kernel import fp8_blockwise_scaled_mm + from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8 + + if use_vllm_cutlass_w8a8_fp8_kernel: + from vllm import _custom_ops as ops + else: + from sgl_kernel import fp8_scaled_mm + + +def cutlass_fp8_supported(): + if not _is_cuda: + return False + major, minor = get_device_capability() + cuda_version = get_cuda_version() + if major >= 9: + return cuda_version >= (12, 0) + elif major == 8 and minor == 9: + return cuda_version >= (12, 4) + return False + def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, @@ -158,10 +187,121 @@ def block_quant_to_tensor_quant( return x_q_tensor, scale -class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): - """ - Parameter class for weight scales loaded for weights with - block-wise quantization. Uses both column and row parallelism. - """ +def apply_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + input_scale_ub: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_fp8_supported: bool = True, + use_per_token_if_dynamic: bool = False, +) -> torch.Tensor: + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[1]] - pass + # cutlass w8a8 fp8 sgl-kernel only supports per-token scale + if input_scale is not None: + assert input_scale.numel() == 1 + # broadcast per-tensor scale to per-token scale when supporting cutlass + qinput, x_scale = static_quant_fp8( + input_2d, input_scale, repeat_scale=cutlass_fp8_supported + ) + else: + # default use per-token quantization if dynamic + if _is_cuda: + qinput, x_scale = sglang_per_token_quant_fp8(input_2d) + else: + qinput, x_scale = per_token_group_quant_fp8( + input_2d, group_size=input_2d.shape[1] + ) + + if cutlass_fp8_supported: + if use_vllm_cutlass_w8a8_fp8_kernel: + # Fall back to vllm cutlass w8a8 fp8 kernel + output = ops.cutlass_scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + else: + assert ( + weight_scale.numel() == weight.shape[1] + ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" + output = fp8_scaled_mm( + qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias + ) + return output.view(*output_shape) + + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + else: + per_tensor_weights = weight_scale.numel() == 1 + per_tensor_activations = x_scale.numel() == 1 + + if per_tensor_weights and per_tensor_activations: + # Fused GEMM_DQ + output = torch._scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + + else: + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # Making sure the dummy tensor is on the same device as the weight + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY.device != weight.device: + TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm( + qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * weight_scale.t() + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index a28e0aeea..c26012da2 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -7,7 +7,7 @@ import torch from torch.nn.parameter import Parameter from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, + convert_to_channelwise, cutlass_fp8_supported, requantize_with_max_scale, ) @@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear # Initialize logger for the module logger = logging.getLogger(__name__) @@ -161,6 +162,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase): layer.weight, layer.weight_scale, layer.logical_widths ) layer.weight = Parameter(quantized_weight.t(), requires_grad=False) + # cutlass sgl-kernel only supports per-channel scale + if self.cutlass_fp8_supported: + max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py new file mode 100644 index 000000000..0adedc68f --- /dev/null +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -0,0 +1,126 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from sglang.srt.layers.linear import LinearMethodBase +from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.fp8_utils import ( + apply_fp8_linear, + cutlass_fp8_supported, + normalize_e4m3fn_to_e4m3fnuz, +) +from sglang.srt.utils import is_hip + + +class W8A8Fp8Config(QuantizationConfig): + """Config class for W8A8 FP8 Quantization. + + - Weight: static, per-channel, symmetric + - Activation: dynamic, per-token, symmetric + """ + + def __init__(self): + pass + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def get_name(self) -> str: + return "w8a8_fp8" + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config": + return cls() + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from sglang.srt.layers.linear import LinearBase + + if isinstance(layer, LinearBase): + return W8A8Fp8LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class W8A8Fp8LinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: W8A8Fp8Config): + self.cutlass_fp8_supported = cutlass_fp8_supported() + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight = layer.weight + weight_scale = layer.weight_scale.detach() + if is_hip(): + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale + ) + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs + ): + + weight_loader = extra_weight_attrs.get("weight_loader") + self.logical_widths = output_partition_sizes + + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + return apply_fp8_linear( + x, + layer.weight, + layer.weight_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c5b8b920e..4e6fbdd49 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -405,6 +405,7 @@ class ServerArgs: "gguf", "modelopt", "w8a8_int8", + "w8a8_fp8", ], help="The quantization method.", ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 1ce2862f9..8bfdbc0ed 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -52,11 +52,13 @@ import triton import zmq from fastapi.responses import ORJSONResponse from packaging import version as pkg_version +from packaging.version import Version, parse from starlette.routing import Mount from torch import nn from torch.func import functional_call from torch.library import Library from torch.profiler import ProfilerActivity, profile, record_function +from torch.utils.cpp_extension import CUDA_HOME from triton.runtime.cache import ( FileCacheManager, default_cache_dir, @@ -1431,6 +1433,12 @@ def rank0_print(msg: str): print(msg, flush=True) +def get_cuda_version(): + if torch.version.cuda: + return tuple(map(int, torch.version.cuda.split("."))) + return (0, 0) + + def launch_dummy_health_check_server(host, port): import uvicorn from fastapi import FastAPI, Response diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index 3a02531e6..b3da7690c 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -7,6 +7,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, + static_quant_fp8, w8a8_block_fp8_matmul, ) @@ -63,7 +64,7 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase): out, scale = per_token_group_quant_fp8(x, group_size) self.assertTrue( - torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15) + torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20) ) self.assertTrue(torch.allclose(scale, ref_scale)) @@ -85,6 +86,71 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase): self._per_token_group_quant_fp8(*params) +# For test +def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn): + """Function to perform static quantization on an input tensor `x` using native torch. + + It converts the tensor values into float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + """ + assert x.is_contiguous(), "`x` is not contiguous" + assert x_s.numel() == 1, "only supports per-tensor scale" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // x.shape[-1], x.shape[-1]) + x_s_inv = 1.0 / x_s + x_q = (x_ * x_s_inv).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + + return x_q, x_s + + +class TestStaticQuantFP8(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16, torch.float32] + NUM_TOKENS = [7, 83, 2048] + D = [512, 4096, 5120, 13824] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _static_quant_fp8(self, num_tokens, d, dtype, seed): + torch.manual_seed(seed) + + x = torch.rand(num_tokens, d, dtype=dtype) + fp8_max = torch.finfo(torch.float8_e4m3fn).max + x_s = x.max() / fp8_max + + with torch.inference_mode(): + ref_out, _ = native_static_quant_fp8(x, x_s) + out, _ = static_quant_fp8(x, x_s, repeat_scale=True) + + self.assertTrue( + torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50) + ) + + def test_static_quant_fp8(self): + for params in itertools.product( + self.NUM_TOKENS, + self.D, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + d=params[1], + dtype=params[2], + seed=params[3], + ): + self._static_quant_fp8(*params) + + # For test def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): """This function performs matrix multiplication with block-wise quantization using native torch.