diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index b7afd62e7..b2654f1f7 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -292,27 +292,33 @@ def extend_attention_fwd( BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - if is_cuda_available and CUDA_CAPABILITY[0] >= 9: - if Lq <= 256: - BLOCK_M, BLOCK_N = (128, 64) - else: - BLOCK_M, BLOCK_N = (32, 64) - elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: - if Lq <= 128: - BLOCK_M, BLOCK_N = (128, 128) - elif Lq <= 256: - BLOCK_M, BLOCK_N = (64, 64) - else: - BLOCK_M, BLOCK_N = (32, 64) + if is_hip_: + BLOCK_M, BLOCK_N = (64, 64) + num_warps = 4 + else: - BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) + if is_cuda_available and CUDA_CAPABILITY[0] >= 9: + if Lq <= 256: + BLOCK_M, BLOCK_N = (128, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: + if Lq <= 128: + BLOCK_M, BLOCK_N = (128, 128) + elif Lq <= 256: + BLOCK_M, BLOCK_N = (64, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + else: + BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) + + num_warps = 4 if Lk <= 64 else 8 sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) - num_warps = 4 if Lk <= 64 else 8 num_stages = 1 extra_kargs = {}