refactor apply_w8a8_block_fp8_linear in fp (#6545)

This commit is contained in:
ChangyiYang
2025-05-29 00:15:11 -07:00
committed by GitHub
parent 7e41290082
commit 485a023bd8
5 changed files with 283 additions and 120 deletions

View File

@@ -9,7 +9,9 @@ from deep_gemm import get_col_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul
from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
)
def get_weight_shapes(args):