247 lines
8.0 KiB
Python
247 lines
8.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
# Helper functions for 3D sparse pattern
|
|
# These function are not optimized and very inefficient.
|
|
# Avoid calling them too frequent or use a cache mechanism.
|
|
|
|
from functools import lru_cache
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from vllm.triton_utils import triton
|
|
|
|
|
|
class csr_matrix:
|
|
"""Simple implementation of CSR matrix conversion without scipy.
|
|
This replaced scipy.sparse.csr_matrix() previously used."""
|
|
|
|
def __init__(self, input_array):
|
|
if not isinstance(input_array, np.ndarray):
|
|
raise ValueError("Input must be a NumPy array")
|
|
|
|
self.shape = input_array.shape
|
|
rows, cols = self.shape
|
|
data = []
|
|
indices = []
|
|
indptr = [0]
|
|
|
|
for i in range(rows):
|
|
for j in range(cols):
|
|
if input_array[i, j]:
|
|
data.append(input_array[i, j])
|
|
indices.append(j)
|
|
indptr.append(len(indices))
|
|
|
|
self.data = np.array(data)
|
|
self.indices = np.array(indices)
|
|
self.indptr = np.array(indptr)
|
|
|
|
|
|
def dense_to_crow_col(x: torch.Tensor):
|
|
"""Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
|
|
NOTE: col_indices padded -1
|
|
"""
|
|
device = x.device
|
|
pad = -1
|
|
dim = x.dim()
|
|
assert x.dim() in (2, 3)
|
|
if x.dim() == 2:
|
|
x = x[None]
|
|
x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
|
|
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
|
|
cols = [torch.from_numpy(xi.indices) for xi in x]
|
|
max_cols = max(len(xi) for xi in cols)
|
|
cols = [
|
|
torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
|
|
for xi in cols
|
|
]
|
|
cols = torch.vstack(cols)
|
|
if dim == 2:
|
|
crows = crows[0]
|
|
cols = cols[0]
|
|
return crows.to(device), cols.to(device)
|
|
|
|
|
|
def crow_col_to_dense(crows: torch.Tensor,
|
|
cols: torch.Tensor,
|
|
dtype: torch.dtype = torch.float16):
|
|
dim = crows.dim()
|
|
if dim == 1:
|
|
crows = crows[None]
|
|
cols = cols[None]
|
|
device = crows.device
|
|
crows, cols = crows.cpu(), cols.cpu() # faster in cpu
|
|
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
|
|
x = torch.zeros(shape, dtype=dtype)
|
|
for i in range(shape[0]):
|
|
for j in range(shape[1]):
|
|
x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
|
|
if dim == 1:
|
|
x = x[0]
|
|
return x.to(device)
|
|
|
|
|
|
def dense_to_ccol_row(x: torch.Tensor):
|
|
"""Similar, but to CSC format"""
|
|
x = x.transpose(-2, -1)
|
|
return dense_to_crow_col(x)
|
|
|
|
|
|
def ccol_row_to_dense(ccol: torch.Tensor,
|
|
rows: torch.Tensor,
|
|
dtype: torch.dtype = torch.float16):
|
|
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
|
|
|
|
|
|
def _get_sparse_attn_mask_homo_head(
|
|
q_len: int,
|
|
max_seqlen: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
block_size: int = 128,
|
|
local_blocks: int = 4,
|
|
vert_stride: int = 4,
|
|
return_dense: bool = False,
|
|
):
|
|
"""
|
|
:return: a tuple of 3:
|
|
- tuple of crow_indices, col_indices representation
|
|
of CSR format.
|
|
- block dense mask
|
|
- all token dense mask (be aware that it can be
|
|
OOM if it is too big) if `return_dense==True`,
|
|
otherwise, None
|
|
"""
|
|
with torch.no_grad():
|
|
num_blocks = triton.cdiv(max_seqlen, block_size)
|
|
q_pos = torch.arange(num_blocks)[:, None]
|
|
k_pos = torch.arange(num_blocks)[None]
|
|
mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
|
|
block_mask_dense = (((q_pos >= k_pos)
|
|
& ((q_pos - k_pos < local_blocks)
|
|
| mask_vert_strided)).to(device).to(dtype))
|
|
num_blocks_q = triton.cdiv(q_len, block_size)
|
|
block_mask_dense_output = (dense_to_crow_col(
|
|
block_mask_dense[-num_blocks_q:].contiguous()))
|
|
if return_dense:
|
|
mask_dense = torch.kron(
|
|
block_mask_dense,
|
|
block_mask_dense.new_ones((block_size, block_size)),
|
|
)
|
|
causal_mask = torch.tril(torch.ones(
|
|
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
|
mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
|
|
return (
|
|
block_mask_dense_output,
|
|
block_mask_dense,
|
|
mask_dense,
|
|
)
|
|
else:
|
|
return (
|
|
block_mask_dense_output,
|
|
block_mask_dense,
|
|
None,
|
|
)
|
|
|
|
|
|
def binary_mask_to_bias(mask_dense: torch.Tensor):
|
|
mask_dense = 1 - mask_dense
|
|
mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
|
|
return mask_dense
|
|
|
|
|
|
def get_head_sliding_step(n_heads: int,
|
|
vert_stride: int,
|
|
homo_head: bool = False):
|
|
if homo_head:
|
|
return 0
|
|
return max(1, int(vert_stride / n_heads))
|
|
|
|
|
|
@lru_cache
|
|
def get_sparse_attn_mask(
|
|
n_heads: int,
|
|
q_len: int,
|
|
max_seqlen: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
block_size: int = 64,
|
|
local_blocks: int = 4,
|
|
vert_stride: int = 4,
|
|
homo_head: bool = True,
|
|
return_dense: bool = False,
|
|
dense_mask_type: str = "binary",
|
|
):
|
|
"""
|
|
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
|
|
or "bias" (-inf for skip token, 0 or others)
|
|
:return: a tuple of 3:
|
|
- tuple of crow_indices, col_indices representation
|
|
of CSR format.
|
|
- block dense mask
|
|
- all token dense mask (be aware that it can be OOM if it
|
|
is too big) if `return_dense==True`, otherwise, None
|
|
"""
|
|
assert dense_mask_type in ("binary", "bias")
|
|
if homo_head:
|
|
with torch.no_grad():
|
|
(crow, col), block_mask_dense, mask_dense = (
|
|
_get_sparse_attn_mask_homo_head(
|
|
q_len,
|
|
max_seqlen,
|
|
dtype,
|
|
device,
|
|
block_size,
|
|
local_blocks,
|
|
vert_stride,
|
|
return_dense,
|
|
))
|
|
crow = crow[None].expand(n_heads, crow.shape[0])
|
|
col = col[None].expand(n_heads, col.shape[0])
|
|
if return_dense:
|
|
mask_dense = mask_dense[None].expand(n_heads,
|
|
*mask_dense.shape)
|
|
if dense_mask_type == "bias":
|
|
mask_dense = binary_mask_to_bias(mask_dense)
|
|
return (crow, col), block_mask_dense, mask_dense
|
|
|
|
with torch.no_grad():
|
|
num_blocks = triton.cdiv(max_seqlen, block_size)
|
|
q_pos = torch.arange(num_blocks)[None, :, None]
|
|
k_pos = torch.arange(num_blocks)[None, None]
|
|
head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
|
|
mask_vert_strided = [
|
|
(torch.arange(num_blocks) + h * head_sliding_step + 1) %
|
|
vert_stride == 0 for h in range(n_heads)
|
|
]
|
|
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
|
|
block_mask_dense = (((q_pos >= k_pos)
|
|
& ((q_pos - k_pos < local_blocks)
|
|
| mask_vert_strided)).to(device).to(dtype))
|
|
num_blocks_q = triton.cdiv(q_len, block_size)
|
|
block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
|
|
if return_dense:
|
|
mask_dense = torch.kron(
|
|
block_mask_dense,
|
|
block_mask_dense.new_ones((block_size, block_size)),
|
|
)
|
|
causal_mask = torch.tril(torch.ones(
|
|
max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
|
|
mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
|
|
if dense_mask_type == "bias":
|
|
mask_dense = binary_mask_to_bias(mask_dense)
|
|
|
|
return (
|
|
dense_to_crow_col(block_mask_dense_output),
|
|
block_mask_dense,
|
|
mask_dense,
|
|
)
|
|
else:
|
|
return (
|
|
dense_to_crow_col(block_mask_dense_output),
|
|
block_mask_dense,
|
|
None,
|
|
)
|