[ROCm] fix dtype (#4510)
This commit is contained in:
committed by
GitHub
parent
5493c3343e
commit
5f9b2c62ff
@@ -108,10 +108,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.weight, layer.weight.shape[-1]
|
layer.weight, layer.weight.shape[-1]
|
||||||
)
|
)
|
||||||
weight_scale = weight_scale.t().contiguous()
|
weight_scale = weight_scale.t().contiguous()
|
||||||
|
if _is_hip:
|
||||||
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=weight, weight_scale=weight_scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# if cutlass not supported, we fall back to use torch._scaled_mm
|
# if cutlass not supported, we fall back to use torch._scaled_mm
|
||||||
# which requires per tensor quantization on weight
|
# which requires per tensor quantization on weight
|
||||||
qweight, weight_scale = input_to_float8(layer.weight)
|
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
|
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
|
||||||
|
|
||||||
# Update the layer with the new values.
|
# Update the layer with the new values.
|
||||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user