Use FlashInfer's TRTLLM FP8 Blockscale GEMM (#8588)
This commit is contained in:
@@ -161,16 +161,16 @@ def flashinfer_gemm_w8a8_block_fp8_linear(
|
|||||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||||
|
|
||||||
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
||||||
input_2d, block_size[1], column_major_scales=False
|
input_2d, block_size[1], column_major_scales=True
|
||||||
)
|
)
|
||||||
|
# TRTLLM requires column-major scaling factors
|
||||||
output = gemm_fp8_nt_groupwise(
|
output = gemm_fp8_nt_groupwise(
|
||||||
q_input,
|
q_input,
|
||||||
weight,
|
weight,
|
||||||
x_scale,
|
x_scale,
|
||||||
weight_scale,
|
weight_scale,
|
||||||
scale_major_mode="K",
|
|
||||||
out_dtype=input_2d.dtype,
|
out_dtype=input_2d.dtype,
|
||||||
|
backend="trtllm",
|
||||||
)
|
)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user