ROCM: AITER BLOCK GEMM (#4075)
This commit is contained in:
@@ -8,9 +8,12 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_block_fp8_matmul,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||||
|
|
||||||
is_hip_ = is_hip()
|
is_hip_ = is_hip()
|
||||||
|
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||||
|
from aiter import gemm_a8w8_blockscale
|
||||||
|
|
||||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||||
@@ -78,6 +81,16 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
output = fp8_blockwise_scaled_mm(
|
output = fp8_blockwise_scaled_mm(
|
||||||
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
||||||
)
|
)
|
||||||
|
elif is_hip_ and get_bool_env_var("CK_MOE"):
|
||||||
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
|
input_2d, block_size[1], column_major_scales=False
|
||||||
|
)
|
||||||
|
output = torch.zeros(
|
||||||
|
[q_input.shape[0], weight.shape[0]],
|
||||||
|
dtype=input.dtype,
|
||||||
|
device=q_input.device,
|
||||||
|
)
|
||||||
|
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
||||||
else:
|
else:
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
input_2d, block_size[1], column_major_scales=False
|
input_2d, block_size[1], column_major_scales=False
|
||||||
|
|||||||
Reference in New Issue
Block a user