use sgl_per_token_group_quant_fp8 kernel (#3493)
This commit is contained in:
@@ -33,6 +33,10 @@ _is_rocm = torch.cuda.is_available() and torch.version.hip
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
if _is_cuda or _is_rocm:
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
|
||||
@@ -488,7 +492,10 @@ def invoke_fused_moe_kernel(
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
if _is_cuda:
|
||||
A, A_scale = sglang_per_token_group_quant_fp8(A, block_k)
|
||||
else:
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
|
||||
Reference in New Issue
Block a user