AMD: set weights and scaling numbers properly for block FP8 (#2637)
This commit is contained in:
@@ -272,6 +272,19 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
# activation_scheme: dynamic
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(weight, require_grad=False)
|
||||
layer.weight_scale_inv = torch.nn.Parameter(
|
||||
weight_scale, require_grad=False
|
||||
)
|
||||
layer.input_scale = None
|
||||
return
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
@@ -369,7 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight=layer.weight,
|
||||
block_size=self.quant_config.weight_block_size,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
input_scale=None,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
@@ -553,6 +566,30 @@ class Fp8MoEMethod:
|
||||
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
# activation_scheme: dynamic
|
||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w13_weight,
|
||||
weight_scale=layer.w13_weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w2_weight,
|
||||
weight_scale=layer.w2_weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False
|
||||
)
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||
w2_weight_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = None
|
||||
return
|
||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
|
||||
Reference in New Issue
Block a user