From c02e3139149d5f0c318a3b292d389a58f172b6ba Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Fri, 31 Jan 2025 19:56:02 +0800 Subject: [PATCH] Fix block wise fp8 torch compile (#3232) --- python/sglang/srt/layers/quantization/fp8.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index b0b5b8952..f5a0005a2 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -290,6 +290,13 @@ class Fp8LinearMethod(LinearMethodBase): weight_scale, requires_grad=False ) layer.input_scale = None + else: + layer.weight = torch.nn.Parameter( + layer.weight.data, requires_grad=False + ) + layer.weight_scale_inv = torch.nn.Parameter( + layer.weight_scale_inv.data, requires_grad=False + ) return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights.