diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 9dafbb513..6c2a62dcd 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -24,6 +24,8 @@ It supports page size = 1. import triton import triton.language as tl +from sglang.srt.utils import is_hip + @triton.jit def tanh(x): @@ -553,6 +555,12 @@ def _decode_grouped_softmax_reducev_fwd( Lv = v_buffer.shape[-1] BLOCK_DMODEL = triton.next_power_of_2(Lv) + extra_kargs = {} + if is_hip(): + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + _fwd_grouped_kernel_stage2[grid]( logits, v_buffer, @@ -575,6 +583,7 @@ def _decode_grouped_softmax_reducev_fwd( Lv=Lv, num_warps=num_warps, num_stages=1, + **extra_kargs, )