Fix run time error in ROCm platform (#5147)

Co-authored-by: wunhuang <wunhuang@amd.com>
Co-authored-by: root <root@dell300x-pla-t10-17.pla.dcgpu>
This commit is contained in:
kk
2025-04-08 13:49:40 +08:00
committed by GitHub
parent 27f8e6b9c1
commit 5a144a8ab9
3 changed files with 23 additions and 3 deletions

View File

@@ -171,6 +171,7 @@ def input_to_float8(
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
fp8_max = finfo.max
if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0
scale = fp8_max / amax
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)