diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py index fa7d77f28..af4f1a0e0 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -76,7 +76,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): layer.input_scale = torch.nn.Parameter( layer.input_scale.data, requires_grad=False ) - prepare_fp8_layer_for_marlin(layer, strategy="channel") + prepare_fp8_layer_for_marlin(layer, size_k_first=True) def create_weights( self,