ROCM: AITER BLOCK GEMM (#4075)

This commit is contained in:
yigex
2025-03-05 19:10:49 +08:00
committed by GitHub
parent e5760bc40a
commit 5be8f1ed98

View File

@@ -8,9 +8,12 @@ from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
from sglang.srt.utils import is_hip
from sglang.srt.utils import get_bool_env_var, is_hip
is_hip_ = is_hip()
if is_hip_ and get_bool_env_var("CK_MOE"):
from aiter import gemm_a8w8_blockscale
_is_cuda = torch.cuda.is_available() and torch.version.cuda
if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm
@@ -78,6 +81,16 @@ def apply_w8a8_block_fp8_linear(
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
)
elif is_hip_ and get_bool_env_var("CK_MOE"):
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = torch.zeros(
[q_input.shape[0], weight.shape[0]],
dtype=input.dtype,
device=q_input.device,
)
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
else:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False