Tune paged attention parameters for AMD GPU. (#3255)
This commit is contained in:
committed by
GitHub
parent
959dca4fc7
commit
d9eb9358cc
@@ -181,6 +181,9 @@ def _decode_att_m_fwd(
|
|||||||
logit_cap,
|
logit_cap,
|
||||||
):
|
):
|
||||||
BLOCK = 64
|
BLOCK = 64
|
||||||
|
# [TODO] work around SGPR limit on MI3xx
|
||||||
|
if is_hip_:
|
||||||
|
BLOCK = 8
|
||||||
NUM_KV_SPLITS = num_kv_splits
|
NUM_KV_SPLITS = num_kv_splits
|
||||||
Lk = k_buffer.shape[-1]
|
Lk = k_buffer.shape[-1]
|
||||||
Lv = v_buffer.shape[-1]
|
Lv = v_buffer.shape[-1]
|
||||||
@@ -194,6 +197,8 @@ def _decode_att_m_fwd(
|
|||||||
num_warps = 4
|
num_warps = 4
|
||||||
else:
|
else:
|
||||||
num_warps = 2
|
num_warps = 2
|
||||||
|
if is_hip_:
|
||||||
|
num_warps = 1
|
||||||
|
|
||||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
@@ -433,10 +438,12 @@ def _decode_grouped_att_m_fwd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
|
num_stages = 2
|
||||||
if is_hip_:
|
if is_hip_:
|
||||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
|
num_stages = 1
|
||||||
|
|
||||||
_fwd_grouped_kernel_stage1[grid](
|
_fwd_grouped_kernel_stage1[grid](
|
||||||
q,
|
q,
|
||||||
@@ -467,7 +474,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
num_warps=4,
|
num_warps=4,
|
||||||
num_stages=2,
|
num_stages=num_stages,
|
||||||
Lk=Lk,
|
Lk=Lk,
|
||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
**extra_kargs,
|
**extra_kargs,
|
||||||
|
|||||||
@@ -273,6 +273,10 @@ class ServerArgs:
|
|||||||
) and check_gguf_file(self.model_path):
|
) and check_gguf_file(self.model_path):
|
||||||
self.quantization = self.load_format = "gguf"
|
self.quantization = self.load_format = "gguf"
|
||||||
|
|
||||||
|
# AMD-specific Triton attention KV splits default number
|
||||||
|
if is_hip():
|
||||||
|
self.triton_attention_num_kv_splits = 16
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
# Model and port args
|
# Model and port args
|
||||||
|
|||||||
Reference in New Issue
Block a user