init src 0.9.2

This commit is contained in:
2026-01-09 15:09:53 +08:00
parent 0eb2c0a4b3
commit 41d98d4359
1438 changed files with 417605 additions and 683 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,250 @@
import triton
import triton.language as tl
import torch
from functools import reduce
import pytest
import torch
import math
@pytest.mark.parametrize("shape_pair,dim", [
(((4, 8, 512), (4, 8, 64)), 2),
(((8, 8, 512), (8, 8, 64)), 2),
(((16, 8, 512), (16, 8, 64)), 2),
(((32, 8, 512), (32, 8, 64)), 2),
(((64, 8, 512), (64, 8, 64)), 2),
(((128, 8, 512), (128, 8, 64)), 2),
(((256, 8, 512), (256, 8, 64)), 2),
(((512, 8, 512), (512, 8, 64)), 2),
(((672, 8, 512), (672, 8, 64)), 2),
(((768, 8, 512), (768, 8, 64)), 2),
(((896, 8, 512), (896, 8, 64)), 2),
(((1024, 8, 512), (1024, 8, 64)), 2),
(((4, 16, 512), (4, 16, 64)), 2),
(((8, 16, 512), (8, 16, 64)), 2),
(((16, 16, 512), (16, 16, 64)), 2),
(((32, 16, 512), (32, 16, 64)), 2),
(((64, 16, 512), (64, 16, 64)), 2),
(((128, 16, 512), (128, 16, 64)), 2),
(((256, 16, 512), (256, 16, 64)), 2),
(((512, 16, 512), (512, 16, 64)), 2),
(((672, 16, 512), (672, 16, 64)), 2),
(((768, 16, 512), (768, 16, 64)), 2),
(((896, 16, 512), (896, 16, 64)), 2),
(((1024, 16, 512), (1024, 16, 64)), 2),
(((4, 32, 512), (4, 32, 64)), 2),
(((8, 32, 512), (8, 32, 64)), 2),
(((16, 32, 512), (16, 32, 64)), 2),
(((32, 32, 512), (32, 32, 64)), 2),
(((64, 32, 512), (64, 32, 64)), 2),
(((128, 32, 512), (128, 32, 64)), 2),
(((256, 32, 512), (256, 32, 64)), 2),
(((512, 32, 512), (512, 32, 64)), 2),
(((672, 32, 512), (672, 32, 64)), 2),
(((768, 32, 512), (768, 32, 64)), 2),
(((896, 32, 512), (896, 32, 64)), 2),
(((1024, 32, 512), (1024, 32, 64)), 2),
(((4, 32, 128), (4, 32, 64)), 2),
(((8, 32, 128), (8, 32, 64)), 2),
(((16, 32, 128), (16, 32, 64)), 2),
(((32, 32, 128), (32, 32, 64)), 2),
(((64, 32, 128), (64, 32, 64)), 2),
(((128, 32, 128), (128, 32, 64)), 2),
(((256, 32, 128), (256, 32, 64)), 2),
(((512, 32, 128), (512, 32, 64)), 2),
(((672, 32, 128), (672, 32, 64)), 2),
(((768, 32, 128), (768, 32, 64)), 2),
(((896, 32, 128), (896, 32, 64)), 2),
(((1024, 32, 128), (1024, 32, 64)), 2),
])
def test_concat_Acc(shape_pair, dim):
torch.manual_seed(1)
shape1, shape2 = shape_pair
x = torch.randn(*shape1, device='cuda', dtype=torch.bfloat16)
y = torch.randn(*shape2, device='cuda', dtype=torch.bfloat16)
expected = torch.cat([x,y], dim=dim)
result = concat_helper(x, y, dim=dim)
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
@triton.jit
def concat_kernel_prefill(
A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel,
Per_block,
section_num,
BLOCK_SIZE: tl.constexpr
):
block_idx = tl.program_id(0)# 获取当前block的索引
for sub_section_index in range(Per_block//2):
sub_section_offset = block_idx * Per_block + sub_section_index * 2
if sub_section_offset <= section_num-1:
C_section_start = C_ptr + sub_section_offset * C_section_numel
A_section_start = A_ptr + sub_section_offset * A_section_numel
B_section_start = B_ptr + sub_section_offset * B_section_numel
Arrange_doubleA = tl.arange(0, 256)
mask = Arrange_doubleA < (256)
Arrange2 = (tl.arange(0, 128)[None,:] + tl.arange(0, 2)[:,None]).reshape(256)
val_from_A = tl.load(A_section_start + Arrange_doubleA)
tensorAsn = tl.full((256,), 0, tl.int32)
tensorAsn2 = tl.full((256,), (C_section_numel-1), tl.int32)
tensor_offsets = tl.where(Arrange_doubleA < A_section_numel,tensorAsn , tensorAsn2)
off = Arrange2 + tensor_offsets
tl.store(C_section_start + off,val_from_A,mask=mask)
Arrange_doubleB = tl.arange(0, 128)
mask = Arrange_doubleB < (B_section_numel*2)
val_from_B = tl.load(B_section_start + Arrange_doubleB,mask=mask)
Arrange3 = (tl.arange(0, 64)[None,:] + tl.arange(0, 2)[:,None]).reshape(128)
tensorAsn = tl.full((128,), A_section_numel, tl.int32)
tensorAsn2 = tl.full((128,), (C_section_numel + A_section_numel-1), tl.int32)
tensor_offsets = tl.where(Arrange_doubleB < B_section_numel,tensorAsn , tensorAsn2)
tl.store(C_section_start+ Arrange3 + tensor_offsets , val_from_B)
@triton.jit
def concat_kernel(
A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel,
Per_block,
section_num,
BLOCK_SIZE: tl.constexpr
):
block_idx = tl.program_id(0)
for sub_section_index in range(Per_block):
sub_offset = block_idx * Per_block + sub_section_index
if sub_offset <= section_num-1:
C_ptr_block_start = C_ptr + sub_offset * C_section_numel
A_ptr_block_start = A_ptr + sub_offset * A_section_numel
B_ptr_block_start = B_ptr + sub_offset * B_section_numel
for offset in range(0, A_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < A_section_numel
val_from_A = tl.load(A_ptr_block_start + offset_idx, mask=mask)
tl.store(C_ptr_block_start + offset_idx, val_from_A, mask=mask)
for offset in range(0, B_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < B_section_numel
val_from_B = tl.load(B_ptr_block_start + offset_idx, mask=mask)
tl.store(C_ptr_block_start + A_section_numel + offset_idx, val_from_B, mask=mask)
def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
A = A.contiguous()
B = B.contiguous()
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
if dim!=0 :
block_num = reduce(lambda x, y: x * y, output_shape[:dim])
Per_block = 1
unit_offset_A, unit_offset_B, unit_offset_C = A.stride(dim-1),B.stride(dim-1),C.stride(dim-1)
#case prefill
if (A.shape[2] == 128 and B.shape[2] == 64 and A.shape[0] > 16):
Per_block = 8
num_blocks = math.ceil(block_num/Per_block)
concat_kernel_prefill[(num_blocks,)](
A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C,
Per_block,
block_num,
BLOCK_SIZE=1024)
return C
else:
if (A.shape[1]==8 and A.shape[0] > 128) or ( A.shape[1]==16 and A.shape[0] > 96) or ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 64):
Per_block = 2
num_blocks = math.ceil(block_num/Per_block)
concat_kernel[(num_blocks,)](
A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C,
Per_block,
block_num,
BLOCK_SIZE=1024)
return C
assert False, "not support"
configs = []
configs.append(
triton.testing.Benchmark(
x_names=['size'],
x_vals=[4,8,16,32,64,96,128,256,512,768,1024],
x_log=True,
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
styles=[('blue', '-'), ('green', '-')],
ylabel='s',
plot_name='concat-dim2',
args={"dim":2},
),
)
@triton.testing.perf_report(configs)
def benchmark(size, provider, dim):
x = torch.rand([size,8,512], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,8,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs)
def benchmark_16(size, provider, dim):
x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs)
def benchmark_32(size, provider, dim):
x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs)
def benchmark_prefill(size, provider, dim):
x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
if __name__ == '__main__':
# benchmark.run(save_path="./triton_test_8",print_data=True)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)

View File

@@ -0,0 +1,248 @@
import triton
import triton.language as tl
import torch
from functools import reduce
import pytest
import torch
import math
@pytest.mark.parametrize("shape_pair,dim", [
(((4, 8, 512), (4, 8, 64)), 2),
(((8, 8, 512), (8, 8, 64)), 2),
(((16, 8, 512), (16, 8, 64)), 2),
(((32, 8, 512), (32, 8, 64)), 2),
(((64, 8, 512), (64, 8, 64)), 2),
(((128, 8, 512), (128, 8, 64)), 2),
(((256, 8, 512), (256, 8, 64)), 2),
(((512, 8, 512), (512, 8, 64)), 2),
(((672, 8, 512), (672, 8, 64)), 2),
(((768, 8, 512), (768, 8, 64)), 2),
(((896, 8, 512), (896, 8, 64)), 2),
(((1024, 8, 512), (1024, 8, 64)), 2),
(((4, 16, 512), (4, 16, 64)), 2),
(((8, 16, 512), (8, 16, 64)), 2),
(((16, 16, 512), (16, 16, 64)), 2),
(((32, 16, 512), (32, 16, 64)), 2),
(((64, 16, 512), (64, 16, 64)), 2),
(((128, 16, 512), (128, 16, 64)), 2),
(((256, 16, 512), (256, 16, 64)), 2),
(((512, 16, 512), (512, 16, 64)), 2),
(((672, 16, 512), (672, 16, 64)), 2),
(((768, 16, 512), (768, 16, 64)), 2),
(((896, 16, 512), (896, 16, 64)), 2),
(((1024, 16, 512), (1024, 16, 64)), 2),
(((4, 32, 512), (4, 32, 64)), 2),
(((8, 32, 512), (8, 32, 64)), 2),
(((16, 32, 512), (16, 32, 64)), 2),
(((32, 32, 512), (32, 32, 64)), 2),
(((64, 32, 512), (64, 32, 64)), 2),
(((128, 32, 512), (128, 32, 64)), 2),
(((256, 32, 512), (256, 32, 64)), 2),
(((512, 32, 512), (512, 32, 64)), 2),
(((672, 32, 512), (672, 32, 64)), 2),
(((768, 32, 512), (768, 32, 64)), 2),
(((896, 32, 512), (896, 32, 64)), 2),
(((1024, 32, 512), (1024, 32, 64)), 2),
])
def test_concat_Acc(shape_pair, dim):
torch.manual_seed(1)
shape1, shape2 = shape_pair
M = shape1[0]
N = shape1[1]
x_sizes = [M, N, 512]
x_strides = [512, 512*M, 1]
x_max_index = M * N * 512
x_required_length = x_max_index
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
# print("形状:", x.shape) # [4, 8, 512]
# print("步幅:", x.stride()) # (1536, 192, 1)
y_sizes = [M, N, 64]
y_strides = [1536*(N//8), 192, 1]
y_max_index = 1536*(N//8) * M
y_required_length = y_max_index
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
expected = torch.cat([x,y], dim=dim)
result = concat_helper(x, y, dim=dim)
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
@triton.jit
def concat_kernel(
A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel,
Per_block,
section_num,
M,
N,
Astride_0,
Astride_1,
Astride_2,
Bstride_0,
Bstride_1,
Bstride_2,
BLOCK_SIZE: tl.constexpr
):
block_idx = tl.program_id(0)
for sub_section_index in range(Per_block):
sub_offset = block_idx * Per_block + sub_section_index
M_idx = sub_offset // N
N_idx = sub_offset % N
if sub_offset <= section_num-1:
C_ptr_block_start = C_ptr + sub_offset * C_section_numel
A_ptr_block_start = A_ptr + M_idx * Astride_0 + N_idx * Astride_1
B_ptr_block_start = B_ptr + M_idx * Bstride_0 + N_idx * Bstride_1
for offset in range(0, A_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < A_section_numel
val_from_A = tl.load(A_ptr_block_start + offset_idx, mask=mask)
tl.store(C_ptr_block_start + offset_idx, val_from_A, mask=mask)
for offset in range(0, B_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < B_section_numel
val_from_B = tl.load(B_ptr_block_start + offset_idx, mask=mask)
tl.store(C_ptr_block_start + A_section_numel + offset_idx, val_from_B, mask=mask)
def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
if dim!=0 :
block_num = reduce(lambda x, y: x * y, output_shape[:dim])
Per_block = 1
unit_offset_A, unit_offset_B, unit_offset_C = A.shape[dim],B.shape[dim],C.shape[dim]
if (A.shape[1]==8 and A.shape[0] > 512) or ( A.shape[1]==16 and A.shape[0] > 256):
Per_block = 2
if ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 256):
Per_block = 8
num_blocks = math.ceil(block_num/Per_block)
concat_kernel[(num_blocks,)](
A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C,
Per_block,
block_num,
output_shape[0],
output_shape[1],
A.stride(0),
A.stride(1),
A.stride(2),
B.stride(0),
B.stride(1),
B.stride(2),
BLOCK_SIZE=1024)
return C
assert False, "not support"
configs = []
configs.append(
triton.testing.Benchmark(
x_names=['M','N'],
x_vals=[(4,8),(8,8),(16,8),(32,8),(64,8),(96,8),(128,8),(256,8),(512,8),(768,8),(1024,8), \
(4,16),(8,16),(16,16),(32,16),(64,16),(96,16),(128,16),(256,16),(512,16),(768,16),(1024,16), \
(4,32),(8,32),(16,32),(32,32),(64,32),(96,32),(128,32),(256,32),(512,32),(768,32),(1024,32)],
x_log=True,
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
styles=[('blue', '-'), ('green', '-')],
ylabel='s',
plot_name='concat-dim2',
args={"dim":2},
),
)
@triton.testing.perf_report(configs)
def benchmark(M, N, provider, dim):
x_sizes = [M, N, 512]
x_strides = [512, 512*M, 1]
x_max_index = M * N * 512
x_required_length = x_max_index
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
# print("形状:", x.shape) # [M, 8, 512]
# print("步幅:", x.stride()) # (512, 512*M, 1)
y_sizes = [M, N, 64]
y_strides = [1536*(N//8), 192, 1]
y_max_index = 1536*(N//8) * M
y_required_length = y_max_index
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
# print("形状:", y.shape) # [M, 8, 64]
# print("步幅:", y.stride()) # (1536, 192, 1)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_16(size, provider, dim):
# x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_32(size, provider, dim):
# x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_prefill(size, provider, dim):
# x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
if __name__ == '__main__':
benchmark.run(save_path="./triton_test",print_data=True)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
# benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)

View File

@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
logger = init_logger(__name__)
class CutlassMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["CutlassMLAImpl"]:
return CutlassMLAImpl
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"CutlassMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CutlassMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CutlassMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
B = q_nope.shape[0]
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
dtype=q_nope.dtype,
device=q_nope.device)
# Run MLA
# Clone q_nope and q_pe to make sure strides computation is correct.
q_nope = q_nope.clone()
q_pe = q_pe.clone()
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale)
return self._v_up_proj(o)

View File

@@ -0,0 +1,195 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm import envs
from vllm.v1.attention.backends.mla.concatv4_decode_only import concat_helper
logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASHMLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type["FlashMLAMetadata"]:
return FlashMLAMetadata
@staticmethod
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: torch.Tensor
num_splits: torch.Tensor
@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
pass
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
if self.runner.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
self.cg_buf_num_splits = num_splits
else:
assert self.cg_buf_num_splits is not None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata.size() ==
tile_scheduler_metadata.size())
self.cg_buf_tile_scheduler_metadata.\
copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
)
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
assert is_flashmla_supported(), \
"FlashMLA is not supported on this device"
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8":
raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
k_scale = None,
kv_cache_dtype = "auto",
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if envs.VLLM_USE_TRITON_CAT:
if q_nope.shape[0] <= 1024:
q = concat_helper(q_nope, q_pe, dim=-1)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
return self._v_up_proj(o)

View File

@@ -0,0 +1,241 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Any, ClassVar, Optional
import torch
import vllm.envs as envs
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
# yapf conflicts with isort for this docstring
# yapf: disable
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
# yapf: enable
def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_MLA
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["AiterMLAImpl"]:
return AiterMLAImpl
@staticmethod
def get_metadata_cls() -> type["AiterMLAMetadata"]:
return AiterMLAMetadata
@staticmethod
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder
@dataclass
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len: Optional[torch.Tensor] = None
# The query indptr, shape : [num_decode + 1]
qo_indptr: Optional[torch.Tensor] = None
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only
def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1."
# Preparing persistent buffers
if self.runner.full_cuda_graph:
device = self.runner.device
max_num_reqs = self.runner.max_num_reqs
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=device)
self.paged_kv_indices = torch.zeros(
block_table.get_device_tensor().numel(
), # max num pages possible
dtype=torch.int32,
device=device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=device)
self.qo_indptr = torch.arange(0,
max_num_reqs + 1,
dtype=torch.int32,
device=device)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size
device = self.runner.device
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
device=device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
paged_kv_indptr = torch.cat([
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
if self.runner.full_cuda_graph:
num_reqs = self._num_decodes
num_actual_pages = paged_kv_indices.size(0)
self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
non_blocking=True)
self.paged_kv_indices[num_actual_pages:].fill_(-1)
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
non_blocking=True)
self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])
paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs]
self.paged_kv_last_page_len[:num_reqs].copy_(
paged_kv_last_page_len, non_blocking=True)
self.paged_kv_last_page_len[num_reqs:].fill_(1)
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
qo_indptr = self.qo_indptr[:1 + num_reqs]
else:
qo_indptr = torch.arange(0,
self._num_decodes + 1,
step=1,
dtype=torch.int32,
device=device)
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr)
return attn_metadata
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
assert (num_heads == 16 or num_heads == 128), (
f"Aiter MLA only supports 16 or 128 number of heads.\n"
f"Provided {num_heads} number of heads.\n"
"Try adjusting tensor_parallel_size value.")
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(self,
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
**kwargs,
)
return output
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.decode.qo_indptr, max_seqlen_qo,
attn_metadata.decode.paged_kv_indptr,
attn_metadata.decode.paged_kv_indices,
attn_metadata.decode.paged_kv_last_page_len)
return self._v_up_proj(o)

View File

@@ -0,0 +1,177 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["TritonMLAImpl"]:
return TritonMLAImpl
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported")
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
self.triton_fa_func = triton_attention if HAS_TRITON else None
def _flash_attn_varlen_diff_headdims_rocm(self,
q,
k,
v,
softmax_scale=None,
**kwargs):
assert self.triton_fa_func is not None
# Triton Attention requires a padded V
padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
# The output of triton_attention is a tuple of
# [output_tensor, encoded_softmax] where encoded_softmax is always None
output_tensor, _ = self.triton_fa_func(
q,
k,
padded_v,
None, # output
kwargs["cu_seqlens_q"],
kwargs["cu_seqlens_k"],
kwargs["max_seqlen_q"],
kwargs["max_seqlen_k"],
kwargs["causal"],
softmax_scale,
None, # bias
)
return output_tensor
def _flash_attn_varlen_diff_headdims(self,
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
if current_platform.is_rocm() \
and self.use_triton_flash_attn \
and not return_softmax_lse:
return self._flash_attn_varlen_diff_headdims_rocm(
q, k, v, softmax_scale=softmax_scale, **kwargs)
else:
return super()._flash_attn_varlen_diff_headdims(
q,
k,
v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs)
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
num_kv_splits = 4 # TODO: heuristic
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
self.num_heads,
num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self.kv_lora_rank + 1,
),
dtype=torch.float32,
device=q.device,
)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
attn_metadata.decode.block_table,
attn_metadata.decode.seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj(o)