[ROCm] fix dtype (#4510)
This commit is contained in:
committed by
GitHub
parent
5493c3343e
commit
5f9b2c62ff
@@ -108,10 +108,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight, layer.weight.shape[-1]
|
||||
)
|
||||
weight_scale = weight_scale.t().contiguous()
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight, weight_scale=weight_scale
|
||||
)
|
||||
else:
|
||||
# if cutlass not supported, we fall back to use torch._scaled_mm
|
||||
# which requires per tensor quantization on weight
|
||||
qweight, weight_scale = input_to_float8(layer.weight)
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
|
||||
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
|
||||
Reference in New Issue
Block a user