From 2d4ce1b7928d253144bc4b030a643af2b9267b40 Mon Sep 17 00:00:00 2001 From: HAI Date: Wed, 30 Oct 2024 17:33:36 -0700 Subject: [PATCH] =?UTF-8?q?[Performance,=20Triton=20Kernel=20Args]=20=5Fde?= =?UTF-8?q?code=5Fgrouped=5Fsoftmax=5Freducev=5Ffwd=E2=80=A6=20(#1845)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../srt/layers/attention/triton_ops/decode_attention.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 9dafbb513..6c2a62dcd 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -24,6 +24,8 @@ It supports page size = 1. import triton import triton.language as tl +from sglang.srt.utils import is_hip + @triton.jit def tanh(x): @@ -553,6 +555,12 @@ def _decode_grouped_softmax_reducev_fwd( Lv = v_buffer.shape[-1] BLOCK_DMODEL = triton.next_power_of_2(Lv) + extra_kargs = {} + if is_hip(): + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + _fwd_grouped_kernel_stage2[grid]( logits, v_buffer, @@ -575,6 +583,7 @@ def _decode_grouped_softmax_reducev_fwd( Lv=Lv, num_warps=num_warps, num_stages=1, + **extra_kargs, )