AMD Prefill optimize (#3665)
Co-authored-by: AMD-dteng <dteng@amd.com> Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
@@ -74,6 +74,7 @@ def _fwd_kernel(
|
|||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
USE_CUSTOM_MASK: tl.constexpr,
|
USE_CUSTOM_MASK: tl.constexpr,
|
||||||
|
STORE_TRANSPOSE: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_seq = tl.program_id(0)
|
cur_seq = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -272,9 +273,18 @@ def _fwd_kernel(
|
|||||||
+ cur_head * stride_oh
|
+ cur_head * stride_oh
|
||||||
+ offs_dv[None, :]
|
+ offs_dv[None, :]
|
||||||
)
|
)
|
||||||
tl.store(
|
if STORE_TRANSPOSE:
|
||||||
O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
|
tl.store(
|
||||||
)
|
O_Extend + offs_o.T,
|
||||||
|
(acc / deno[:, None]).T,
|
||||||
|
mask=(mask_m[:, None] & mask_dv[None, :]).T,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tl.store(
|
||||||
|
O_Extend + offs_o,
|
||||||
|
acc / deno[:, None],
|
||||||
|
mask=mask_m[:, None] & mask_dv[None, :],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def extend_attention_fwd(
|
def extend_attention_fwd(
|
||||||
@@ -319,8 +329,8 @@ def extend_attention_fwd(
|
|||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
if is_hip_:
|
if is_hip_:
|
||||||
BLOCK_M, BLOCK_N = (64, 64)
|
BLOCK_M, BLOCK_N = (32, 32)
|
||||||
num_warps = 4
|
num_warps = 2
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
|
||||||
@@ -388,6 +398,7 @@ def extend_attention_fwd(
|
|||||||
Lq=Lq,
|
Lq=Lq,
|
||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||||
|
STORE_TRANSPOSE=is_hip_,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
**extra_kargs,
|
**extra_kargs,
|
||||||
|
|||||||
Reference in New Issue
Block a user