Fix block wise fp8 torch compile (#3232)
This commit is contained in:
@@ -290,6 +290,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight_scale, requires_grad=False
|
||||
)
|
||||
layer.input_scale = None
|
||||
else:
|
||||
layer.weight = torch.nn.Parameter(
|
||||
layer.weight.data, requires_grad=False
|
||||
)
|
||||
layer.weight_scale_inv = torch.nn.Parameter(
|
||||
layer.weight_scale_inv.data, requires_grad=False
|
||||
)
|
||||
return
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
|
||||
Reference in New Issue
Block a user