import os import torch import triton import triton.language as tl @torch.compile(dynamic=True) def get_last_loc_torch( req_to_token: torch.Tensor, req_pool_indices_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor, ) -> torch.Tensor: return torch.where( prefix_lens_tensor > 0, req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1], torch.full_like(prefix_lens_tensor, -1), ) @triton.jit def get_last_loc_kernel( req_to_token, req_pool_indices_tensor, prefix_lens_tensor, result, num_tokens, req_to_token_stride, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE mask = offset < num_tokens prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0) req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0) token_mask = prefix_lens > 0 token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1) tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1) tl.store(result + offset, tokens, mask=mask) def get_last_loc_triton( req_to_token: torch.Tensor, req_pool_indices_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor, ) -> torch.Tensor: BLOCK_SIZE = 256 num_tokens = prefix_lens_tensor.shape[0] result = torch.empty_like(prefix_lens_tensor) grid = (triton.cdiv(num_tokens, BLOCK_SIZE),) get_last_loc_kernel[grid]( req_to_token, req_pool_indices_tensor, prefix_lens_tensor, result, num_tokens, req_to_token.stride(0), BLOCK_SIZE, ) return result def test_get_last_loc(): max_batch = 4097 max_context_len = 6148 batch_size = 20 # Initialize input tensors req_to_token = torch.zeros( (max_batch, max_context_len), dtype=torch.int32, device="cuda" ) req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda") pre_lens = torch.randint( -max_context_len // 2, max_context_len, (batch_size,), dtype=torch.int64, device="cuda", ) last_loc_res = get_last_loc_triton(req_to_token, req_pool_indices, pre_lens) last_loc_ref = get_last_loc_torch(req_to_token, req_pool_indices, pre_lens) # Compare results torch.testing.assert_close(last_loc_res, last_loc_ref) def get_benchmark(): batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], x_vals=batch_sizes, line_arg="provider", line_vals=["reference", "triton"], line_names=["PyTorch", "Triton"], styles=[("blue", "-"), ("green", "-")], ylabel="us", plot_name="get-last-loc-performance", args={}, ) ) def benchmark(batch_size, provider): max_batch = 2048 max_context_len = 16384 req_to_token = torch.zeros( (max_batch, max_context_len), dtype=torch.int32, device="cuda" ) req_pool_indices = torch.arange(batch_size, dtype=torch.int64, device="cuda") pre_lens = torch.randint( -max_context_len // 2, max_context_len, (batch_size,), dtype=torch.int64, device="cuda", ) quantiles = [0.5, 0.2, 0.8] if provider == "reference": ms, min_ms, max_ms = triton.testing.do_bench( lambda: get_last_loc_torch(req_to_token, req_pool_indices, pre_lens), quantiles=quantiles, ) elif provider == "triton": ms, min_ms, max_ms = triton.testing.do_bench( lambda: get_last_loc_triton(req_to_token, req_pool_indices, pre_lens), quantiles=quantiles, ) return 1000 * ms, 1000 * max_ms, 1000 * min_ms return benchmark def run_benchmark(save_path: str = "./configs/benchmark_ops/get_last_loc/"): """Run benchmark and save results""" # Ensure save path exists os.makedirs(save_path, exist_ok=True) # Run correctness test test_get_last_loc() print("Correctness test passed!") # Run performance test benchmark = get_benchmark() benchmark.run(print_data=True, save_path=save_path) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--save_path", type=str, default="./configs/benchmark_ops/get_last_loc/", help="Path to save benchmark results", ) args = parser.parse_args() run_benchmark(args.save_path)