From 70dc2fbe2d1ebecbb9b1a052f864253c446ec301 Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Fri, 27 Dec 2024 16:32:17 +0800 Subject: [PATCH] =?UTF-8?q?Change=20extend=20attention=20kernel=20launch?= =?UTF-8?q?=20parameter=20for=20ROCm=20platform=20to=20=E2=80=A6=20(#2610)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: wunhuang Co-authored-by: HAI --- .../attention/triton_ops/extend_attention.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index b7afd62e7..b2654f1f7 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -292,27 +292,33 @@ def extend_attention_fwd( BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - if is_cuda_available and CUDA_CAPABILITY[0] >= 9: - if Lq <= 256: - BLOCK_M, BLOCK_N = (128, 64) - else: - BLOCK_M, BLOCK_N = (32, 64) - elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: - if Lq <= 128: - BLOCK_M, BLOCK_N = (128, 128) - elif Lq <= 256: - BLOCK_M, BLOCK_N = (64, 64) - else: - BLOCK_M, BLOCK_N = (32, 64) + if is_hip_: + BLOCK_M, BLOCK_N = (64, 64) + num_warps = 4 + else: - BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) + if is_cuda_available and CUDA_CAPABILITY[0] >= 9: + if Lq <= 256: + BLOCK_M, BLOCK_N = (128, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: + if Lq <= 128: + BLOCK_M, BLOCK_N = (128, 128) + elif Lq <= 256: + BLOCK_M, BLOCK_N = (64, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + else: + BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) + + num_warps = 4 if Lk <= 64 else 8 sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) - num_warps = 4 if Lk <= 64 else 8 num_stages = 1 extra_kargs = {}