add qwen3
This commit is contained in:
233
vllm-v0.6.2/tests/kernels/test_flash_attention.py
Normal file
233
vllm-v0.6.2/tests/kernels/test_flash_attention.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm_mlu.attention.ops.triton_flash_attention import triton_attention
|
||||
class SelfAttention(torch.nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
def __init__(self, causal=False, softmax_scale=None):
|
||||
super().__init__()
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
|
||||
# special alibi only support causal
|
||||
def build_alibi(self, slopes, block_size, n_heads, dtype):
|
||||
device ='mlu'
|
||||
tril = torch.tril(torch.ones(1,1 , block_size, block_size, device = device))
|
||||
bias_rows = torch.arange( block_size, device=device).view(1, -1)
|
||||
bias_cols = torch.arange( block_size, device=device).view(-1, 1)
|
||||
bias = - torch.sqrt(bias_cols - bias_rows)
|
||||
bias = bias.view(1, block_size, block_size) * slopes.view(-1, 1, 1)
|
||||
bias = bias.masked_fill(tril == 0, float('-inf'))
|
||||
return bias.type(dtype)
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cur_seq_len_t:torch.Tensor, alibi_slope:torch.Tensor, attn_bias:torch.Tensor):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
q: The tensor containing the query. (B, T, H, D)
|
||||
k: The tensor containing the key. (B, T, H, D)
|
||||
v: The tensor containing the value. (B, T, H, D)
|
||||
cur_seq_len_t: true_seq_lens. (B+1)
|
||||
alibi_slope: (H) or (B, H)
|
||||
attn_bias: (B,H,T,T) or (B,T,T)
|
||||
"""
|
||||
batch = q.shape[0]
|
||||
seq_q = q.shape[1]
|
||||
seq_k = k.shape[1]
|
||||
head = q.shape[2]
|
||||
scores = torch.einsum('bthd,bshd->bhts', q, k )* self.softmax_scale
|
||||
# mask
|
||||
if alibi_slope is not None:
|
||||
slope = torch.zeros((batch, head)).mlu()
|
||||
if len(alibi_slope.shape) == 1 :
|
||||
slope[:,]=alibi_slope
|
||||
else:
|
||||
slope=alibi_slope
|
||||
slope = slope.reshape(batch, head, 1, 1)
|
||||
slope_bias = torch.zeros(batch, head, seq_q, seq_k).mlu()
|
||||
if self.causal:
|
||||
relative_pos = torch.arange(-seq_k + 1, 1, dtype=torch.float32).mlu()
|
||||
slope_bias = relative_pos * slope
|
||||
else:
|
||||
row_idx = torch.arange(seq_q, dtype=torch.long).reshape(-1, 1)
|
||||
col_idx = torch.arange(seq_k, dtype=torch.long)
|
||||
relative_pos = torch.abs(row_idx + seq_k - seq_q - col_idx).mlu()
|
||||
slope_bias = -slope * relative_pos.to(dtype=slope.dtype)
|
||||
# if use special alibi
|
||||
# slope_bias = self.build_alibi(alibi_slope, seq_k, head, dtype=torch.float32)
|
||||
|
||||
scores += slope_bias
|
||||
if attn_bias is not None:
|
||||
if len(attn_bias.shape) == 3:
|
||||
scores += attn_bias.unsqueeze(1)
|
||||
else:
|
||||
scores +=attn_bias
|
||||
if self.causal:
|
||||
causal_mask = torch.triu(torch.full((seq_q, seq_k), -10000.0, device=scores.device), 1)
|
||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
else: # fill -inf in pad_area
|
||||
for b in range(batch):
|
||||
true_seq_len = cur_seq_len_t[b + 1] - cur_seq_len_t[b]
|
||||
scores[b, ..., true_seq_len:] = -10000.0
|
||||
scores[b, :, true_seq_len:, :] = -10000.0
|
||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||
output = torch.einsum('bhts,bshd->bthd', attention, v)
|
||||
return output.contiguous()
|
||||
|
||||
|
||||
NUM_HEADS = [64, 256]
|
||||
NUM_QUERIES_PER_KV = [1]
|
||||
HEAD_SIZES = [96]
|
||||
DTYPES = [torch.float16, torch.float32]
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention(num_heads: int, num_queries_per_kv: int,
|
||||
head_size: int, dtype: torch.dtype) -> None:
|
||||
"""
|
||||
split test case head_size 96 cause multi tests in one pytest will conflict memory.
|
||||
"""
|
||||
device="cuda"
|
||||
|
||||
MAX_SEQ_LEN = 1024
|
||||
MAX_CTX_LEN = 1024
|
||||
BS = 10
|
||||
cache_size = 640
|
||||
block_size = 32
|
||||
max_block_per_request = 64
|
||||
random.seed(1)
|
||||
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
|
||||
ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
|
||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||
num_kv_heads = num_heads
|
||||
|
||||
num_tokens = sum(query_lens)
|
||||
max_seqlens_q = max(query_lens)
|
||||
max_seqlens_k = max(query_lens)
|
||||
cu_seqlens = [0]
|
||||
for value in query_lens:
|
||||
cu_seqlens.append(cu_seqlens[-1] + value)
|
||||
cu_seqlens_q = torch.tensor(cu_seqlens, dtype=torch.int, device=device)
|
||||
cu_seqlens_k = torch.tensor(cu_seqlens, dtype=torch.int, device=device)
|
||||
|
||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype,device=device)
|
||||
query.uniform_(-1e-3, 1e-3)
|
||||
triton_output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype,device=device)
|
||||
|
||||
kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype,device=device)
|
||||
kv.uniform_(-1e-3, 1e-3)
|
||||
key, value = kv.unbind(dim=1)
|
||||
|
||||
k_cache = torch.zeros(cache_size,
|
||||
block_size, num_kv_heads, head_size, dtype=dtype, device=device)
|
||||
v_cache = torch.zeros(cache_size, block_size, num_kv_heads, head_size,
|
||||
dtype=dtype, device=device)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype,device=device)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype,device=device)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long,device=device)
|
||||
values = values[torch.randperm(cache_size,device=device)]
|
||||
block_table = values[:BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long,device=device)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long,device=device)
|
||||
b_start_loc = torch.cumsum(torch.tensor(
|
||||
[0] + query_lens[:-1], dtype=torch.long,device=device), dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
||||
dtype=torch.long,device=device), dim=0)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
|
||||
cur_ctx = 0
|
||||
block_id = 0
|
||||
while cur_ctx < b_ctx_len[i]:
|
||||
start_loc = b_seq_start_loc[i] + cur_ctx
|
||||
if cur_ctx + block_size > b_ctx_len[i]:
|
||||
end_loc = b_seq_start_loc[i] + b_ctx_len[i]
|
||||
else:
|
||||
end_loc = start_loc + block_size
|
||||
start_slot = block_table[i, block_id] * block_size
|
||||
end_slot = start_slot + end_loc - start_loc
|
||||
k_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(key[start_loc:end_loc])
|
||||
v_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(value[start_loc:end_loc])
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
||||
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
|
||||
8).permute(0, 2, 3, 1, 4).contiguous()
|
||||
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
triton_output,_ = triton_attention(query,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlens_q,
|
||||
max_seqlens_k)
|
||||
triton_output_cpu = triton_output.to(device='cpu')
|
||||
def copy_pack_data_to_pad_data(pad_input: torch.Tensor,
|
||||
packed_input_list: list,
|
||||
t_len_sequence: list,
|
||||
max_sequence_len: int):
|
||||
end_index1 = 0
|
||||
for index in range(len(t_len_sequence)):
|
||||
start_index1 = end_index1
|
||||
end_index1 = end_index1 + t_len_sequence[index]
|
||||
start_index = index * max_sequence_len
|
||||
end_index = start_index + t_len_sequence[index]
|
||||
pad_input[start_index:end_index, ...] = packed_input_list[start_index1:end_index1, ...]
|
||||
|
||||
pad_input_q = torch.zeros((MAX_SEQ_LEN * BS, num_heads, head_size)).mlu().half()
|
||||
pad_input_k = torch.zeros((MAX_SEQ_LEN * BS, num_kv_heads, head_size)).mlu().half()
|
||||
pad_input_v = torch.zeros((MAX_SEQ_LEN * BS, num_kv_heads, head_size)).mlu().half()
|
||||
copy_pack_data_to_pad_data(pad_input_q, query, query_lens, MAX_SEQ_LEN)
|
||||
copy_pack_data_to_pad_data(pad_input_k, k, query_lens, MAX_SEQ_LEN)
|
||||
copy_pack_data_to_pad_data(pad_input_v, v, query_lens, MAX_SEQ_LEN)
|
||||
softmax_scale = 1 / math.sqrt(head_size)
|
||||
attention = SelfAttention(causal = False, softmax_scale=softmax_scale)
|
||||
torch_output = attention(pad_input_q.view(BS, MAX_SEQ_LEN, num_heads, head_size),
|
||||
pad_input_k.view(BS, MAX_SEQ_LEN, num_kv_heads, head_size),
|
||||
pad_input_v.view(BS, MAX_SEQ_LEN, num_kv_heads, head_size),
|
||||
cu_seqlens_q,None,None)
|
||||
pad_triton_output = torch_output.clone().view(BS * MAX_SEQ_LEN, num_heads, head_size)
|
||||
copy_pack_data_to_pad_data(pad_triton_output, triton_output_cpu, query_lens, MAX_SEQ_LEN)
|
||||
view_triton_output = pad_triton_output.view(BS, MAX_SEQ_LEN, num_heads, head_size)
|
||||
torch.testing.assert_close(view_triton_output, torch_output)
|
||||
|
||||
|
||||
HEAD_SIZES = [24, 128]
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention_1(num_heads: int, num_queries_per_kv: int, head_size: int,
|
||||
dtype: torch.dtype) -> None:
|
||||
"""
|
||||
split test case ihead_size 24, 128 cause multi tests in one pytest will conflict memory.
|
||||
"""
|
||||
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size, dtype)
|
||||
Reference in New Issue
Block a user