[Performance, Triton Kernel Args] extend_attention, optimize kern args to _fwd_kernel (#1941)
This commit is contained in:
@@ -25,6 +25,7 @@ import triton.language as tl
|
|||||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||||
context_attention_fwd,
|
context_attention_fwd,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
is_cuda_available = torch.cuda.is_available()
|
is_cuda_available = torch.cuda.is_available()
|
||||||
if is_cuda_available:
|
if is_cuda_available:
|
||||||
@@ -311,6 +312,10 @@ def extend_attention_fwd(
|
|||||||
num_warps = 4 if Lk <= 64 else 8
|
num_warps = 4 if Lk <= 64 else 8
|
||||||
num_stages = 1
|
num_stages = 1
|
||||||
|
|
||||||
|
extra_kargs = {}
|
||||||
|
if is_hip():
|
||||||
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
|
|
||||||
_fwd_kernel[grid](
|
_fwd_kernel[grid](
|
||||||
q_extend,
|
q_extend,
|
||||||
k_extend,
|
k_extend,
|
||||||
@@ -348,6 +353,7 @@ def extend_attention_fwd(
|
|||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
|
**extra_kargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user