From 087ab832236ef264746d8c75af8cd8752f56ca6b Mon Sep 17 00:00:00 2001 From: HAI Date: Sun, 10 Nov 2024 18:54:43 -0800 Subject: [PATCH] [Performance, Triton] Optimize over mask compute to tl.load in fused_moe_kernel (#1980) --- .../attention/triton_ops/decode_attention.py | 7 +++++ .../sglang/srt/layers/fused_moe/fused_moe.py | 30 ++++++++++++++----- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 6c2a62dcd..b87062d36 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -507,6 +507,12 @@ def _decode_grouped_att_m_fwd( num_warps = 4 + extra_kargs = {} + if is_hip(): + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + _fwd_grouped_kernel_stage1[grid]( q, k_buffer, @@ -532,6 +538,7 @@ def _decode_grouped_att_m_fwd( num_warps=num_warps, num_stages=1, Lk=Lk, + **extra_kargs, ) diff --git a/python/sglang/srt/layers/fused_moe/fused_moe.py b/python/sglang/srt/layers/fused_moe/fused_moe.py index 646cea14d..3e8c2eae0 100644 --- a/python/sglang/srt/layers/fused_moe/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe/fused_moe.py @@ -54,6 +54,7 @@ def fused_moe_kernel( top_k: tl.constexpr, compute_type: tl.constexpr, use_fp8: tl.constexpr, + even_Ks: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -130,16 +131,24 @@ def fused_moe_kernel( # of fp32 values for higher accuracy. # `accumulator` will be converted back to fp16 after the loop. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + if even_Ks: + a = tl.load( + a_ptrs, + mask=token_mask[:, None], + other=0.0, + ) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. if use_fp8: accumulator = tl.dot(a, b, acc=accumulator) @@ -253,6 +262,12 @@ def invoke_fused_moe_kernel( * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + K = B.shape[2] - padding_size + if K % config["BLOCK_SIZE_K"] == 0: + even_ks = True + else: + even_ks = False + fused_moe_kernel[grid]( A, B, @@ -278,6 +293,7 @@ def invoke_fused_moe_kernel( top_k=top_k, compute_type=compute_type, use_fp8=use_fp8, + even_Ks=even_ks, **config, )