Remove contiguous before Flashinfer groupwise fp8 gemm (#6804)
This commit is contained in:
@@ -166,11 +166,13 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
|
||||
x_scale_input = x_scale.T.contiguous()
|
||||
weight_scale_input = weight_scale.T.contiguous()
|
||||
|
||||
output = gemm_fp8_nt_groupwise(
|
||||
q_input, weight, x_scale_input, weight_scale_input, out_dtype=input_2d.dtype
|
||||
q_input,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
scale_major_mode="K",
|
||||
out_dtype=input_2d.dtype,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
|
||||
Reference in New Issue
Block a user