[Feat] QWen-1M context support[1/2]: Update block sparse attention backend utils kernel (#5847)

Co-authored-by: sighingnow <sighingnow@gmail.com>
This commit is contained in:
PGFLMG
2025-04-29 02:03:17 +08:00
committed by GitHub
parent d364b9b0f2
commit ee71ed8a41
6 changed files with 763 additions and 1 deletions

View File

@@ -8,6 +8,124 @@ def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
# Sparse attention utils
def convert_vertical_slash_indexes(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.zeros(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
block_offset = torch.zeros(
batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
column_count = torch.zeros(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
column_index = torch.zeros(
batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
torch.ops.sgl_kernel.convert_vertical_slash_indexes.default(
block_count,
block_offset,
column_count,
column_index,
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
context_size,
block_size_M,
block_size_N,
causal,
)
return block_count, block_offset, column_count, column_index
def convert_vertical_slash_indexes_mergehead(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
# [N_HEADS] : different head use different number of indices
vertical_indices_count: torch.Tensor,
slash_indices_count: torch.Tensor,
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.empty(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
block_offset = torch.empty(
batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
column_count = torch.empty(
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
)
column_index = torch.empty(
batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device,
)
torch.ops.sgl_kernel.convert_vertical_slash_indexes_mergehead.default(
block_count,
block_offset,
column_count,
column_index,
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
vertical_indices_count,
slash_indices_count,
context_size,
block_size_M,
block_size_N,
causal,
)
return block_count, block_offset, column_count, column_index
def sparse_attn_func(
q,
k,