Fix run time error in ROCm platform (#5147)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: root <root@dell300x-pla-t10-17.pla.dcgpu>
This commit is contained in:
@@ -171,6 +171,7 @@ def input_to_float8(
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
fp8_max = finfo.max
|
||||
if _is_hip:
|
||||
dtype = torch.float8_e4m3fnuz
|
||||
fp8_max = 224.0
|
||||
scale = fp8_max / amax
|
||||
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
||||
|
||||
Reference in New Issue
Block a user