diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 240d86927..8ddbef82e 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -9,9 +9,11 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, cutlass_fp8_supported, + input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.utils import is_hip @@ -22,12 +24,24 @@ _is_hip = is_hip() class W8A8Fp8Config(QuantizationConfig): """Config class for W8A8 FP8 Quantization. - - Weight: static, per-channel, symmetric - - Activation: dynamic, per-token, symmetric + Weight Quantization: + - Method: Static quantization + - Granularity: Per-channel + - Type: Symmetric + + Activation Quantization: + - Method: Dynamic quantization + - Granularity: Per-token + - Type: Symmetric + + Note: + - For models without offline quantization, weights will be quantized during model loading + - If CUTLASS is supported: Per-channel weight quantization is used + - If CUTLASS is not supported: Falls back to per-token weight quantization """ - def __init__(self): - pass + def __init__(self, is_checkpoint_fp8_serialized: bool = False): + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: @@ -47,7 +61,9 @@ class W8A8Fp8Config(QuantizationConfig): @classmethod def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config": - return cls() + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized) def get_quant_method( self, @@ -72,13 +88,35 @@ class W8A8Fp8LinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight = layer.weight - weight_scale = layer.weight_scale.detach() - if _is_hip: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, weight_scale=weight_scale - ) - layer.weight = Parameter(weight.t(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + if self.quantization_config.is_checkpoint_fp8_serialized: + weight_scale = layer.weight_scale.detach() + # If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly. + if _is_hip: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=weight_scale + ) + + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + else: + # If checkpoint not offline quantized, quantize the weights with per-channel quantization. + if self.cutlass_fp8_supported: + # if cutlass supported, we use cutlass_scaled_mm + # which requires per-channel quantization on weight + qweight, weight_scale = per_token_group_quant_fp8( + layer.weight, layer.weight.shape[-1] + ) + weight_scale = weight_scale.t().contiguous() + else: + # if cutlass not supported, we fall back to use torch._scaled_mm + # which requires per tensor quantization on weight + qweight, weight_scale = input_to_float8(layer.weight) + + # Update the layer with the new values. + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None def create_weights( self, @@ -90,6 +128,11 @@ class W8A8Fp8LinearMethod(LinearMethodBase): params_dtype: torch.dtype, **extra_weight_attrs ): + weight_dtype = ( + torch.float8_e4m3fn + if self.quantization_config.is_checkpoint_fp8_serialized + else params_dtype + ) weight_loader = extra_weight_attrs.get("weight_loader") self.logical_widths = output_partition_sizes @@ -98,7 +141,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase): data=torch.empty( sum(output_partition_sizes), input_size_per_partition, - dtype=torch.float8_e4m3fn, + dtype=weight_dtype, ), input_dim=1, output_dim=0, @@ -106,12 +149,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase): ) layer.register_parameter("weight", weight) - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight_scale", weight_scale) + if self.quantization_config.is_checkpoint_fp8_serialized: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + else: + layer.weight_scale = None def apply( self, diff --git a/test/srt/test_eval_fp8_accuracy.py b/test/srt/test_eval_fp8_accuracy.py index 07eb4dc04..25aa8a50d 100644 --- a/test/srt/test_eval_fp8_accuracy.py +++ b/test/srt/test_eval_fp8_accuracy.py @@ -6,6 +6,7 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST, DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST, + DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, @@ -40,33 +41,68 @@ class TestEvalFP8Accuracy(unittest.TestCase): class TestEvalFP8DynamicQuantAccuracy(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, + + def _run_test(self, model, other_args, expected_score): + base_url = DEFAULT_URL_FOR_TEST + other_args = other_args or [] + + process = popen_launch_server( + model, + base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + try: + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], expected_score) + finally: + kill_process_tree(process.pid) + + def test_mmlu_offline_only(self): + """Test with offline quantization only.""" + self._run_test( + model=DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST, + other_args=[], + expected_score=0.64, + ) + + def test_mmlu_offline_and_online_override(self): + """Test with both offline and online quantization.""" + self._run_test( + model=DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST, other_args=["--quantization", "w8a8_fp8"], + # inference will use sgl kernel w/ online quant override + # we observed that the accuracy is higher then offline only + expected_score=0.64, ) - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_mmlu(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - temperature=0.1, + def test_mmlu_online_only(self): + """Test with online quantization only.""" + self._run_test( + model=DEFAULT_MODEL_NAME_FOR_TEST, + # inference will use sgl kernel w/ online quantization only + # we observed that the accuracy is higher then offline only + other_args=["--quantization", "w8a8_fp8"], + expected_score=0.64, ) - metrics = run_eval(args) - self.assertGreaterEqual(metrics["score"], 0.70) + def test_mmlu_fp16_baseline(self): + """Test with unquantized fp16 baseline.""" + self._run_test( + model=DEFAULT_MODEL_NAME_FOR_TEST, + other_args=[], + expected_score=0.64, + ) if __name__ == "__main__":