[Performance, Triton] Optimize over mask compute to tl.load in fused_moe_kernel (#1980)
This commit is contained in:
@@ -507,6 +507,12 @@ def _decode_grouped_att_m_fwd(
|
|||||||
|
|
||||||
num_warps = 4
|
num_warps = 4
|
||||||
|
|
||||||
|
extra_kargs = {}
|
||||||
|
if is_hip():
|
||||||
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||||
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||||
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
||||||
|
|
||||||
_fwd_grouped_kernel_stage1[grid](
|
_fwd_grouped_kernel_stage1[grid](
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
@@ -532,6 +538,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
Lk=Lk,
|
Lk=Lk,
|
||||||
|
**extra_kargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ def fused_moe_kernel(
|
|||||||
top_k: tl.constexpr,
|
top_k: tl.constexpr,
|
||||||
compute_type: tl.constexpr,
|
compute_type: tl.constexpr,
|
||||||
use_fp8: tl.constexpr,
|
use_fp8: tl.constexpr,
|
||||||
|
even_Ks: tl.constexpr,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||||
@@ -130,16 +131,24 @@ def fused_moe_kernel(
|
|||||||
# of fp32 values for higher accuracy.
|
# of fp32 values for higher accuracy.
|
||||||
# `accumulator` will be converted back to fp16 after the loop.
|
# `accumulator` will be converted back to fp16 after the loop.
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
# Load the next block of A and B, generate a mask by checking the
|
# Load the next block of A and B, generate a mask by checking the
|
||||||
# K dimension.
|
# K dimension.
|
||||||
a = tl.load(
|
if even_Ks:
|
||||||
a_ptrs,
|
a = tl.load(
|
||||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
a_ptrs,
|
||||||
other=0.0,
|
mask=token_mask[:, None],
|
||||||
)
|
other=0.0,
|
||||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
)
|
||||||
|
b = tl.load(b_ptrs)
|
||||||
|
else:
|
||||||
|
a = tl.load(
|
||||||
|
a_ptrs,
|
||||||
|
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
|
|
||||||
# We accumulate along the K dimension.
|
# We accumulate along the K dimension.
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
accumulator = tl.dot(a, b, acc=accumulator)
|
accumulator = tl.dot(a, b, acc=accumulator)
|
||||||
@@ -253,6 +262,12 @@ def invoke_fused_moe_kernel(
|
|||||||
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
K = B.shape[2] - padding_size
|
||||||
|
if K % config["BLOCK_SIZE_K"] == 0:
|
||||||
|
even_ks = True
|
||||||
|
else:
|
||||||
|
even_ks = False
|
||||||
|
|
||||||
fused_moe_kernel[grid](
|
fused_moe_kernel[grid](
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
@@ -278,6 +293,7 @@ def invoke_fused_moe_kernel(
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
use_fp8=use_fp8,
|
use_fp8=use_fp8,
|
||||||
|
even_Ks=even_ks,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user