diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 52a72d7fe..8c588bd9c 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -25,6 +25,7 @@ import triton.language as tl from sglang.srt.layers.attention.triton_ops.prefill_attention import ( context_attention_fwd, ) +from sglang.srt.utils import is_hip is_cuda_available = torch.cuda.is_available() if is_cuda_available: @@ -311,6 +312,10 @@ def extend_attention_fwd( num_warps = 4 if Lk <= 64 else 8 num_stages = 1 + extra_kargs = {} + if is_hip(): + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + _fwd_kernel[grid]( q_extend, k_extend, @@ -348,6 +353,7 @@ def extend_attention_fwd( Lv=Lv, num_warps=num_warps, num_stages=num_stages, + **extra_kargs, )