diff --git a/benchmark/kernels/scheduler_batch/benchmark_get_last_loc_triton.py b/benchmark/kernels/scheduler_batch/benchmark_get_last_loc_triton.py new file mode 100644 index 000000000..3e17205e7 --- /dev/null +++ b/benchmark/kernels/scheduler_batch/benchmark_get_last_loc_triton.py @@ -0,0 +1,169 @@ +import os + +import torch +import triton +import triton.language as tl + + +@torch.compile(dynamic=True) +def get_last_loc_torch( + req_to_token: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, +) -> torch.Tensor: + return torch.where( + prefix_lens_tensor > 0, + req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1], + torch.full_like(prefix_lens_tensor, -1), + ) + + +@triton.jit +def get_last_loc_kernel( + req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + result, + num_tokens, + req_to_token_stride, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + mask = offset < num_tokens + + prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0) + req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0) + + token_mask = prefix_lens > 0 + token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1) + tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1) + + tl.store(result + offset, tokens, mask=mask) + + +def get_last_loc_triton( + req_to_token: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, +) -> torch.Tensor: + BLOCK_SIZE = 256 + num_tokens = prefix_lens_tensor.shape[0] + result = torch.empty_like(prefix_lens_tensor) + grid = (triton.cdiv(num_tokens, BLOCK_SIZE),) + + get_last_loc_kernel[grid]( + req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + result, + num_tokens, + req_to_token.stride(0), + BLOCK_SIZE, + ) + return result + + +def test_get_last_loc(): + max_batch = 4097 + max_context_len = 6148 + batch_size = 20 + + # Initialize input tensors + req_to_token = torch.zeros( + (max_batch, max_context_len), dtype=torch.int32, device="cuda" + ) + req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda") + pre_lens = torch.randint( + -max_context_len // 2, + max_context_len, + (batch_size,), + dtype=torch.int64, + device="cuda", + ) + + last_loc_res = get_last_loc_triton(req_to_token, req_pool_indices, pre_lens) + last_loc_ref = get_last_loc_torch(req_to_token, req_pool_indices, pre_lens) + + # Compare results + torch.testing.assert_close(last_loc_res, last_loc_ref) + + +def get_benchmark(): + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=batch_sizes, + line_arg="provider", + line_vals=["reference", "triton"], + line_names=["PyTorch", "Triton"], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name="get-last-loc-performance", + args={}, + ) + ) + def benchmark(batch_size, provider): + max_batch = 2048 + max_context_len = 16384 + + req_to_token = torch.zeros( + (max_batch, max_context_len), dtype=torch.int32, device="cuda" + ) + req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda") + pre_lens = torch.randint( + -max_context_len // 2, + max_context_len, + (batch_size,), + dtype=torch.int64, + device="cuda", + ) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "reference": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: get_last_loc_torch(req_to_token, req_pool_indices, pre_lens), + quantiles=quantiles, + ) + elif provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: get_last_loc_triton(req_to_token, req_pool_indices, pre_lens), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +def run_benchmark(save_path: str = "./configs/benchmark_ops/get_last_loc/"): + """Run benchmark and save results""" + + # Ensure save path exists + os.makedirs(save_path, exist_ok=True) + + # Run correctness test + test_get_last_loc() + print("Correctness test passed!") + + # Run performance test + benchmark = get_benchmark() + benchmark.run(print_data=True, save_path=save_path) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/get_last_loc/", + help="Path to save benchmark results", + ) + args = parser.parse_args() + + run_benchmark(args.save_path)