From 2163586e6312e9b5ca9da8e6ca49d8a082d62a09 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Thu, 29 May 2025 14:10:28 +0800 Subject: [PATCH] [feat] triton kernel for get_last_loc (#6676) --- python/sglang/srt/managers/schedule_batch.py | 66 +++++++++++++++++++- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ee613c573..59b5471b6 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1810,10 +1810,72 @@ def write_req_to_token_pool_triton( ) -@torch.compile(dynamic=True, backend=get_compiler_backend()) -def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor): +def get_last_loc( + req_to_token: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, +) -> torch.Tensor: + if global_server_args_dict["attention_backend"] != "torch_native": + impl = get_last_loc_triton + else: + impl = get_last_loc_torch + + return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor) + + +def get_last_loc_torch( + req_to_token: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, +) -> torch.Tensor: return torch.where( prefix_lens_tensor > 0, req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1], torch.full_like(prefix_lens_tensor, -1), ) + + +@triton.jit +def get_last_loc_kernel( + req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + result, + num_tokens, + req_to_token_stride, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE + mask = offset < num_tokens + + prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0) + req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0) + + token_mask = prefix_lens > 0 + token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1) + tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1) + + tl.store(result + offset, tokens, mask=mask) + + +def get_last_loc_triton( + req_to_token: torch.Tensor, + req_pool_indices_tensor: torch.Tensor, + prefix_lens_tensor: torch.Tensor, +) -> torch.Tensor: + BLOCK_SIZE = 256 + num_tokens = prefix_lens_tensor.shape[0] + result = torch.empty_like(prefix_lens_tensor) + grid = (triton.cdiv(num_tokens, BLOCK_SIZE),) + + get_last_loc_kernel[grid]( + req_to_token, + req_pool_indices_tensor, + prefix_lens_tensor, + result, + num_tokens, + req_to_token.stride(0), + BLOCK_SIZE, + ) + return result