Files
enginex-c_series-vllm/vllm/attention/ops/blocksparse_attention/utils.py

247 lines
8.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Helper functions for 3D sparse pattern
# These function are not optimized and very inefficient.
# Avoid calling them too frequent or use a cache mechanism.
from functools import lru_cache
import numpy as np
import torch
from vllm.triton_utils import triton
class csr_matrix:
"""Simple implementation of CSR matrix conversion without scipy.
This replaced scipy.sparse.csr_matrix() previously used."""
def __init__(self, input_array):
if not isinstance(input_array, np.ndarray):
raise ValueError("Input must be a NumPy array")
self.shape = input_array.shape
rows, cols = self.shape
data = []
indices = []
indptr = [0]
for i in range(rows):
for j in range(cols):
if input_array[i, j]:
data.append(input_array[i, j])
indices.append(j)
indptr.append(len(indices))
self.data = np.array(data)
self.indices = np.array(indices)
self.indptr = np.array(indptr)
def dense_to_crow_col(x: torch.Tensor):
"""Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
NOTE: col_indices padded -1
"""
device = x.device
pad = -1
dim = x.dim()
assert x.dim() in (2, 3)
if x.dim() == 2:
x = x[None]
x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
cols = [torch.from_numpy(xi.indices) for xi in x]
max_cols = max(len(xi) for xi in cols)
cols = [
torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
for xi in cols
]
cols = torch.vstack(cols)
if dim == 2:
crows = crows[0]
cols = cols[0]
return crows.to(device), cols.to(device)
def crow_col_to_dense(crows: torch.Tensor,
cols: torch.Tensor,
dtype: torch.dtype = torch.float16):
dim = crows.dim()
if dim == 1:
crows = crows[None]
cols = cols[None]
device = crows.device
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
x = torch.zeros(shape, dtype=dtype)
for i in range(shape[0]):
for j in range(shape[1]):
x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
if dim == 1:
x = x[0]
return x.to(device)
def dense_to_ccol_row(x: torch.Tensor):
"""Similar, but to CSC format"""
x = x.transpose(-2, -1)
return dense_to_crow_col(x)
def ccol_row_to_dense(ccol: torch.Tensor,
rows: torch.Tensor,
dtype: torch.dtype = torch.float16):
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
def _get_sparse_attn_mask_homo_head(
q_len: int,
max_seqlen: int,
dtype: torch.dtype,
device: torch.device,
block_size: int = 128,
local_blocks: int = 4,
vert_stride: int = 4,
return_dense: bool = False,
):
"""
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`,
otherwise, None
"""
with torch.no_grad():
num_blocks = triton.cdiv(max_seqlen, block_size)
q_pos = torch.arange(num_blocks)[:, None]
k_pos = torch.arange(num_blocks)[None]
mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
block_mask_dense = (((q_pos >= k_pos)
& ((q_pos - k_pos < local_blocks)
| mask_vert_strided)).to(device).to(dtype))
num_blocks_q = triton.cdiv(q_len, block_size)
block_mask_dense_output = (dense_to_crow_col(
block_mask_dense[-num_blocks_q:].contiguous()))
if return_dense:
mask_dense = torch.kron(
block_mask_dense,
block_mask_dense.new_ones((block_size, block_size)),
)
causal_mask = torch.tril(torch.ones(
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
return (
block_mask_dense_output,
block_mask_dense,
mask_dense,
)
else:
return (
block_mask_dense_output,
block_mask_dense,
None,
)
def binary_mask_to_bias(mask_dense: torch.Tensor):
mask_dense = 1 - mask_dense
mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
return mask_dense
def get_head_sliding_step(n_heads: int,
vert_stride: int,
homo_head: bool = False):
if homo_head:
return 0
return max(1, int(vert_stride / n_heads))
@lru_cache
def get_sparse_attn_mask(
n_heads: int,
q_len: int,
max_seqlen: int,
dtype: torch.dtype,
device: torch.device,
block_size: int = 64,
local_blocks: int = 4,
vert_stride: int = 4,
homo_head: bool = True,
return_dense: bool = False,
dense_mask_type: str = "binary",
):
"""
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
or "bias" (-inf for skip token, 0 or others)
:return: a tuple of 3:
- tuple of crow_indices, col_indices representation
of CSR format.
- block dense mask
- all token dense mask (be aware that it can be OOM if it
is too big) if `return_dense==True`, otherwise, None
"""
assert dense_mask_type in ("binary", "bias")
if homo_head:
with torch.no_grad():
(crow, col), block_mask_dense, mask_dense = (
_get_sparse_attn_mask_homo_head(
q_len,
max_seqlen,
dtype,
device,
block_size,
local_blocks,
vert_stride,
return_dense,
))
crow = crow[None].expand(n_heads, crow.shape[0])
col = col[None].expand(n_heads, col.shape[0])
if return_dense:
mask_dense = mask_dense[None].expand(n_heads,
*mask_dense.shape)
if dense_mask_type == "bias":
mask_dense = binary_mask_to_bias(mask_dense)
return (crow, col), block_mask_dense, mask_dense
with torch.no_grad():
num_blocks = triton.cdiv(max_seqlen, block_size)
q_pos = torch.arange(num_blocks)[None, :, None]
k_pos = torch.arange(num_blocks)[None, None]
head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
mask_vert_strided = [
(torch.arange(num_blocks) + h * head_sliding_step + 1) %
vert_stride == 0 for h in range(n_heads)
]
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
block_mask_dense = (((q_pos >= k_pos)
& ((q_pos - k_pos < local_blocks)
| mask_vert_strided)).to(device).to(dtype))
num_blocks_q = triton.cdiv(q_len, block_size)
block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
if return_dense:
mask_dense = torch.kron(
block_mask_dense,
block_mask_dense.new_ones((block_size, block_size)),
)
causal_mask = torch.tril(torch.ones(
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
if dense_mask_type == "bias":
mask_dense = binary_mask_to_bias(mask_dense)
return (
dense_to_crow_col(block_mask_dense_output),
block_mask_dense,
mask_dense,
)
else:
return (
dense_to_crow_col(block_mask_dense_output),
block_mask_dense,
None,
)