Use FlashInfer's TRTLLM FP8 Blockscale GEMM (#8588)

This commit is contained in:
Elfie Guo
2025-08-12 20:08:40 -07:00
committed by GitHub
parent 62f99e08b3
commit 8723b4f146

View File

@@ -161,16 +161,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = sglang_per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
input_2d, block_size[1], column_major_scales=True
)
# TRTLLM requires column-major scaling factors
output = gemm_fp8_nt_groupwise(
q_input,
weight,
x_scale,
weight_scale,
scale_major_mode="K",
out_dtype=input_2d.dtype,
backend="trtllm",
)
if bias is not None: