Files
enginex-mlu370-vllm/vllm-v0.6.2/tests/kernels/test_flash_attention.py
2026-02-04 17:22:39 +08:00

234 lines
11 KiB
Python

import numpy as np
import math
import random
import time
import pytest
import torch
from vllm_mlu.attention.ops.triton_flash_attention import triton_attention
class SelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
# special alibi only support causal
def build_alibi(self, slopes, block_size, n_heads, dtype):
device ='mlu'
tril = torch.tril(torch.ones(1,1 , block_size, block_size, device = device))
bias_rows = torch.arange( block_size, device=device).view(1, -1)
bias_cols = torch.arange( block_size, device=device).view(-1, 1)
bias = - torch.sqrt(bias_cols - bias_rows)
bias = bias.view(1, block_size, block_size) * slopes.view(-1, 1, 1)
bias = bias.masked_fill(tril == 0, float('-inf'))
return bias.type(dtype)
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cur_seq_len_t:torch.Tensor, alibi_slope:torch.Tensor, attn_bias:torch.Tensor):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, T, H, D)
k: The tensor containing the key. (B, T, H, D)
v: The tensor containing the value. (B, T, H, D)
cur_seq_len_t: true_seq_lens. (B+1)
alibi_slope: (H) or (B, H)
attn_bias: (B,H,T,T) or (B,T,T)
"""
batch = q.shape[0]
seq_q = q.shape[1]
seq_k = k.shape[1]
head = q.shape[2]
scores = torch.einsum('bthd,bshd->bhts', q, k )* self.softmax_scale
# mask
if alibi_slope is not None:
slope = torch.zeros((batch, head)).mlu()
if len(alibi_slope.shape) == 1 :
slope[:,]=alibi_slope
else:
slope=alibi_slope
slope = slope.reshape(batch, head, 1, 1)
slope_bias = torch.zeros(batch, head, seq_q, seq_k).mlu()
if self.causal:
relative_pos = torch.arange(-seq_k + 1, 1, dtype=torch.float32).mlu()
slope_bias = relative_pos * slope
else:
row_idx = torch.arange(seq_q, dtype=torch.long).reshape(-1, 1)
col_idx = torch.arange(seq_k, dtype=torch.long)
relative_pos = torch.abs(row_idx + seq_k - seq_q - col_idx).mlu()
slope_bias = -slope * relative_pos.to(dtype=slope.dtype)
# if use special alibi
# slope_bias = self.build_alibi(alibi_slope, seq_k, head, dtype=torch.float32)
scores += slope_bias
if attn_bias is not None:
if len(attn_bias.shape) == 3:
scores += attn_bias.unsqueeze(1)
else:
scores +=attn_bias
if self.causal:
causal_mask = torch.triu(torch.full((seq_q, seq_k), -10000.0, device=scores.device), 1)
scores = scores + causal_mask.to(dtype=scores.dtype)
else: # fill -inf in pad_area
for b in range(batch):
true_seq_len = cur_seq_len_t[b + 1] - cur_seq_len_t[b]
scores[b, ..., true_seq_len:] = -10000.0
scores[b, :, true_seq_len:, :] = -10000.0
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
output = torch.einsum('bhts,bshd->bthd', attention, v)
return output.contiguous()
NUM_HEADS = [64, 256]
NUM_QUERIES_PER_KV = [1]
HEAD_SIZES = [96]
DTYPES = [torch.float16, torch.float32]
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_contexted_kv_attention(num_heads: int, num_queries_per_kv: int,
head_size: int, dtype: torch.dtype) -> None:
"""
split test case head_size 96 cause multi tests in one pytest will conflict memory.
"""
device="cuda"
MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024
BS = 10
cache_size = 640
block_size = 32
max_block_per_request = 64
random.seed(1)
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads
num_tokens = sum(query_lens)
max_seqlens_q = max(query_lens)
max_seqlens_k = max(query_lens)
cu_seqlens = [0]
for value in query_lens:
cu_seqlens.append(cu_seqlens[-1] + value)
cu_seqlens_q = torch.tensor(cu_seqlens, dtype=torch.int, device=device)
cu_seqlens_k = torch.tensor(cu_seqlens, dtype=torch.int, device=device)
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype,device=device)
query.uniform_(-1e-3, 1e-3)
triton_output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype,device=device)
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype,device=device)
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)
k_cache = torch.zeros(cache_size,
block_size, num_kv_heads, head_size, dtype=dtype, device=device)
v_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size,
dtype=dtype, device=device)
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype,device=device)
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype,device=device)
values = torch.arange(0, cache_size, dtype=torch.long,device=device)
values = values[torch.randperm(cache_size,device=device)]
block_table = values[:BS * max_block_per_request].view(BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long,device=device)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long,device=device)
b_start_loc = torch.cumsum(torch.tensor(
[0] + query_lens[:-1], dtype=torch.long,device=device), dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long,device=device), dim=0)
for i in range(BS):
for j in range(query_lens[i]):
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
cur_ctx = 0
block_id = 0
while cur_ctx < b_ctx_len[i]:
start_loc = b_seq_start_loc[i] + cur_ctx
if cur_ctx + block_size > b_ctx_len[i]:
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
else:
end_loc = start_loc + block_size
start_slot = block_table[i, block_id] * block_size
end_slot = start_slot + end_loc - start_loc
k_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(key[start_loc:end_loc])
v_cache.view(-1, num_kv_heads,
head_size)[start_slot:end_slot].copy_(value[start_loc:end_loc])
cur_ctx += block_size
block_id += 1
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
8).permute(0, 2, 3, 1, 4).contiguous()
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous()
triton_output,_ = triton_attention(query,
k,
v,
None,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k)
triton_output_cpu = triton_output.to(device='cpu')
def copy_pack_data_to_pad_data(pad_input: torch.Tensor,
packed_input_list: list,
t_len_sequence: list,
max_sequence_len: int):
end_index1 = 0
for index in range(len(t_len_sequence)):
start_index1 = end_index1
end_index1 = end_index1 + t_len_sequence[index]
start_index = index * max_sequence_len
end_index = start_index + t_len_sequence[index]
pad_input[start_index:end_index, ...] = packed_input_list[start_index1:end_index1, ...]
pad_input_q = torch.zeros((MAX_SEQ_LEN * BS, num_heads, head_size)).mlu().half()
pad_input_k = torch.zeros((MAX_SEQ_LEN * BS, num_kv_heads, head_size)).mlu().half()
pad_input_v = torch.zeros((MAX_SEQ_LEN * BS, num_kv_heads, head_size)).mlu().half()
copy_pack_data_to_pad_data(pad_input_q, query, query_lens, MAX_SEQ_LEN)
copy_pack_data_to_pad_data(pad_input_k, k, query_lens, MAX_SEQ_LEN)
copy_pack_data_to_pad_data(pad_input_v, v, query_lens, MAX_SEQ_LEN)
softmax_scale = 1 / math.sqrt(head_size)
attention = SelfAttention(causal = False, softmax_scale=softmax_scale)
torch_output = attention(pad_input_q.view(BS, MAX_SEQ_LEN, num_heads, head_size),
pad_input_k.view(BS, MAX_SEQ_LEN, num_kv_heads, head_size),
pad_input_v.view(BS, MAX_SEQ_LEN, num_kv_heads, head_size),
cu_seqlens_q,None,None)
pad_triton_output = torch_output.clone().view(BS * MAX_SEQ_LEN, num_heads, head_size)
copy_pack_data_to_pad_data(pad_triton_output, triton_output_cpu, query_lens, MAX_SEQ_LEN)
view_triton_output = pad_triton_output.view(BS, MAX_SEQ_LEN, num_heads, head_size)
torch.testing.assert_close(view_triton_output, torch_output)
HEAD_SIZES = [24, 128]
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode()
def test_contexted_kv_attention_1(num_heads: int, num_queries_per_kv: int, head_size: int,
dtype: torch.dtype) -> None:
"""
split test case ihead_size 24, 128 cause multi tests in one pytest will conflict memory.
"""
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size, dtype)