Tiny cleanup fp4 gemm calls (#11537)

This commit is contained in:
fzyzcjy
2025-10-14 05:48:22 +08:00
committed by GitHub
parent 8e51049f56
commit 065ce81574

View File

@@ -852,25 +852,15 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
if enable_flashinfer_fp4_gemm: if enable_flashinfer_fp4_gemm:
w = layer.weight.T w = layer.weight.T
w_scale_interleaved = layer.weight_scale_interleaved.T w_scale_interleaved = layer.weight_scale_interleaved.T
if USE_CUTLASS_BACKEND_FOR_FP4_GEMM: out = fp4_gemm(
out = fp4_gemm( x_fp4,
x_fp4, w,
w, x_scale_interleaved,
x_scale_interleaved, w_scale_interleaved,
w_scale_interleaved, layer.alpha,
layer.alpha, output_dtype,
output_dtype, **(dict(backend="cutlass") if USE_CUTLASS_BACKEND_FOR_FP4_GEMM else dict()),
backend="cutlass", )
)
else:
out = fp4_gemm(
x_fp4,
w,
x_scale_interleaved,
w_scale_interleaved,
layer.alpha,
output_dtype,
)
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.view(*output_shape) return out.view(*output_shape)