Refine pre_reorder_triton_kernel slightly to improve performance (#6627)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -0,0 +1,100 @@
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel
|
||||
|
||||
|
||||
def benchmark_pre_reorder(batch_size, topk, model_config):
|
||||
hidden_size = model_config["hidden_size"]
|
||||
block_size = model_config["block_size"]
|
||||
expert_range = model_config["expert_range"]
|
||||
|
||||
input_ptr = torch.randn(batch_size, hidden_size, dtype=torch.float16, device="cuda")
|
||||
gateup_input_ptr = torch.zeros(
|
||||
batch_size * topk, hidden_size, dtype=torch.float16, device="cuda"
|
||||
)
|
||||
src2dst_ptr = torch.randint(
|
||||
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
topk_ids_ptr = torch.randint(
|
||||
expert_range[0],
|
||||
expert_range[1] + 1,
|
||||
(batch_size, topk),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
a1_scales_ptr = torch.rand(
|
||||
expert_range[1] - expert_range[0] + 1, dtype=torch.float32, device="cuda"
|
||||
)
|
||||
|
||||
input_ptr = input_ptr.view(-1)
|
||||
gateup_input_ptr = gateup_input_ptr.view(-1)
|
||||
src2dst_ptr = src2dst_ptr.view(-1)
|
||||
topk_ids_ptr = topk_ids_ptr.view(-1)
|
||||
|
||||
def run_kernel():
|
||||
pre_reorder_triton_kernel[(batch_size,)](
|
||||
input_ptr,
|
||||
gateup_input_ptr,
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
a1_scales_ptr,
|
||||
expert_range[0],
|
||||
expert_range[1],
|
||||
topk,
|
||||
hidden_size,
|
||||
block_size,
|
||||
)
|
||||
|
||||
for _ in range(10):
|
||||
run_kernel()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ms, _, _ = triton.testing.do_bench(run_kernel, quantiles=[0.5, 0.2, 0.8])
|
||||
return ms
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--hidden-size", type=int, required=True)
|
||||
parser.add_argument("--block-size", type=int, default=512)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_config = {
|
||||
"hidden_size": args.hidden_size,
|
||||
"block_size": args.block_size,
|
||||
"expert_range": (0, 255),
|
||||
}
|
||||
|
||||
batch_sizes = [64, 128, 256, 512, 640, 768, 1024]
|
||||
topks = [2, 4, 8]
|
||||
configs = list(itertools.product(batch_sizes, topks))
|
||||
|
||||
# Prepare results dict: keys = topk, each row is indexed by batch_size
|
||||
results_dict = {topk: {} for topk in topks}
|
||||
|
||||
for batch_size, topk in configs:
|
||||
ms = benchmark_pre_reorder(batch_size, topk, model_config)
|
||||
results_dict[topk][batch_size] = ms
|
||||
|
||||
# Build dataframe
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"batch_size": batch_sizes,
|
||||
**{
|
||||
f"TopK={topk}": [results_dict[topk].get(bs, None) for bs in batch_sizes]
|
||||
for topk in topks
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
print("\npre-reorder-performance:")
|
||||
print(df.to_string(index=False, float_format="%.6f"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -184,8 +184,10 @@ def pre_reorder_triton_kernel(
|
||||
src_idx = tl.program_id(0)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
|
||||
src_ptr = input_ptr + src_idx * hidden_size
|
||||
|
||||
vec = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
for idx in range(topk):
|
||||
expert_id = tl.load(topk_ids_ptr + idx)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
@@ -197,7 +199,7 @@ def pre_reorder_triton_kernel(
|
||||
dst_idx = tl.load(src2dst_ptr + idx)
|
||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
offset = start_offset + vec
|
||||
mask = offset < hidden_size
|
||||
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
||||
out_data = (in_data * scale).to(OutDtype)
|
||||
@@ -481,8 +483,11 @@ def post_reorder_triton_kernel(
|
||||
|
||||
computed = False
|
||||
store_ptr = output_ptr + src_idx * hidden_size
|
||||
|
||||
vec = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
offset = start_offset + vec
|
||||
mask = offset < hidden_size
|
||||
|
||||
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
||||
@@ -499,7 +504,7 @@ def post_reorder_triton_kernel(
|
||||
|
||||
if computed == False:
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
offset = start_offset + vec
|
||||
mask = offset < hidden_size
|
||||
tl.store(
|
||||
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
|
||||
|
||||
Reference in New Issue
Block a user