remove activation dependency in fused_moe (#3433)
This commit is contained in:
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
||||
|
||||
is_hip_flag = is_hip()
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||
@@ -27,6 +27,15 @@ enable_moe_align_block_size_triton = bool(
|
||||
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||
)
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
|
||||
if _is_cuda or _is_rocm:
|
||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel(
|
||||
@@ -989,9 +998,15 @@ def fused_experts_impl(
|
||||
)
|
||||
|
||||
if activation == "silu":
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
if _is_cuda:
|
||||
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
if _is_cuda:
|
||||
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {activation=}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user