[Performance, Triton Kernel Args] _decode_grouped_softmax_reducev_fwd… (#1845)
This commit is contained in:
@@ -24,6 +24,8 @@ It supports page size = 1.
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def tanh(x):
|
def tanh(x):
|
||||||
@@ -553,6 +555,12 @@ def _decode_grouped_softmax_reducev_fwd(
|
|||||||
Lv = v_buffer.shape[-1]
|
Lv = v_buffer.shape[-1]
|
||||||
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
|
extra_kargs = {}
|
||||||
|
if is_hip():
|
||||||
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||||
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||||
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
|
|
||||||
_fwd_grouped_kernel_stage2[grid](
|
_fwd_grouped_kernel_stage2[grid](
|
||||||
logits,
|
logits,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
@@ -575,6 +583,7 @@ def _decode_grouped_softmax_reducev_fwd(
|
|||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
|
**extra_kargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user