Fix block wise fp8 torch compile (#3232)

This commit is contained in:
Ke Bao
2025-01-31 19:56:02 +08:00
committed by GitHub
parent 734daedd8f
commit c02e313914

View File

@@ -290,6 +290,13 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale, requires_grad=False
)
layer.input_scale = None
else:
layer.weight = torch.nn.Parameter(
layer.weight.data, requires_grad=False
)
layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv.data, requires_grad=False
)
return
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.