diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 8e4f84714..ed13195c5 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -16,7 +16,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, ) from sglang.srt.layers.quantization.fp8_utils import ( - Fp8LinearOp, + apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale @@ -29,7 +29,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) @classmethod def get_min_capability(cls) -> int: @@ -149,11 +148,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - return self.fp8_linear.apply( + return apply_fp8_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, + use_per_token_if_dynamic=True, + compressed_tensor_quant=True, ) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 7377ab73b..33519a49c 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Tuple import torch @@ -5,7 +6,7 @@ import torch from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 try: - from vllm import _custom_ops as vllm_ops + from vllm import _custom_ops as ops VLLM_AVAILABLE = True except ImportError: @@ -234,6 +235,43 @@ def channel_quant_to_tensor_quant( return x_q_tensor, scale +def _process_scaled_mm_output(output, input_2d_shape, output_shape): + if type(output) is tuple and len(output) == 2: + output = output[0] + return torch.narrow(output, 0, 0, input_2d_shape[0]).view(*output_shape) + + +def _apply_fallback_scaled_mm( + qinput, + weight, + x_scale, + weight_scale, + input_2d_shape, + output_shape, + bias, + input_dtype, +): + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY is None: + TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32, device=weight.device) + + output = torch._scaled_mm( + qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32, + ) + + output = _process_scaled_mm_output(output, input_2d_shape, output_shape) + x_scale = torch.narrow(x_scale, 0, 0, input_2d_shape[0]) + + output = output * x_scale * weight_scale.t() + if bias is not None: + output = output + bias + return output.to(dtype=input_dtype) + + def apply_fp8_linear( input: torch.Tensor, weight: torch.Tensor, @@ -241,211 +279,38 @@ def apply_fp8_linear( input_scale: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - cutlass_fp8_supported: bool = True, + cutlass_fp8_supported: bool = cutlass_fp8_supported(), use_per_token_if_dynamic: bool = False, + pad_output: Optional[bool] = None, + compressed_tensor_quant: bool = False, ) -> torch.Tensor: + # Note: we pad the input because torch._scaled_mm is more performant + # for matrices with batch dimension > 16. + # This could change in the future. + # We also don't pad when using torch.compile, + # as it breaks with dynamic shapes. + if pad_output is None: + pad_output = not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE") + output_padding = 17 if pad_output else None + # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[1]] - # cutlass w8a8 fp8 sgl-kernel only supports per-token scale - if input_scale is not None: - assert input_scale.numel() == 1 - # broadcast per-tensor scale to per-token scale when supporting cutlass - qinput, x_scale = static_quant_fp8( - input_2d, input_scale, repeat_scale=cutlass_fp8_supported - ) - else: - # default use per-token quantization if dynamic - if _is_cuda: - qinput, x_scale = sglang_per_token_quant_fp8(input_2d) - else: - # TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling - # final solution should be: 1. add support to per-tensor activation scaling. - # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308) - if _is_hip and weight_scale.numel() == 1: - qinput, x_scale = vllm_ops.scaled_fp8_quant( - input_2d, - input_scale, - use_per_token_if_dynamic=use_per_token_if_dynamic, - ) - else: - qinput, x_scale = per_token_group_quant_fp8( - input_2d, group_size=input_2d.shape[1] - ) - - if cutlass_fp8_supported: - if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: - # Fall back to vllm cutlass w8a8 fp8 kernel - output = vllm_ops.cutlass_scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - ) - else: - assert ( - weight_scale.numel() == weight.shape[1] - ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" - output = fp8_scaled_mm( - qinput, - weight, - x_scale, - weight_scale, - out_dtype=input.dtype, - bias=bias, - ) - return output.view(*output_shape) - - # torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token - else: - per_tensor_weights = weight_scale.numel() == 1 - per_tensor_activations = x_scale.numel() == 1 - - if per_tensor_weights and per_tensor_activations: - # Fused GEMM_DQ - output = torch._scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) - - else: - # Fallback for channelwise case, where we use unfused DQ - # due to limitations with scaled_mm - - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # Making sure the dummy tensor is on the same device as the weight - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY is None: - TORCH_DEVICE_IDENTITY = torch.ones( - 1, dtype=torch.float32, device=weight.device - ) - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm( - qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32, - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) - - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * weight_scale.t() - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) - - -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py -# TODO(luka): follow similar pattern for marlin and block-fp8-linear -# https://github.com/vllm-project/vllm/issues/14397 -class Fp8LinearOp: - """ - This class executes a FP8 linear layer using cutlass if supported and - torch.scaled_mm otherwise. - It needs to be a class instead of a method so that config can be read - in the __init__ method, as reading config is not allowed inside forward. - """ - - def __init__( - self, - cutlass_fp8_supported: bool = cutlass_fp8_supported(), - use_per_token_if_dynamic: bool = False, - pad_output: Optional[bool] = None, - ): - self.cutlass_fp8_supported = cutlass_fp8_supported - self.use_per_token_if_dynamic = use_per_token_if_dynamic - - # Note: we pad the input because torch._scaled_mm is more performant - # for matrices with batch dimension > 16. - # This could change in the future. - # We also don't pad when using torch.compile, - # as it breaks with dynamic shapes. - if pad_output is None: - enable_torch_compile = get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE") - pad_output = not enable_torch_compile - self.output_padding = 17 if pad_output else None - - def apply( - self, - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_scale: Optional[torch.Tensor] = None, - input_scale_ub: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - # TODO(luka) remove this parameter in favor of __init__ - use_per_token_if_dynamic: Optional[bool] = None, - ) -> torch.Tensor: - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x. - # If static, layer.input_scale is scalar and x_scale is input_scale. - - # View input as 2D matrix for fp8 methods - input_2d = input.view(-1, input.shape[-1]) - output_shape = [*input.shape[:-1], weight.shape[1]] - - # TODO(luka) this is here because currently MLA only decides this - # during the forward method instead of in __init__. - if use_per_token_if_dynamic is None: - use_per_token_if_dynamic = self.use_per_token_if_dynamic - + if compressed_tensor_quant: # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A # for sgl-kernel fp8_scaled_mm, it support per channel W now - if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]: - if _is_cuda: - qinput, x_scale = scaled_fp8_quant( - input_2d, - input_scale, - use_per_token_if_dynamic=use_per_token_if_dynamic, - ) - else: - qinput, x_scale = vllm_ops.scaled_fp8_quant( - input_2d, - input_scale, - scale_ub=input_scale_ub, - use_per_token_if_dynamic=use_per_token_if_dynamic, - ) + if cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]: + qinput, x_scale = scaled_fp8_quant( + input_2d, + input_scale, + use_per_token_if_dynamic=use_per_token_if_dynamic, + ) # Fused GEMM_DQ if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: # Fall back to vllm cutlass w8a8 fp8 kernel - output = vllm_ops.cutlass_scaled_mm( + output = ops.cutlass_scaled_mm( qinput, weight, out_dtype=input.dtype, @@ -471,20 +336,21 @@ class Fp8LinearOp: # so fallback to naive if per channel or per token else: # Maybe apply padding to output, see comment in __init__ - if _is_cuda: - qinput, x_scale = scaled_fp8_quant( + qinput, x_scale = ( + scaled_fp8_quant( input_2d, input_scale, - num_token_padding=self.output_padding, + num_token_padding=output_padding, use_per_token_if_dynamic=use_per_token_if_dynamic, ) - else: - qinput, x_scale = vllm_ops.scaled_fp8_quant( + if _is_cuda + else ops.scaled_fp8_quant( input_2d, input_scale, - num_token_padding=self.output_padding, + num_token_padding=output_padding, use_per_token_if_dynamic=use_per_token_if_dynamic, ) + ) per_tensor_weights = weight_scale.numel() == 1 per_tensor_activations = x_scale.numel() == 1 @@ -499,12 +365,7 @@ class Fp8LinearOp: scale_b=weight_scale, bias=bias, ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + return _process_scaled_mm_output(output, input_2d.shape, output_shape) elif ( use_per_token_if_dynamic @@ -527,10 +388,7 @@ class Fp8LinearOp: scale_b=weight_scale.t(), bias=bias, ) - - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - output = output.view(*output_shape) - return output + return _process_scaled_mm_output(output, input_2d.shape, output_shape) else: # Fallback for channelwise case, where we use unfused DQ @@ -547,36 +405,110 @@ class Fp8LinearOp: # # For the scaled_mm fallback case, we break this down, since it # does not support s_w being a vector. - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - - # Making sure the dummy tensor is on the same device as the weight - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY is None: - TORCH_DEVICE_IDENTITY = torch.ones( - 1, dtype=torch.float32, device=weight.device - ) - - output = torch._scaled_mm( + return _apply_fallback_scaled_mm( qinput, weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32, + x_scale, + weight_scale, + input_2d.shape, + output_shape, + bias, + input.dtype, ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) + else: + # cutlass w8a8 fp8 sgl-kernel only supports per-token scale + if input_scale is not None: + assert input_scale.numel() == 1 + # broadcast per-tensor scale to per-token scale when supporting cutlass + qinput, x_scale = static_quant_fp8( + input_2d, input_scale, repeat_scale=cutlass_fp8_supported + ) + else: + # default use per-token quantization if dynamic + if _is_cuda: + qinput, x_scale = sglang_per_token_quant_fp8(input_2d) + else: + # TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling + # final solution should be: 1. add support to per-tensor activation scaling. + # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308) + if _is_hip and weight_scale.numel() == 1: + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + use_per_token_if_dynamic=use_per_token_if_dynamic, + ) + else: + qinput, x_scale = per_token_group_quant_fp8( + input_2d, group_size=input_2d.shape[1] + ) - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * weight_scale.t() - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + if cutlass_fp8_supported: + try: + if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: + # Fall back to vllm cutlass w8a8 fp8 kernel + output = ops.cutlass_scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + else: + assert ( + weight_scale.numel() == weight.shape[1] + ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" + output = fp8_scaled_mm( + qinput, + weight, + x_scale, + weight_scale, + out_dtype=input.dtype, + bias=bias, + ) + return output.view(*output_shape) + except (ImportError, NameError, AttributeError): + pass + + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + per_tensor_weights = weight_scale.numel() == 1 + per_tensor_activations = x_scale.numel() == 1 + + if per_tensor_weights and per_tensor_activations: + # Fused GEMM_DQ + output = torch._scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + return _process_scaled_mm_output(output, input_2d.shape, output_shape) + + else: + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + return _apply_fallback_scaled_mm( + qinput, + weight, + x_scale, + weight_scale, + input_2d.shape, + output_shape, + bias, + input.dtype, + ) diff --git a/test/srt/models/test_compressed_tensors_models.py b/test/srt/models/test_compressed_tensors_models.py new file mode 100644 index 000000000..b069008d0 --- /dev/null +++ b/test/srt/models/test_compressed_tensors_models.py @@ -0,0 +1,46 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestCompressedTensorsLlama3FP8(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "RedHatAI/Meta-Llama-3.1-8B-FP8" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.45) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3c94b2ba3..3f7d846a5 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -20,6 +20,7 @@ suites = { TestFile("models/test_generation_models.py", 103), TestFile("models/test_grok_models.py", 60), TestFile("models/test_qwen_models.py", 82), + TestFile("models/test_compressed_tensors_models.py", 100), TestFile("models/test_reward_models.py", 83), TestFile("models/test_gme_qwen_models.py", 45), TestFile("models/test_clip_models.py", 100),