[Performance, Triton] Optimize over mask compute to tl.load in fused_moe_kernel (#1980)
This commit is contained in:
@@ -507,6 +507,12 @@ def _decode_grouped_att_m_fwd(
|
||||
|
||||
num_warps = 4
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip():
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||
|
||||
_fwd_grouped_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
@@ -532,6 +538,7 @@ def _decode_grouped_att_m_fwd(
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
Lk=Lk,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ def fused_moe_kernel(
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_fp8: tl.constexpr,
|
||||
even_Ks: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||
@@ -130,16 +131,24 @@ def fused_moe_kernel(
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
if even_Ks:
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs)
|
||||
else:
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
|
||||
# We accumulate along the K dimension.
|
||||
if use_fp8:
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
@@ -253,6 +262,12 @@ def invoke_fused_moe_kernel(
|
||||
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
K = B.shape[2] - padding_size
|
||||
if K % config["BLOCK_SIZE_K"] == 0:
|
||||
even_ks = True
|
||||
else:
|
||||
even_ks = False
|
||||
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
@@ -278,6 +293,7 @@ def invoke_fused_moe_kernel(
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8,
|
||||
even_Ks=even_ks,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user