[Performance, Triton Kernel Args] extend_attention, optimize kern args to _fwd_kernel (#1941)
This commit is contained in:
@@ -25,6 +25,7 @@ import triton.language as tl
|
||||
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
||||
context_attention_fwd,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
is_cuda_available = torch.cuda.is_available()
|
||||
if is_cuda_available:
|
||||
@@ -311,6 +312,10 @@ def extend_attention_fwd(
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
num_stages = 1
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip():
|
||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q_extend,
|
||||
k_extend,
|
||||
@@ -348,6 +353,7 @@ def extend_attention_fwd(
|
||||
Lv=Lv,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user