[Feat] QWen-1M context support[1/2]: Update block sparse attention backend utils kernel (#5847)
Co-authored-by: sighingnow <sighingnow@gmail.com>
This commit is contained in:
@@ -8,6 +8,124 @@ def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
|
||||
# Sparse attention utils
|
||||
def convert_vertical_slash_indexes(
|
||||
q_seqlens: torch.Tensor, # [BATCH, ]
|
||||
kv_seqlens: torch.Tensor, # [BATCH, ]
|
||||
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
||||
context_size: int,
|
||||
block_size_M: int,
|
||||
block_size_N: int,
|
||||
causal: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = slash_indexes.size(0)
|
||||
num_heads = slash_indexes.size(1)
|
||||
nnz_slash = slash_indexes.size(2)
|
||||
nnz_vertical = vertical_indexes.size(2)
|
||||
num_rows = (context_size + block_size_M - 1) // block_size_M
|
||||
|
||||
block_count = torch.zeros(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
block_offset = torch.zeros(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_slash,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
column_count = torch.zeros(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
column_index = torch.zeros(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_vertical,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.convert_vertical_slash_indexes.default(
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
context_size,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
causal,
|
||||
)
|
||||
return block_count, block_offset, column_count, column_index
|
||||
|
||||
|
||||
def convert_vertical_slash_indexes_mergehead(
|
||||
q_seqlens: torch.Tensor, # [BATCH, ]
|
||||
kv_seqlens: torch.Tensor, # [BATCH, ]
|
||||
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
||||
# [N_HEADS] : different head use different number of indices
|
||||
vertical_indices_count: torch.Tensor,
|
||||
slash_indices_count: torch.Tensor,
|
||||
context_size: int,
|
||||
block_size_M: int,
|
||||
block_size_N: int,
|
||||
causal: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = slash_indexes.size(0)
|
||||
num_heads = slash_indexes.size(1)
|
||||
nnz_slash = slash_indexes.size(2)
|
||||
nnz_vertical = vertical_indexes.size(2)
|
||||
num_rows = (context_size + block_size_M - 1) // block_size_M
|
||||
|
||||
block_count = torch.empty(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
block_offset = torch.empty(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_slash,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
column_count = torch.empty(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
column_index = torch.empty(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_vertical,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.convert_vertical_slash_indexes_mergehead.default(
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
vertical_indices_count,
|
||||
slash_indices_count,
|
||||
context_size,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
causal,
|
||||
)
|
||||
return block_count, block_offset, column_count, column_index
|
||||
|
||||
|
||||
def sparse_attn_func(
|
||||
q,
|
||||
k,
|
||||
|
||||
Reference in New Issue
Block a user