Tiny cleanup fp4 gemm calls (#11537)
This commit is contained in:
@@ -852,25 +852,15 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
||||
if enable_flashinfer_fp4_gemm:
|
||||
w = layer.weight.T
|
||||
w_scale_interleaved = layer.weight_scale_interleaved.T
|
||||
if USE_CUTLASS_BACKEND_FOR_FP4_GEMM:
|
||||
out = fp4_gemm(
|
||||
x_fp4,
|
||||
w,
|
||||
x_scale_interleaved,
|
||||
w_scale_interleaved,
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
backend="cutlass",
|
||||
)
|
||||
else:
|
||||
out = fp4_gemm(
|
||||
x_fp4,
|
||||
w,
|
||||
x_scale_interleaved,
|
||||
w_scale_interleaved,
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
out = fp4_gemm(
|
||||
x_fp4,
|
||||
w,
|
||||
x_scale_interleaved,
|
||||
w_scale_interleaved,
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
**(dict(backend="cutlass") if USE_CUTLASS_BACKEND_FOR_FP4_GEMM else dict()),
|
||||
)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
|
||||
Reference in New Issue
Block a user