Fix ci test "test_eval_fp8_accuracy" failed (#5185)
Co-authored-by: wunhuang <wunhuang@amd.com>
This commit is contained in:
@@ -243,9 +243,19 @@ def apply_fp8_linear(
|
|||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
qinput, x_scale = sglang_per_token_quant_fp8(input_2d)
|
||||||
else:
|
else:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(
|
# TODO(kkhuang): temporarily enforce per-tensor activation scaling if weight is per-tensor scaling
|
||||||
input_2d, input_scale, use_per_token_if_dynamic=use_per_token_if_dynamic
|
# final solution should be: 1. add support to per-tensor activation scaling.
|
||||||
)
|
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
||||||
|
if _is_hip and weight_scale.numel() == 1:
|
||||||
|
qinput, x_scale = ops.scaled_fp8_quant(
|
||||||
|
input_2d,
|
||||||
|
input_scale,
|
||||||
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qinput, x_scale = per_token_group_quant_fp8(
|
||||||
|
input_2d, group_size=input_2d.shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
if cutlass_fp8_supported:
|
if cutlass_fp8_supported:
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user