forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
0
vllm/v1/attention/backends/mla/__init__.py
Normal file
1114
vllm/v1/attention/backends/mla/common.py
Normal file
1114
vllm/v1/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
250
vllm/v1/attention/backends/mla/concatv3Tritonfinalv2.py
Normal file
250
vllm/v1/attention/backends/mla/concatv3Tritonfinalv2.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
from functools import reduce
|
||||
import pytest
|
||||
import torch
|
||||
import math
|
||||
|
||||
@pytest.mark.parametrize("shape_pair,dim", [
|
||||
|
||||
(((4, 8, 512), (4, 8, 64)), 2),
|
||||
(((8, 8, 512), (8, 8, 64)), 2),
|
||||
(((16, 8, 512), (16, 8, 64)), 2),
|
||||
(((32, 8, 512), (32, 8, 64)), 2),
|
||||
(((64, 8, 512), (64, 8, 64)), 2),
|
||||
(((128, 8, 512), (128, 8, 64)), 2),
|
||||
(((256, 8, 512), (256, 8, 64)), 2),
|
||||
(((512, 8, 512), (512, 8, 64)), 2),
|
||||
(((672, 8, 512), (672, 8, 64)), 2),
|
||||
(((768, 8, 512), (768, 8, 64)), 2),
|
||||
(((896, 8, 512), (896, 8, 64)), 2),
|
||||
(((1024, 8, 512), (1024, 8, 64)), 2),
|
||||
|
||||
(((4, 16, 512), (4, 16, 64)), 2),
|
||||
(((8, 16, 512), (8, 16, 64)), 2),
|
||||
(((16, 16, 512), (16, 16, 64)), 2),
|
||||
(((32, 16, 512), (32, 16, 64)), 2),
|
||||
(((64, 16, 512), (64, 16, 64)), 2),
|
||||
(((128, 16, 512), (128, 16, 64)), 2),
|
||||
(((256, 16, 512), (256, 16, 64)), 2),
|
||||
(((512, 16, 512), (512, 16, 64)), 2),
|
||||
(((672, 16, 512), (672, 16, 64)), 2),
|
||||
(((768, 16, 512), (768, 16, 64)), 2),
|
||||
(((896, 16, 512), (896, 16, 64)), 2),
|
||||
(((1024, 16, 512), (1024, 16, 64)), 2),
|
||||
|
||||
|
||||
(((4, 32, 512), (4, 32, 64)), 2),
|
||||
(((8, 32, 512), (8, 32, 64)), 2),
|
||||
(((16, 32, 512), (16, 32, 64)), 2),
|
||||
(((32, 32, 512), (32, 32, 64)), 2),
|
||||
(((64, 32, 512), (64, 32, 64)), 2),
|
||||
(((128, 32, 512), (128, 32, 64)), 2),
|
||||
(((256, 32, 512), (256, 32, 64)), 2),
|
||||
(((512, 32, 512), (512, 32, 64)), 2),
|
||||
(((672, 32, 512), (672, 32, 64)), 2),
|
||||
(((768, 32, 512), (768, 32, 64)), 2),
|
||||
(((896, 32, 512), (896, 32, 64)), 2),
|
||||
(((1024, 32, 512), (1024, 32, 64)), 2),
|
||||
|
||||
|
||||
|
||||
(((4, 32, 128), (4, 32, 64)), 2),
|
||||
(((8, 32, 128), (8, 32, 64)), 2),
|
||||
(((16, 32, 128), (16, 32, 64)), 2),
|
||||
(((32, 32, 128), (32, 32, 64)), 2),
|
||||
(((64, 32, 128), (64, 32, 64)), 2),
|
||||
(((128, 32, 128), (128, 32, 64)), 2),
|
||||
(((256, 32, 128), (256, 32, 64)), 2),
|
||||
(((512, 32, 128), (512, 32, 64)), 2),
|
||||
(((672, 32, 128), (672, 32, 64)), 2),
|
||||
(((768, 32, 128), (768, 32, 64)), 2),
|
||||
(((896, 32, 128), (896, 32, 64)), 2),
|
||||
(((1024, 32, 128), (1024, 32, 64)), 2),
|
||||
|
||||
])
|
||||
def test_concat_Acc(shape_pair, dim):
|
||||
|
||||
torch.manual_seed(1)
|
||||
shape1, shape2 = shape_pair
|
||||
x = torch.randn(*shape1, device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.randn(*shape2, device='cuda', dtype=torch.bfloat16)
|
||||
expected = torch.cat([x,y], dim=dim)
|
||||
result = concat_helper(x, y, dim=dim)
|
||||
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@triton.jit
|
||||
def concat_kernel_prefill(
|
||||
A_ptr, B_ptr, C_ptr,
|
||||
A_section_numel, B_section_numel, C_section_numel,
|
||||
Per_block,
|
||||
section_num,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
block_idx = tl.program_id(0)# 获取当前block的索引
|
||||
|
||||
for sub_section_index in range(Per_block//2):
|
||||
sub_section_offset = block_idx * Per_block + sub_section_index * 2
|
||||
if sub_section_offset <= section_num-1:
|
||||
|
||||
C_section_start = C_ptr + sub_section_offset * C_section_numel
|
||||
A_section_start = A_ptr + sub_section_offset * A_section_numel
|
||||
B_section_start = B_ptr + sub_section_offset * B_section_numel
|
||||
|
||||
Arrange_doubleA = tl.arange(0, 256)
|
||||
mask = Arrange_doubleA < (256)
|
||||
Arrange2 = (tl.arange(0, 128)[None,:] + tl.arange(0, 2)[:,None]).reshape(256)
|
||||
val_from_A = tl.load(A_section_start + Arrange_doubleA)
|
||||
tensorAsn = tl.full((256,), 0, tl.int32)
|
||||
tensorAsn2 = tl.full((256,), (C_section_numel-1), tl.int32)
|
||||
tensor_offsets = tl.where(Arrange_doubleA < A_section_numel,tensorAsn , tensorAsn2)
|
||||
off = Arrange2 + tensor_offsets
|
||||
tl.store(C_section_start + off,val_from_A,mask=mask)
|
||||
|
||||
Arrange_doubleB = tl.arange(0, 128)
|
||||
mask = Arrange_doubleB < (B_section_numel*2)
|
||||
val_from_B = tl.load(B_section_start + Arrange_doubleB,mask=mask)
|
||||
|
||||
Arrange3 = (tl.arange(0, 64)[None,:] + tl.arange(0, 2)[:,None]).reshape(128)
|
||||
tensorAsn = tl.full((128,), A_section_numel, tl.int32)
|
||||
tensorAsn2 = tl.full((128,), (C_section_numel + A_section_numel-1), tl.int32)
|
||||
tensor_offsets = tl.where(Arrange_doubleB < B_section_numel,tensorAsn , tensorAsn2)
|
||||
tl.store(C_section_start+ Arrange3 + tensor_offsets , val_from_B)
|
||||
|
||||
@triton.jit
|
||||
def concat_kernel(
|
||||
A_ptr, B_ptr, C_ptr,
|
||||
A_section_numel, B_section_numel, C_section_numel,
|
||||
Per_block,
|
||||
section_num,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
for sub_section_index in range(Per_block):
|
||||
sub_offset = block_idx * Per_block + sub_section_index
|
||||
if sub_offset <= section_num-1:
|
||||
C_ptr_block_start = C_ptr + sub_offset * C_section_numel
|
||||
A_ptr_block_start = A_ptr + sub_offset * A_section_numel
|
||||
B_ptr_block_start = B_ptr + sub_offset * B_section_numel
|
||||
|
||||
for offset in range(0, A_section_numel, BLOCK_SIZE):
|
||||
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset_idx < A_section_numel
|
||||
val_from_A = tl.load(A_ptr_block_start + offset_idx, mask=mask)
|
||||
tl.store(C_ptr_block_start + offset_idx, val_from_A, mask=mask)
|
||||
|
||||
for offset in range(0, B_section_numel, BLOCK_SIZE):
|
||||
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset_idx < B_section_numel
|
||||
val_from_B = tl.load(B_ptr_block_start + offset_idx, mask=mask)
|
||||
tl.store(C_ptr_block_start + A_section_numel + offset_idx, val_from_B, mask=mask)
|
||||
|
||||
def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
|
||||
A = A.contiguous()
|
||||
B = B.contiguous()
|
||||
output_shape = list(A.shape)
|
||||
output_shape[dim] = A.shape[dim] + B.shape[dim]
|
||||
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
|
||||
|
||||
if dim!=0 :
|
||||
block_num = reduce(lambda x, y: x * y, output_shape[:dim])
|
||||
Per_block = 1
|
||||
unit_offset_A, unit_offset_B, unit_offset_C = A.stride(dim-1),B.stride(dim-1),C.stride(dim-1)
|
||||
#case prefill
|
||||
if (A.shape[2] == 128 and B.shape[2] == 64 and A.shape[0] > 16):
|
||||
Per_block = 8
|
||||
num_blocks = math.ceil(block_num/Per_block)
|
||||
concat_kernel_prefill[(num_blocks,)](
|
||||
A, B, C,
|
||||
unit_offset_A, unit_offset_B, unit_offset_C,
|
||||
Per_block,
|
||||
block_num,
|
||||
BLOCK_SIZE=1024)
|
||||
return C
|
||||
|
||||
else:
|
||||
if (A.shape[1]==8 and A.shape[0] > 128) or ( A.shape[1]==16 and A.shape[0] > 96) or ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 64):
|
||||
Per_block = 2
|
||||
num_blocks = math.ceil(block_num/Per_block)
|
||||
concat_kernel[(num_blocks,)](
|
||||
A, B, C,
|
||||
unit_offset_A, unit_offset_B, unit_offset_C,
|
||||
Per_block,
|
||||
block_num,
|
||||
BLOCK_SIZE=1024)
|
||||
return C
|
||||
assert False, "not support"
|
||||
|
||||
|
||||
configs = []
|
||||
configs.append(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['size'],
|
||||
x_vals=[4,8,16,32,64,96,128,256,512,768,1024],
|
||||
x_log=True,
|
||||
line_arg='provider',
|
||||
line_vals=['triton', 'torch'],
|
||||
line_names=['Triton', 'Torch'],
|
||||
styles=[('blue', '-'), ('green', '-')],
|
||||
ylabel='s',
|
||||
plot_name='concat-dim2',
|
||||
args={"dim":2},
|
||||
),
|
||||
)
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(size, provider, dim):
|
||||
x = torch.rand([size,8,512], device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.rand([size,8,64], device='cuda', dtype=torch.bfloat16)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark_16(size, provider, dim):
|
||||
x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark_32(size, provider, dim):
|
||||
x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark_prefill(size, provider, dim):
|
||||
x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
|
||||
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# benchmark.run(save_path="./triton_test_8",print_data=True)
|
||||
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
|
||||
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
|
||||
benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
|
||||
|
||||
|
||||
248
vllm/v1/attention/backends/mla/concatv4_decode_only.py
Normal file
248
vllm/v1/attention/backends/mla/concatv4_decode_only.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
from functools import reduce
|
||||
import pytest
|
||||
import torch
|
||||
import math
|
||||
|
||||
@pytest.mark.parametrize("shape_pair,dim", [
|
||||
|
||||
(((4, 8, 512), (4, 8, 64)), 2),
|
||||
(((8, 8, 512), (8, 8, 64)), 2),
|
||||
(((16, 8, 512), (16, 8, 64)), 2),
|
||||
(((32, 8, 512), (32, 8, 64)), 2),
|
||||
(((64, 8, 512), (64, 8, 64)), 2),
|
||||
(((128, 8, 512), (128, 8, 64)), 2),
|
||||
(((256, 8, 512), (256, 8, 64)), 2),
|
||||
(((512, 8, 512), (512, 8, 64)), 2),
|
||||
(((672, 8, 512), (672, 8, 64)), 2),
|
||||
(((768, 8, 512), (768, 8, 64)), 2),
|
||||
(((896, 8, 512), (896, 8, 64)), 2),
|
||||
(((1024, 8, 512), (1024, 8, 64)), 2),
|
||||
|
||||
(((4, 16, 512), (4, 16, 64)), 2),
|
||||
(((8, 16, 512), (8, 16, 64)), 2),
|
||||
(((16, 16, 512), (16, 16, 64)), 2),
|
||||
(((32, 16, 512), (32, 16, 64)), 2),
|
||||
(((64, 16, 512), (64, 16, 64)), 2),
|
||||
(((128, 16, 512), (128, 16, 64)), 2),
|
||||
(((256, 16, 512), (256, 16, 64)), 2),
|
||||
(((512, 16, 512), (512, 16, 64)), 2),
|
||||
(((672, 16, 512), (672, 16, 64)), 2),
|
||||
(((768, 16, 512), (768, 16, 64)), 2),
|
||||
(((896, 16, 512), (896, 16, 64)), 2),
|
||||
(((1024, 16, 512), (1024, 16, 64)), 2),
|
||||
|
||||
|
||||
(((4, 32, 512), (4, 32, 64)), 2),
|
||||
(((8, 32, 512), (8, 32, 64)), 2),
|
||||
(((16, 32, 512), (16, 32, 64)), 2),
|
||||
(((32, 32, 512), (32, 32, 64)), 2),
|
||||
(((64, 32, 512), (64, 32, 64)), 2),
|
||||
(((128, 32, 512), (128, 32, 64)), 2),
|
||||
(((256, 32, 512), (256, 32, 64)), 2),
|
||||
(((512, 32, 512), (512, 32, 64)), 2),
|
||||
(((672, 32, 512), (672, 32, 64)), 2),
|
||||
(((768, 32, 512), (768, 32, 64)), 2),
|
||||
(((896, 32, 512), (896, 32, 64)), 2),
|
||||
(((1024, 32, 512), (1024, 32, 64)), 2),
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
])
|
||||
def test_concat_Acc(shape_pair, dim):
|
||||
|
||||
torch.manual_seed(1)
|
||||
shape1, shape2 = shape_pair
|
||||
M = shape1[0]
|
||||
N = shape1[1]
|
||||
x_sizes = [M, N, 512]
|
||||
x_strides = [512, 512*M, 1]
|
||||
x_max_index = M * N * 512
|
||||
x_required_length = x_max_index
|
||||
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
|
||||
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
|
||||
|
||||
# print("形状:", x.shape) # [4, 8, 512]
|
||||
# print("步幅:", x.stride()) # (1536, 192, 1)
|
||||
|
||||
y_sizes = [M, N, 64]
|
||||
y_strides = [1536*(N//8), 192, 1]
|
||||
y_max_index = 1536*(N//8) * M
|
||||
y_required_length = y_max_index
|
||||
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
|
||||
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
|
||||
|
||||
|
||||
|
||||
expected = torch.cat([x,y], dim=dim)
|
||||
result = concat_helper(x, y, dim=dim)
|
||||
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@triton.jit
|
||||
def concat_kernel(
|
||||
A_ptr, B_ptr, C_ptr,
|
||||
A_section_numel, B_section_numel, C_section_numel,
|
||||
Per_block,
|
||||
section_num,
|
||||
M,
|
||||
N,
|
||||
Astride_0,
|
||||
Astride_1,
|
||||
Astride_2,
|
||||
Bstride_0,
|
||||
Bstride_1,
|
||||
Bstride_2,
|
||||
BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
block_idx = tl.program_id(0)
|
||||
for sub_section_index in range(Per_block):
|
||||
sub_offset = block_idx * Per_block + sub_section_index
|
||||
M_idx = sub_offset // N
|
||||
N_idx = sub_offset % N
|
||||
if sub_offset <= section_num-1:
|
||||
C_ptr_block_start = C_ptr + sub_offset * C_section_numel
|
||||
A_ptr_block_start = A_ptr + M_idx * Astride_0 + N_idx * Astride_1
|
||||
B_ptr_block_start = B_ptr + M_idx * Bstride_0 + N_idx * Bstride_1
|
||||
|
||||
for offset in range(0, A_section_numel, BLOCK_SIZE):
|
||||
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset_idx < A_section_numel
|
||||
val_from_A = tl.load(A_ptr_block_start + offset_idx, mask=mask)
|
||||
tl.store(C_ptr_block_start + offset_idx, val_from_A, mask=mask)
|
||||
|
||||
for offset in range(0, B_section_numel, BLOCK_SIZE):
|
||||
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset_idx < B_section_numel
|
||||
val_from_B = tl.load(B_ptr_block_start + offset_idx, mask=mask)
|
||||
tl.store(C_ptr_block_start + A_section_numel + offset_idx, val_from_B, mask=mask)
|
||||
|
||||
def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
|
||||
|
||||
output_shape = list(A.shape)
|
||||
output_shape[dim] = A.shape[dim] + B.shape[dim]
|
||||
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
|
||||
|
||||
if dim!=0 :
|
||||
block_num = reduce(lambda x, y: x * y, output_shape[:dim])
|
||||
Per_block = 1
|
||||
unit_offset_A, unit_offset_B, unit_offset_C = A.shape[dim],B.shape[dim],C.shape[dim]
|
||||
if (A.shape[1]==8 and A.shape[0] > 512) or ( A.shape[1]==16 and A.shape[0] > 256):
|
||||
Per_block = 2
|
||||
if ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 256):
|
||||
Per_block = 8
|
||||
num_blocks = math.ceil(block_num/Per_block)
|
||||
concat_kernel[(num_blocks,)](
|
||||
A, B, C,
|
||||
unit_offset_A, unit_offset_B, unit_offset_C,
|
||||
Per_block,
|
||||
block_num,
|
||||
output_shape[0],
|
||||
output_shape[1],
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
A.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
BLOCK_SIZE=1024)
|
||||
return C
|
||||
|
||||
assert False, "not support"
|
||||
|
||||
|
||||
configs = []
|
||||
configs.append(
|
||||
triton.testing.Benchmark(
|
||||
x_names=['M','N'],
|
||||
x_vals=[(4,8),(8,8),(16,8),(32,8),(64,8),(96,8),(128,8),(256,8),(512,8),(768,8),(1024,8), \
|
||||
(4,16),(8,16),(16,16),(32,16),(64,16),(96,16),(128,16),(256,16),(512,16),(768,16),(1024,16), \
|
||||
(4,32),(8,32),(16,32),(32,32),(64,32),(96,32),(128,32),(256,32),(512,32),(768,32),(1024,32)],
|
||||
x_log=True,
|
||||
line_arg='provider',
|
||||
line_vals=['triton', 'torch'],
|
||||
line_names=['Triton', 'Torch'],
|
||||
styles=[('blue', '-'), ('green', '-')],
|
||||
ylabel='s',
|
||||
plot_name='concat-dim2',
|
||||
args={"dim":2},
|
||||
),
|
||||
)
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(M, N, provider, dim):
|
||||
|
||||
x_sizes = [M, N, 512]
|
||||
x_strides = [512, 512*M, 1]
|
||||
x_max_index = M * N * 512
|
||||
x_required_length = x_max_index
|
||||
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
|
||||
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
|
||||
|
||||
# print("形状:", x.shape) # [M, 8, 512]
|
||||
# print("步幅:", x.stride()) # (512, 512*M, 1)
|
||||
|
||||
y_sizes = [M, N, 64]
|
||||
y_strides = [1536*(N//8), 192, 1]
|
||||
y_max_index = 1536*(N//8) * M
|
||||
y_required_length = y_max_index
|
||||
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
|
||||
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
|
||||
|
||||
# print("形状:", y.shape) # [M, 8, 64]
|
||||
# print("步幅:", y.stride()) # (1536, 192, 1)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == 'torch':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
if provider == 'triton':
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
# @triton.testing.perf_report(configs)
|
||||
# def benchmark_16(size, provider, dim):
|
||||
# x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
|
||||
# y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
|
||||
# quantiles = [0.5, 0.2, 0.8]
|
||||
# if provider == 'torch':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
# if provider == 'triton':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
# return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
# @triton.testing.perf_report(configs)
|
||||
# def benchmark_32(size, provider, dim):
|
||||
# x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
|
||||
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
|
||||
# quantiles = [0.5, 0.2, 0.8]
|
||||
# if provider == 'torch':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
# if provider == 'triton':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
# return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
# @triton.testing.perf_report(configs)
|
||||
# def benchmark_prefill(size, provider, dim):
|
||||
# x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
|
||||
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
|
||||
# quantiles = [0.5, 0.2, 0.8]
|
||||
# if provider == 'torch':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
|
||||
# if provider == 'triton':
|
||||
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
|
||||
# return (ms*1000), (max_ms*1000), (min_ms*1000)
|
||||
|
||||
if __name__ == '__main__':
|
||||
benchmark.run(save_path="./triton_test",print_data=True)
|
||||
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
|
||||
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
|
||||
# benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
|
||||
|
||||
|
||||
98
vllm/v1/attention/backends/mla/cutlass_mla.py
Normal file
98
vllm/v1/attention/backends/mla/cutlass_mla.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CUTLASS_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||
return CutlassMLAImpl
|
||||
|
||||
|
||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
# Run MLA
|
||||
# Clone q_nope and q_pe to make sure strides computation is correct.
|
||||
q_nope = q_nope.clone()
|
||||
q_pe = q_pe.clone()
|
||||
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table, self.scale)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
195
vllm/v1/attention/backends/mla/flashmla.py
Normal file
195
vllm/v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm import envs
|
||||
from vllm.v1.attention.backends.mla.concatv4_decode_only import concat_helper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
||||
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
seq_lens,
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
if self.runner.full_cuda_graph:
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_tile_scheduler_metadata is None:
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
self.cg_buf_num_splits = num_splits
|
||||
else:
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
||||
assert (self.cg_buf_tile_scheduler_metadata.size() ==
|
||||
tile_scheduler_metadata.size())
|
||||
self.cg_buf_tile_scheduler_metadata.\
|
||||
copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
|
||||
num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
assert is_flashmla_supported(), \
|
||||
"FlashMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
if self.kv_cache_dtype != "fp8":
|
||||
raise NotImplementedError(
|
||||
"FlashMLA with other KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
k_scale = None,
|
||||
kv_cache_dtype = "auto",
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if envs.VLLM_USE_TRITON_CAT:
|
||||
if q_nope.shape[0] <= 1024:
|
||||
q = concat_helper(q_nope, q_pe, dim=-1)\
|
||||
.unsqueeze(1)
|
||||
else:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
else:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=attn_metadata.decode.
|
||||
tile_scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
k_scale = k_scale,
|
||||
kv_cache_dtype = kv_cache_dtype,
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
241
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
241
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def is_aiter_mla_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER \
|
||||
and envs.VLLM_ROCM_USE_AITER_MLA
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterMLAImpl"]:
|
||||
return AiterMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AiterMLAMetadata"]:
|
||||
return AiterMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
|
||||
return AiterMLAMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: Optional[torch.Tensor] = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: Optional[torch.Tensor] = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: Optional[torch.Tensor] = None
|
||||
# The query indptr, shape : [num_decode + 1]
|
||||
qo_indptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
full_cudagraph_supported: ClassVar[bool] = True # decode only
|
||||
|
||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
|
||||
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
||||
"only supports block size 1."
|
||||
|
||||
# Preparing persistent buffers
|
||||
if self.runner.full_cuda_graph:
|
||||
device = self.runner.device
|
||||
max_num_reqs = self.runner.max_num_reqs
|
||||
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
block_table.get_device_tensor().numel(
|
||||
), # max num pages possible
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
self.qo_indptr = torch.arange(0,
|
||||
max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
device = self.runner.device
|
||||
|
||||
mask = (torch.arange(block_table_tensor.size(1),
|
||||
dtype=block_table_tensor.dtype,
|
||||
device=device).unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
|
||||
paged_kv_indptr = torch.cat([
|
||||
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
|
||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
||||
])
|
||||
|
||||
if self.runner.full_cuda_graph:
|
||||
num_reqs = self._num_decodes
|
||||
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
|
||||
self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
|
||||
non_blocking=True)
|
||||
self.paged_kv_indices[num_actual_pages:].fill_(-1)
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
|
||||
self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
|
||||
non_blocking=True)
|
||||
self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])
|
||||
paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs]
|
||||
|
||||
self.paged_kv_last_page_len[:num_reqs].copy_(
|
||||
paged_kv_last_page_len, non_blocking=True)
|
||||
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||
|
||||
qo_indptr = self.qo_indptr[:1 + num_reqs]
|
||||
|
||||
else:
|
||||
qo_indptr = torch.arange(0,
|
||||
self._num_decodes + 1,
|
||||
step=1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
assert (num_heads == 16 or num_heads == 128), (
|
||||
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
||||
f"Provided {num_heads} number of heads.\n"
|
||||
"Try adjusting tensor_parallel_size value.")
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=False,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
# max_seqlen_qo must be 1 except for MTP
|
||||
# TODO: Find the best value for MTP
|
||||
max_seqlen_qo = 1
|
||||
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
|
||||
attn_metadata.decode.qo_indptr, max_seqlen_qo,
|
||||
attn_metadata.decode.paged_kv_indptr,
|
||||
attn_metadata.decode.paged_kv_indices,
|
||||
attn_metadata.decode.paged_kv_last_page_len)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
177
vllm/v1/attention/backends/mla/triton_mla.py
Normal file
177
vllm/v1/attention/backends/mla/triton_mla.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
|
||||
self.triton_fa_func = triton_attention if HAS_TRITON else None
|
||||
|
||||
def _flash_attn_varlen_diff_headdims_rocm(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
assert self.triton_fa_func is not None
|
||||
|
||||
# Triton Attention requires a padded V
|
||||
padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
# The output of triton_attention is a tuple of
|
||||
# [output_tensor, encoded_softmax] where encoded_softmax is always None
|
||||
output_tensor, _ = self.triton_fa_func(
|
||||
q,
|
||||
k,
|
||||
padded_v,
|
||||
None, # output
|
||||
kwargs["cu_seqlens_q"],
|
||||
kwargs["cu_seqlens_k"],
|
||||
kwargs["max_seqlen_q"],
|
||||
kwargs["max_seqlen_k"],
|
||||
kwargs["causal"],
|
||||
softmax_scale,
|
||||
None, # bias
|
||||
)
|
||||
|
||||
return output_tensor
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=False,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
if current_platform.is_rocm() \
|
||||
and self.use_triton_flash_attn \
|
||||
and not return_softmax_lse:
|
||||
return self._flash_attn_varlen_diff_headdims_rocm(
|
||||
q, k, v, softmax_scale=softmax_scale, **kwargs)
|
||||
else:
|
||||
return super()._flash_attn_varlen_diff_headdims(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
num_kv_splits = 4 # TODO: heuristic
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
B,
|
||||
self.num_heads,
|
||||
num_kv_splits,
|
||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# just mirror that
|
||||
self.kv_lora_rank + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# Run MQA
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens, attn_logits,
|
||||
num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
Reference in New Issue
Block a user