Fix block wise fp8 torch compile (#3232)
This commit is contained in:
@@ -290,6 +290,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight_scale, requires_grad=False
|
weight_scale, requires_grad=False
|
||||||
)
|
)
|
||||||
layer.input_scale = None
|
layer.input_scale = None
|
||||||
|
else:
|
||||||
|
layer.weight = torch.nn.Parameter(
|
||||||
|
layer.weight.data, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.weight_scale_inv = torch.nn.Parameter(
|
||||||
|
layer.weight_scale_inv.data, requires_grad=False
|
||||||
|
)
|
||||||
return
|
return
|
||||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
|
|||||||
Reference in New Issue
Block a user