diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 2a87f70d1..ff10f0a56 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -8,9 +8,12 @@ from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul, ) -from sglang.srt.utils import is_hip +from sglang.srt.utils import get_bool_env_var, is_hip is_hip_ = is_hip() +if is_hip_ and get_bool_env_var("CK_MOE"): + from aiter import gemm_a8w8_blockscale + _is_cuda = torch.cuda.is_available() and torch.version.cuda if _is_cuda: from sgl_kernel import fp8_blockwise_scaled_mm @@ -78,6 +81,16 @@ def apply_w8a8_block_fp8_linear( output = fp8_blockwise_scaled_mm( q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype ) + elif is_hip_ and get_bool_env_var("CK_MOE"): + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False + ) + output = torch.zeros( + [q_input.shape[0], weight.shape[0]], + dtype=input.dtype, + device=q_input.device, + ) + gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output) else: q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=False