Change extend attention kernel launch parameter for ROCm platform to … (#2610)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
@@ -292,27 +292,33 @@ def extend_attention_fwd(
|
|||||||
BLOCK_DPE = 0
|
BLOCK_DPE = 0
|
||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
if is_hip_:
|
||||||
if Lq <= 256:
|
BLOCK_M, BLOCK_N = (64, 64)
|
||||||
BLOCK_M, BLOCK_N = (128, 64)
|
num_warps = 4
|
||||||
else:
|
|
||||||
BLOCK_M, BLOCK_N = (32, 64)
|
|
||||||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
|
||||||
if Lq <= 128:
|
|
||||||
BLOCK_M, BLOCK_N = (128, 128)
|
|
||||||
elif Lq <= 256:
|
|
||||||
BLOCK_M, BLOCK_N = (64, 64)
|
|
||||||
else:
|
|
||||||
BLOCK_M, BLOCK_N = (32, 64)
|
|
||||||
else:
|
else:
|
||||||
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
||||||
|
if Lq <= 256:
|
||||||
|
BLOCK_M, BLOCK_N = (128, 64)
|
||||||
|
else:
|
||||||
|
BLOCK_M, BLOCK_N = (32, 64)
|
||||||
|
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||||
|
if Lq <= 128:
|
||||||
|
BLOCK_M, BLOCK_N = (128, 128)
|
||||||
|
elif Lq <= 256:
|
||||||
|
BLOCK_M, BLOCK_N = (64, 64)
|
||||||
|
else:
|
||||||
|
BLOCK_M, BLOCK_N = (32, 64)
|
||||||
|
else:
|
||||||
|
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
||||||
|
|
||||||
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
|
|
||||||
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
||||||
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
||||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||||
|
|
||||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||||
num_warps = 4 if Lk <= 64 else 8
|
|
||||||
num_stages = 1
|
num_stages = 1
|
||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user