[kernel optimize] benchmark write_req_to_token_pool_triton and optimize kernel (#2509)
This commit is contained in:
@@ -0,0 +1,345 @@
|
|||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def write_req_to_token_pool_triton(
|
||||||
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
req_to_token_ptr_stride: tl.constexpr,
|
||||||
|
):
|
||||||
|
BLOCK_SIZE: tl.constexpr = 512
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
|
req_pool_index = tl.load(req_pool_indices + pid)
|
||||||
|
pre_len = tl.load(pre_lens + pid)
|
||||||
|
seq_len = tl.load(seq_lens + pid)
|
||||||
|
|
||||||
|
# TODO: optimize this?
|
||||||
|
cumsum_start = 0
|
||||||
|
for i in range(pid):
|
||||||
|
cumsum_start += tl.load(extend_lens + i)
|
||||||
|
|
||||||
|
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
|
||||||
|
for i in range(num_loop):
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||||
|
mask = offset < (seq_len - pre_len)
|
||||||
|
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
|
||||||
|
tl.store(
|
||||||
|
req_to_token_ptr
|
||||||
|
+ req_pool_index * req_to_token_ptr_stride
|
||||||
|
+ offset
|
||||||
|
+ pre_len,
|
||||||
|
value,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def write_req_to_token_pool_triton_optimize(
|
||||||
|
req_to_token_ptr, # [max_batch, max_context_len]
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
req_to_token_ptr_stride: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid_batch = tl.program_id(0)
|
||||||
|
pid_token = tl.program_id(1)
|
||||||
|
|
||||||
|
req_pool_index = tl.load(req_pool_indices + pid_batch)
|
||||||
|
pre_len = tl.load(pre_lens + pid_batch)
|
||||||
|
seq_len = tl.load(seq_lens + pid_batch)
|
||||||
|
extend_len = seq_len - pre_len
|
||||||
|
|
||||||
|
cumsum_start = 0
|
||||||
|
for i in range(pid_batch):
|
||||||
|
cumsum_start += tl.load(extend_lens + i)
|
||||||
|
|
||||||
|
token_start = pid_token * BLOCK_SIZE
|
||||||
|
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE)
|
||||||
|
actual_offset = token_start + offset
|
||||||
|
mask = actual_offset < extend_len
|
||||||
|
|
||||||
|
src_ptr = out_cache_loc + cumsum_start + actual_offset
|
||||||
|
src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE)
|
||||||
|
value = tl.load(src_ptr, mask=mask)
|
||||||
|
dst_ptr = (
|
||||||
|
req_to_token_ptr
|
||||||
|
+ req_pool_index * req_to_token_ptr_stride
|
||||||
|
+ actual_offset
|
||||||
|
+ pre_len
|
||||||
|
)
|
||||||
|
dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE)
|
||||||
|
|
||||||
|
tl.store(dst_ptr, value, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def write_req_to_token_pool_reference(
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
pre_lens: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
extend_lens: torch.Tensor,
|
||||||
|
out_cache_loc: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""Reference implementation using PyTorch"""
|
||||||
|
for i in range(len(req_pool_indices)):
|
||||||
|
req_pool_idx = req_pool_indices[i].item()
|
||||||
|
pre_len = pre_lens[i].item()
|
||||||
|
seq_len = seq_lens[i].item()
|
||||||
|
extend_len = extend_lens[i].item()
|
||||||
|
|
||||||
|
cumsum_start = sum(extend_lens[:i].tolist())
|
||||||
|
|
||||||
|
# Copy values from out_cache_loc to req_to_token
|
||||||
|
req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[
|
||||||
|
cumsum_start : cumsum_start + extend_len
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_req_to_token_pool():
|
||||||
|
max_batch = 4097
|
||||||
|
max_context_len = 6148
|
||||||
|
batch_size = 1
|
||||||
|
extend_len = 14
|
||||||
|
|
||||||
|
# Initialize input tensors
|
||||||
|
req_to_token = torch.zeros(
|
||||||
|
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda")
|
||||||
|
pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda")
|
||||||
|
seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda")
|
||||||
|
extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda")
|
||||||
|
out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
# Create copies for reference implementation
|
||||||
|
req_to_token_ref = req_to_token.clone()
|
||||||
|
req_to_token_opt = req_to_token.clone()
|
||||||
|
|
||||||
|
# Run original triton kernel
|
||||||
|
write_req_to_token_pool_triton[(batch_size,)](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run optimized triton kernel
|
||||||
|
def grid(batch_size, extend_len):
|
||||||
|
num_token_blocks = triton.cdiv(extend_len, 512)
|
||||||
|
return (batch_size, num_token_blocks)
|
||||||
|
|
||||||
|
write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)](
|
||||||
|
req_to_token_opt,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run reference implementation
|
||||||
|
write_req_to_token_pool_reference(
|
||||||
|
req_to_token_ref,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
torch.testing.assert_close(req_to_token, req_to_token_ref)
|
||||||
|
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
|
||||||
|
|
||||||
|
# Test case 2: batch size > 1
|
||||||
|
batch_size = 3
|
||||||
|
extend_lens_list = [14, 20, 30]
|
||||||
|
total_extend_len = sum(extend_lens_list)
|
||||||
|
|
||||||
|
req_to_token = torch.zeros(
|
||||||
|
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda")
|
||||||
|
pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda")
|
||||||
|
seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda")
|
||||||
|
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
|
||||||
|
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
req_to_token_ref = req_to_token.clone()
|
||||||
|
req_to_token_opt = req_to_token.clone()
|
||||||
|
|
||||||
|
# Run original triton kernel
|
||||||
|
write_req_to_token_pool_triton[(batch_size,)](
|
||||||
|
req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run optimized triton kernel
|
||||||
|
max_extend_len = max(extend_lens_list)
|
||||||
|
write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)](
|
||||||
|
req_to_token_opt,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run reference implementation
|
||||||
|
write_req_to_token_pool_reference(
|
||||||
|
req_to_token_ref,
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
torch.testing.assert_close(req_to_token, req_to_token_ref)
|
||||||
|
torch.testing.assert_close(req_to_token_opt, req_to_token_ref)
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmark():
|
||||||
|
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128]
|
||||||
|
extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||||
|
configs = list(itertools.product(batch_sizes, extend_lens))
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size", "extend_len"],
|
||||||
|
x_vals=configs,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["reference", "triton", "triton_optimize"],
|
||||||
|
line_names=["PyTorch", "Triton", "Triton Optimized"],
|
||||||
|
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name="write-req-to-token-pool-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, extend_len, provider):
|
||||||
|
max_batch = 256
|
||||||
|
max_context_len = 16384
|
||||||
|
|
||||||
|
extend_lens_list = [extend_len] * batch_size
|
||||||
|
total_extend_len = sum(extend_lens_list)
|
||||||
|
|
||||||
|
req_to_token = torch.zeros(
|
||||||
|
(max_batch, max_context_len), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda")
|
||||||
|
pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8
|
||||||
|
seq_lens = pre_lens + extend_len
|
||||||
|
extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda")
|
||||||
|
out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "reference":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: write_req_to_token_pool_reference(
|
||||||
|
req_to_token.clone(),
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif provider == "triton":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: write_req_to_token_pool_triton[(batch_size,)](
|
||||||
|
req_to_token.clone(),
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def run_optimized():
|
||||||
|
block_size = 128 if extend_len <= 1024 else 512
|
||||||
|
grid_config = (batch_size, triton.cdiv(extend_len, block_size))
|
||||||
|
write_req_to_token_pool_triton_optimize[grid_config](
|
||||||
|
req_to_token.clone(),
|
||||||
|
req_pool_indices,
|
||||||
|
pre_lens,
|
||||||
|
seq_lens,
|
||||||
|
extend_lens,
|
||||||
|
out_cache_loc,
|
||||||
|
max_context_len,
|
||||||
|
BLOCK_SIZE=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
run_optimized, quantiles=quantiles
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"):
|
||||||
|
"""Run benchmark and save results"""
|
||||||
|
|
||||||
|
# Ensure save path exists
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Run correctness test
|
||||||
|
test_write_req_to_token_pool()
|
||||||
|
print("Correctness test passed!")
|
||||||
|
|
||||||
|
# Run performance test
|
||||||
|
benchmark = get_benchmark()
|
||||||
|
benchmark.run(print_data=True, save_path=save_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/benchmark_ops/write_req_to_token_pool/",
|
||||||
|
help="Path to save benchmark results",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
run_benchmark(args.save_path)
|
||||||
Reference in New Issue
Block a user