[feat] triton kernel for get_last_loc (#6676)
This commit is contained in:
@@ -1810,10 +1810,72 @@ def write_req_to_token_pool_triton(
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def get_last_loc(req_to_token, req_pool_indices_tensor, prefix_lens_tensor):
|
||||
def get_last_loc(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices_tensor: torch.Tensor,
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if global_server_args_dict["attention_backend"] != "torch_native":
|
||||
impl = get_last_loc_triton
|
||||
else:
|
||||
impl = get_last_loc_torch
|
||||
|
||||
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
|
||||
|
||||
|
||||
def get_last_loc_torch(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices_tensor: torch.Tensor,
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.where(
|
||||
prefix_lens_tensor > 0,
|
||||
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
|
||||
torch.full_like(prefix_lens_tensor, -1),
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def get_last_loc_kernel(
|
||||
req_to_token,
|
||||
req_pool_indices_tensor,
|
||||
prefix_lens_tensor,
|
||||
result,
|
||||
num_tokens,
|
||||
req_to_token_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
|
||||
mask = offset < num_tokens
|
||||
|
||||
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
|
||||
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
|
||||
|
||||
token_mask = prefix_lens > 0
|
||||
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
|
||||
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
|
||||
|
||||
tl.store(result + offset, tokens, mask=mask)
|
||||
|
||||
|
||||
def get_last_loc_triton(
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices_tensor: torch.Tensor,
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
BLOCK_SIZE = 256
|
||||
num_tokens = prefix_lens_tensor.shape[0]
|
||||
result = torch.empty_like(prefix_lens_tensor)
|
||||
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
|
||||
|
||||
get_last_loc_kernel[grid](
|
||||
req_to_token,
|
||||
req_pool_indices_tensor,
|
||||
prefix_lens_tensor,
|
||||
result,
|
||||
num_tokens,
|
||||
req_to_token.stride(0),
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user