ROCM: AITER BLOCK GEMM (#4075)
This commit is contained in:
@@ -8,9 +8,12 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
from aiter import gemm_a8w8_blockscale
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
if _is_cuda:
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
@@ -78,6 +81,16 @@ def apply_w8a8_block_fp8_linear(
|
||||
output = fp8_blockwise_scaled_mm(
|
||||
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
||||
)
|
||||
elif is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
output = torch.zeros(
|
||||
[q_input.shape[0], weight.shape[0]],
|
||||
dtype=input.dtype,
|
||||
device=q_input.device,
|
||||
)
|
||||
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
|
||||
Reference in New Issue
Block a user