[ROCm] fix dtype (#4510)

This commit is contained in:
yiakwy-xpu-ml-framework-team
2025-03-17 20:20:50 +08:00
committed by GitHub
parent 5493c3343e
commit 5f9b2c62ff

View File

@@ -108,10 +108,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
layer.weight, layer.weight.shape[-1]
)
weight_scale = weight_scale.t().contiguous()
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale
)
else:
# if cutlass not supported, we fall back to use torch._scaled_mm
# which requires per tensor quantization on weight
qweight, weight_scale = input_to_float8(layer.weight)
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)