Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,7 @@
#pragma once
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
#include "dtype_fp8.cuh"

View File

@@ -0,0 +1,65 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <stdint.h>
namespace vllm {
// A vector type to store Q, K, V elements.
template <typename T, int VEC_SIZE>
struct Vec {};
// A vector type to store FP32 accumulators.
template <typename T>
struct FloatVec {};
// Template vector operations.
template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);
template <typename T>
inline __device__ float sum(T v);
template <typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}
template <typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}
template <typename T>
inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
} // namespace vllm

View File

@@ -0,0 +1,670 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "../cuda_compat.h"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/w8a8/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace vllm {
// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0);
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block.
return;
}
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS =
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t);
float qk_max = -FLT_MAX;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
logits[token_idx - start_token_idx] = -FLT_MAX;
}
}
continue;
}
}
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, *k_scale);
}
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
// If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum;
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
*v_scale);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// context, we should explicitly zero out the values since they may
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
accs[i] += dot(logits_vec, v_vec);
}
}
}
// Perform reduction within each warp.
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]);
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE>
__global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
out_ptr[i] = tmp_out_ptr[i];
}
// Terminate the thread block.
return;
}
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warp_idx = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
// Size: 2 * num_partitions.
extern __shared__ char shared_mem[];
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float max_logit = -FLT_MAX;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i];
shared_max_logits[i] = l;
max_logit = fmaxf(max_logit, l);
}
__syncthreads();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
}
__syncthreads();
// Reduce across warps.
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory.
float* shared_exp_sums =
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i];
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
__syncthreads();
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
}
from_float(out_ptr[i], acc);
}
}
} // namespace vllm
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP

View File

@@ -0,0 +1,57 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../cuda_compat.h"
#include "attention_dtypes.h"
#include <float.h>
#include <type_traits>
namespace vllm {
// Q*K^T operation.
template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = vllm::fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
}
return qk;
}
template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot {
template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
}
};
} // namespace vllm

View File

@@ -0,0 +1,463 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "attention_generic.cuh"
#include "dtype_float32.cuh"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
#endif
#include <stdint.h>
namespace vllm {
// Define custom BF16 vector data types.
struct bf16_4_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
};
struct bf16_8_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
};
// BF16 vector types for Q, K, V.
template <>
struct Vec<__nv_bfloat16, 1> {
using Type = __nv_bfloat16;
};
template <>
struct Vec<__nv_bfloat16, 2> {
using Type = __nv_bfloat162;
};
template <>
struct Vec<__nv_bfloat16, 4> {
using Type = bf16_4_t;
};
template <>
struct Vec<__nv_bfloat16, 8> {
using Type = bf16_8_t;
};
// FP32 accumulator vector types corresponding to Vec.
template <>
struct FloatVec<__nv_bfloat16> {
using Type = float;
};
template <>
struct FloatVec<__nv_bfloat162> {
using Type = float2;
};
template <>
struct FloatVec<bf16_4_t> {
using Type = Float4_;
};
template <>
struct FloatVec<bf16_8_t> {
using Type = Float8_;
};
// Utility functions for type conversions.
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __bfloat1622float2(val);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __bfloat162bfloat162(val);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
// Vector addition.
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
#ifndef USE_ROCM
return a + b;
#else
return __hadd(a, b);
#endif
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hadd2(a, b);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
float2 fa = bf1622float2(a);
return add(fa, fb);
}
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
// Vector multiplication.
template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hmul(a, b);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hmul2(a, b);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
template <>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return c;
}
template <>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return c;
}
template <>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return c;
}
template <>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return c;
}
template <>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
float fa = __bfloat162float(a);
float fb = __bfloat162float(b);
return fa * fb;
}
template <>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
template <>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
template <>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return fc;
}
template <>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return fc;
}
template <>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return fc;
}
template <>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return fc;
}
// Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hfma2(a, b, c);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hfma2(bf162bf162(a), b, c);
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
bf16_4_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
bf16_8_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
return __bfloat162float(a) * __bfloat162float(b) + fc;
}
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return fma(fa, fb, fc);
}
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
return fma(bf162bf162(a), b, fc);
}
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
__nv_bfloat162 s = bf162bf162(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
__nv_bfloat162 s = bf162bf162(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
// Vector sum.
template <>
inline __device__ float sum(__nv_bfloat16 v) {
return __bfloat162float(v);
}
template <>
inline __device__ float sum(__nv_bfloat162 v) {
float2 vf = bf1622float2(v);
return vf.x + vf.y;
}
template <>
inline __device__ float sum(bf16_4_t v) {
return sum(v.x) + sum(v.y);
}
template <>
inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
}
// From float32 to bfloat16.
inline __device__ void from_float(__nv_bfloat16& dst, float src) {
dst = __float2bfloat16(src);
}
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst = __float22bfloat162_rn(src);
#endif
}
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
#endif
}
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
dst.z = __float22bfloat162_rn(src.z);
dst.w = __float22bfloat162_rn(src.w);
#endif
}
// From bfloat16 to float32.
inline __device__ float to_float(__nv_bfloat16 u) {
return __bfloat162float(u);
}
// Zero-out a variable.
inline __device__ void zero(__nv_bfloat16& dst) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
#endif
}
} // namespace vllm

View File

@@ -0,0 +1,504 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "attention_generic.cuh"
#include "dtype_float32.cuh"
#ifdef USE_ROCM
#include <hip/hip_fp16.h>
#endif
#include <stdint.h>
namespace vllm {
// FP16 vector types for Q, K, V.
template <>
struct Vec<uint16_t, 1> {
using Type = uint16_t;
};
template <>
struct Vec<uint16_t, 2> {
using Type = uint32_t;
};
template <>
struct Vec<uint16_t, 4> {
using Type = uint2;
};
template <>
struct Vec<uint16_t, 8> {
using Type = uint4;
};
// FP32 accumulator vector types corresponding to Vec.
template <>
struct FloatVec<uint16_t> {
using Type = float;
};
template <>
struct FloatVec<uint32_t> {
using Type = float2;
};
template <>
struct FloatVec<uint2> {
using Type = Float4_;
};
template <>
struct FloatVec<uint4> {
using Type = Float8_;
};
// Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) {
#ifndef USE_ROCM
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
#else
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u16[0] = a;
tmp.u16[1] = a;
return tmp.u32;
#endif
}
inline __device__ float half_to_float(uint16_t h) {
float f;
#ifndef USE_ROCM
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
#else
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
#endif
return f;
}
inline __device__ float2 half2_to_float2(uint32_t v) {
#ifndef USE_ROCM
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
#else
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u32 = v;
float2 ret;
ret.x = half_to_float(tmp.u16[0]);
ret.y = half_to_float(tmp.u16[1]);
return ret;
#endif
}
inline __device__ uint16_t float_to_half(float f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#ifndef USE_ROCM
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
#else
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
#endif
return tmp.u16[0];
}
inline __device__ uint32_t float2_to_half2(float2 f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#ifndef USE_ROCM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
#else
tmp.u16[0] = float_to_half(f.x);
tmp.u16[1] = float_to_half(f.y);
#endif
return tmp.u32;
}
// Vector addition.
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
uint16_t c;
#ifndef USE_ROCM
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
#else
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
uint32_t c;
#ifndef USE_ROCM
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
#else
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}
inline __device__ uint2 add(uint2 a, uint2 b) {
uint2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ uint4 add(uint4 a, uint4 b) {
uint4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(uint32_t a, float2 fb) {
float2 fa = half2_to_float2(a);
return add(fa, fb);
}
inline __device__ Float4_ add(uint2 a, Float4_ fb) {
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
// Vector multiplication.
template <>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c;
#ifndef USE_ROCM
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
#else
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}
template <>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c;
#ifndef USE_ROCM
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
#else
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}
template <>
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
}
template <>
inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
return c;
}
template <>
inline __device__ uint2 mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
return c;
}
template <>
inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
return c;
}
template <>
inline __device__ uint4 mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
return c;
}
template <>
inline __device__ float mul(uint16_t a, uint16_t b) {
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb;
}
template <>
inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb);
}
template <>
inline __device__ float2 mul(uint16_t a, uint32_t b) {
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
}
template <>
inline __device__ Float4_ mul(uint2 a, uint2 b) {
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
return fc;
}
template <>
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
return fc;
}
template <>
inline __device__ Float8_ mul(uint4 a, uint4 b) {
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
return fc;
}
template <>
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
return fc;
}
// Vector fused multiply-add.
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
#ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
#endif
return d;
}
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
return fma(h0_h0(a), b, c);
}
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
uint2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
uint32_t s = h0_h0(a);
uint2 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
uint4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
uint32_t s = h0_h0(a);
uint4 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb + fc;
}
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return fma(fa, fb, fc);
}
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
return fma(h0_h0(a), b, fc);
}
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
uint32_t s = h0_h0(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
uint32_t s = h0_h0(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
// Vector sum.
template <>
inline __device__ float sum(uint16_t v) {
return half_to_float(v);
}
template <>
inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
template <>
inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y);
return sum(c);
}
template <>
inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
return sum(c);
}
// From float32 to float16.
inline __device__ void from_float(uint16_t& dst, float src) {
dst = float_to_half(src);
}
inline __device__ void from_float(uint32_t& dst, float2 src) {
dst = float2_to_half2(src);
}
inline __device__ void from_float(uint2& dst, Float4_ src) {
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
}
inline __device__ void from_float(uint4& dst, Float8_ src) {
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
dst.z = float2_to_half2(src.z);
dst.w = float2_to_half2(src.w);
}
// From float16 to float32.
inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
return tmp;
}
inline __device__ Float8_ to_float(uint4 u) {
Float8_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
tmp.z = half2_to_float2(u.z);
tmp.w = half2_to_float2(u.w);
return tmp;
}
// Zero-out a variable.
inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
} // namespace vllm

View File

@@ -0,0 +1,251 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "attention_generic.cuh"
#include <stdint.h>
namespace vllm {
// Define custom FP32 vector data types.
struct Float4_ {
float2 x;
float2 y;
};
struct Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
// FP32 vector types for Q, K, V.
template <>
struct Vec<float, 1> {
using Type = float;
};
template <>
struct Vec<float, 2> {
using Type = float2;
};
template <>
struct Vec<float, 4> {
using Type = float4;
};
// FP32 accumulator vector types corresponding to Vec.
template <>
struct FloatVec<float> {
using Type = float;
};
template <>
struct FloatVec<float2> {
using Type = float2;
};
template <>
struct FloatVec<float4> {
using Type = float4;
};
// Vector addition.
inline __device__ float add(float a, float b) { return a + b; }
inline __device__ float2 add(float2 a, float2 b) {
float2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ float4 add(float4 a, float4 b) {
float4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
// Vector multiplication.
template <>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}
template <>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
template <>
inline __device__ float2 mul(float a, float2 b) {
float2 c;
c.x = a * b.x;
c.y = a * b.y;
return c;
}
template <>
inline __device__ float4 mul(float4 a, float4 b) {
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
template <>
inline __device__ float4 mul(float a, float4 b) {
float4 c;
c.x = a * b.x;
c.y = a * b.y;
c.z = a * b.z;
c.w = a * b.w;
return c;
}
// Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ float2 fma(float a, float2 b, float2 c) {
float2 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
float4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ float4 fma(float a, float4 b, float4 c) {
float4 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
Float4_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
Float8_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
// Vector sum.
template <>
inline __device__ float sum(float v) {
return v;
}
template <>
inline __device__ float sum(float2 v) {
return v.x + v.y;
}
template <>
inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w;
}
template <>
inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y;
}
template <>
inline __device__ float sum(Float8_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
}
// Vector dot product.
inline __device__ float dot(float a, float b) { return a * b; }
inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b);
return c.x + c.y;
}
inline __device__ float dot(Float4_ a, Float4_ b) {
float2 acc = mul<float2, float2, float2>(a.x, b.x);
acc = fma(a.y, b.y, acc);
return acc.x + acc.y;
}
inline __device__ float dot(Float8_ a, Float8_ b) {
float2 acc = mul<float2, float2, float2>(a.x, b.x);
acc = fma(a.y, b.y, acc);
acc = fma(a.z, b.z, acc);
acc = fma(a.w, b.w, acc);
return acc.x + acc.y;
}
// From float to float.
inline __device__ void from_float(float& dst, float src) { dst = src; }
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
// From float to float.
inline __device__ float to_float(float u) { return u; }
inline __device__ float2 to_float(float2 u) { return u; }
inline __device__ float4 to_float(float4 u) { return u; }
inline __device__ Float4_ to_float(Float4_ u) { return u; }
inline __device__ Float8_ to_float(Float8_ u) { return u; }
// Zero-out a variable.
inline __device__ void zero(float& dst) { dst = 0.f; }
} // namespace vllm

View File

@@ -0,0 +1,41 @@
#pragma once
#include "attention_generic.cuh"
#include <stdint.h>
#ifdef ENABLE_FP8
#ifndef USE_ROCM
#include <cuda_fp8.h>
#endif // USE_ROCM
#endif // ENABLE_FP8
namespace vllm {
enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
};
// fp8 vector types for quantization of kv cache
template <>
struct Vec<uint8_t, 1> {
using Type = uint8_t;
};
template <>
struct Vec<uint8_t, 2> {
using Type = uint16_t;
};
template <>
struct Vec<uint8_t, 4> {
using Type = uint32_t;
};
template <>
struct Vec<uint8_t, 8> {
using Type = uint2;
};
} // namespace vllm

View File

@@ -0,0 +1,209 @@
#include <optional>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
namespace vllm {
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
template <typename scalar_t, const uint NUM_THREADS>
__global__ void merge_attn_states_kernel(
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
const uint head_size, const uint prefix_head_stride,
const uint output_head_stride) {
using pack_128b_t = uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
if (global_idx >= token_head_threads) return;
// global_idx -> token_idx + head_idx + pack_idx
const uint token_head_idx = global_idx / threads_per_head;
const uint pack_idx = global_idx % threads_per_head;
const uint token_idx = token_head_idx / num_heads;
const uint head_idx = token_head_idx % num_heads;
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
const uint src_head_offset = token_idx * num_heads * prefix_head_stride +
head_idx * prefix_head_stride;
const uint dst_head_offset = token_idx * num_heads * output_head_stride +
head_idx * output_head_stride;
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
scalar_t* output_head_ptr = output + dst_head_offset;
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
const float max_lse = fmaxf(p_lse, s_lse);
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
continuing the pipeline then yields NaN. Root cause: with chunked prefill
a batch may be split into two chunks; if a request in that batch has no
prefix hit, every LSE entry for that requests position is -inf, and at
this moment we merge cross-attention at first. For now we simply emit
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
this problem.
*/
if (std::isinf(max_lse)) {
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_head_ptr)[pack_offset / pack_size];
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
p_out_pack;
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
output_lse[head_idx * num_tokens + token_idx] = max_lse;
}
return;
}
p_lse = p_lse - max_lse;
s_lse = s_lse - max_lse;
const float p_se = expf(p_lse);
const float s_se = expf(s_lse);
const float out_se = p_se + s_se;
const float p_scale = p_se / out_se;
const float s_scale = s_se / out_se;
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_head_ptr)[pack_offset / pack_size];
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
suffix_head_ptr)[pack_offset / pack_size];
pack_128b_t o_out_pack;
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
}
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
o_out_pack;
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
float out_lse = logf(out_se) + max_lse;
output_lse[head_idx * num_tokens + token_idx] = out_lse;
}
}
} // namespace vllm
// The following macro is used to dispatch the conversion function based on
// the output data type. The FN is a macro that calls a function with
// template<typename scalar_t>.
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
{ \
if (scalar_dtype == at::ScalarType::Float) { \
fn(float); \
} else if (scalar_dtype == at::ScalarType::Half) { \
fn(uint16_t); \
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
fn(__nv_bfloat16); \
} else { \
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
} \
}
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size, prefix_head_stride, output_head_stride); \
}
/*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
*
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
* states.
*/
template <typename scalar_t>
void merge_attn_states_launcher(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
const uint prefix_head_stride = prefix_output.stride(1);
const uint output_head_stride = output.stride(1);
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
}
// Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8.
const uint threads_per_head = head_size / pack_size;
const uint total_threads = num_tokens * num_heads * threads_per_head;
dim3 block(NUM_THREADS);
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
auto stream = at::cuda::getCurrentCUDAStream();
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
}
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
prefix_lse, suffix_output, \
suffix_lse); \
}
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
}

View File

@@ -0,0 +1,385 @@
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
// clang-format off
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_mla_reduction.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
using namespace cute;
using namespace cutlass::fmha::kernel;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<
class Kernel_
>
class MLA {
public:
using Kernel = Kernel_;
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
typename Kernel::ElementOut,
typename Kernel::ElementAcc,
typename Kernel::ElementAcc,
Kernel::TileShapeH::value,
Kernel::TileShapeL::value,
256 /*Max split*/
>;
/// Argument structure: User API
using KernelArguments = typename Kernel::Arguments;
using ReductionArguments = typename ReductionKernel::Arguments;
using Arguments = KernelArguments;
/// Argument structure: Kernel API
using KernelParams = typename Kernel::Params;
using ReductionParams = typename ReductionKernel::Params;
struct Params {
KernelParams fmha_params;
ReductionParams reduction_params;
};
private:
/// Kernel API parameters object
Params params_;
bool is_initialized(bool set = false) {
static bool initialized = false;
if (set) initialized = true;
return initialized;
}
static ReductionArguments to_reduction_args(Arguments const& args) {
auto [H, K, D, B] = args.problem_shape;
return ReductionArguments{
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
args.ptr_split_kv, Kernel::TileShapeS::value
};
}
public:
/// Access the Params structure
Params const& params() const {
return params_;
}
static void set_split_kv (KernelArguments& args) {
if (args.split_kv >= 1) return;
auto [H, K, D, B] = args.problem_shape;
int sm_count = args.hw_info.sm_count;
float seq_length_k = static_cast<float>(K) / 1024.0f;
int max_splits = 1;
if (B <= 4 && seq_length_k >= 16) {
max_splits = 16;
}
else if (B <= 8 && seq_length_k >= 4) {
max_splits = 8;
}
else if ((B <= 16 && seq_length_k >= 8) ||
(B == 48 && seq_length_k >= 32)) {
max_splits = 4;
}
else if ((B <= 32 && seq_length_k >= 16) ||
(B == 96 && seq_length_k >= 16)) {
max_splits = 2;
}
else {
max_splits = 1;
}
// Wave-aware scheduling: ensure integer number of waves in K dimension
int sms_per_batch = max(1, sm_count / B);
int split_heur = min(max_splits, sms_per_batch);
int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware;
}
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
if (! Kernel::can_implement(args)) {
return Status::kInvalid;
}
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
return Status::kInvalid;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_bytes = 0;
workspace_bytes += Kernel::get_workspace_size(args);
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
return workspace_bytes;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = Kernel::SharedStorageSize;
// first, account for dynamic smem capacity if needed
cudaError_t result;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return -1;
}
}
// query occupancy after setting smem size
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
device_kernel<Kernel>,
Kernel::MaxThreadsPerBlock,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize the workspace
Status status = Kernel::initialize_workspace(args, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
if (status != Status::kSuccess) {
return status;
}
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {kernel_params, reduction_params};
if (is_initialized()) return Status::kSuccess;
// account for dynamic smem capacity if needed
// no dynamic smem is needed for reduction kernel
int smem_size = Kernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
is_initialized(true);
return Status::kSuccess;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes > 0 && nullptr == workspace) {
return Status::kErrorWorkspaceNull;
}
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {fmha_params, reduction_params};
return Status::kSuccess;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
cute::size<1>(typename Kernel::ClusterShape{}),
cute::size<2>(typename Kernel::ClusterShape{}));
void const* kernel = (void const*) device_kernel<Kernel>;
void* kernel_params[] = {&params.fmha_params};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
else {
launch_result = Status::kSuccess;
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
}
cudaError_t result = cudaGetLastError();
if (cudaSuccess != result or Status::kSuccess != launch_result) {
//return Status::kSuccess;
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
if (params.reduction_params.split_kv > 1) {
// launch reduction kernel
dim3 const block = ReductionKernel::get_block_shape();
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
cudaError_t result = cudaGetLastError();
if (cudaSuccess == result) {
return Status::kSuccess;
}
else {
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
}
else {
return Status::kSuccess;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,203 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<
class ElementOut,
class ElementAcc,
class ElementScale,
size_t kNumHeads,
size_t kHeadDimLatent,
int kMaxSplits
>
struct Sm100FmhaMlaReductionKernel {
static const int SharedStorageSize = 0;
static const int MaxThreadsPerBlock = 128;
static const int MinBlocksPerMultiprocessor = 1;
using ArchTag = cutlass::arch::Sm100;
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
struct Arguments {
ElementAcc* ptr_oaccum = nullptr;
ElementOut* ptr_o = nullptr;
ElementAcc* ptr_lseaccum = nullptr;
ElementAcc* ptr_lse = nullptr;
ElementScale scale = 1.f;
int num_batches = 0;
int split_kv = -1;
int dim_k = -1;
int* ptr_seq = nullptr;
int* ptr_split_kv = nullptr;
int tile_shape_s = 128;
};
using Params = Arguments;
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
args.ptr_split_kv, args.tile_shape_s};
}
static size_t get_workspace_size(Arguments const& /*args*/) {
return 0;
}
static Status initialize_workspace(
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
return Status::kSuccess;
}
static dim3 get_grid_shape(Params const& params) {
return dim3(kNumHeads, 1, params.num_batches);
}
static dim3 get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
static bool can_implement(Arguments const& args) {
if (args.num_batches <= 0) return false;
if (args.split_kv <= 0) return false;
return true;
}
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
if (params.split_kv <= 1) return;
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
__shared__ ElementAcc sLseScale[kMaxSplits];
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
Shape<_1>{}, Stride<_1>{});
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
int warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0) {
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
ElementAcc local_lse[kNLsePerThread];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
}
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
lse_max = max(lse_max, local_lse[i]);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
}
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
ElementAcc sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
}
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
gLSE(0) = global_lse;
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
if (split < local_split_kv) {
sLseScale[split] = expf(local_lse[i] - global_lse);
}
}
}
__syncthreads();
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
ElementAcc local_val[Elements] = {0};
for (int split = 0; split < local_split_kv; ++split) {
ElementAcc lse_scale = sLseScale[split];
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
}
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
}
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
}
}
};
} // namespace cutlass::fmha::kernel

View File

@@ -0,0 +1,165 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaIndividualTileScheduler {
struct Params {
dim3 grid;
};
bool valid_ = true;
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler(Params const&) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
return Params{ grid };
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
}
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaPersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_split_kv;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = size<0>(cluster_shape);
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
num_blocks *= split_kv; /* Maximum Split KV*/
return Params {
num_blocks,
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, n_split_kv;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
return make_coord(m_block, _0{}, bidb, n_split_kv);
}
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel

View File

@@ -0,0 +1,291 @@
/*
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
#include "core/registration.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/cutlass.h>
#include <cutlass/kernel_hardware_info.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <iostream>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
void sm100_cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& lse,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
}
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
}
#else
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
using namespace cute;
using namespace cutlass::fmha::kernel;
template <bool v>
struct IsPersistent {
static const bool value = v;
};
template <typename T, typename TOut, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
struct MlaSm100 {
using Element = T;
using ElementAcc = float;
using ElementOut = TOut;
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
using TileShapeH = cute::tuple_element_t<0, TileShape>;
using TileShapeD = cute::tuple_element_t<2, TileShape>;
// H K (D_latent D_rope) B
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
using StrideO = StrideK; // H D B
using StrideLSE = cute::tuple<_1, int>; // H B
using TileScheduler =
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
TileShape,
Element,
ElementAcc,
ElementOut,
ElementAcc,
TileScheduler,
/*kIsCpAsync=*/!IsPaged128>;
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
};
template <typename T>
typename T::Fmha::Arguments args_from_options(
at::Tensor const& out,
at::Tensor const& lse,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
double sm_scale,
int64_t num_kv_splits) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = q_nope.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
int batches = q_nope.sizes()[0];
int page_count_per_seq = page_table.sizes()[1];
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
int page_size = kv_c_and_k_pe_cache.sizes()[1];
int max_seq_len = page_size * page_count_per_seq;
using TileShapeH = typename T::TileShapeH;
using TileShapeD = typename T::TileShapeD;
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
float scale = float(sm_scale);
using StrideQ = typename T::StrideQ;
using StrideK = typename T::StrideK;
using StrideO = typename T::StrideO;
using StrideLSE = typename T::StrideLSE;
StrideQ stride_Q_nope = cute::make_tuple(
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
StrideQ stride_Q_pe = cute::make_tuple(
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
StrideK stride_C = cute::make_tuple(
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
using Element = typename T::Element;
using ElementOut = typename T::ElementOut;
using ElementAcc = typename T::ElementAcc;
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
typename T::Fmha::Arguments arguments{
problem_shape,
{scale,
Q_nope_ptr,
stride_Q_nope,
Q_pe_ptr,
stride_Q_pe,
C_ptr,
stride_C,
C_ptr + D_latent,
stride_C,
static_cast<int*>(seq_lens.data_ptr()),
static_cast<int*>(page_table.data_ptr()),
stride_PT,
page_count_total,
page_size},
{static_cast<ElementOut*>(out.data_ptr()),
stride_O,
static_cast<ElementAcc*>(lse.defined() ? lse.data_ptr() : nullptr),
stride_LSE},
hw_info,
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
static_cast<int>(num_kv_splits), // split_kv
nullptr, // is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
// workload across available SMs. Consider using var_split_kv for manual
// control if needed.
T::Fmha::set_split_kv(arguments);
return arguments;
}
template <typename Element, typename ElementOut, bool IsPaged128, typename PersistenceOption>
void runMla(
at::Tensor const& out,
at::Tensor const& lse,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
at::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits,
cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element, ElementOut, IsPaged128, PersistenceOption>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
CUTLASS_CHECK(fmha.can_implement(arguments));
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
void sm100_cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& lse,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits) {
auto in_dtype = q_nope.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
// Maybe per batch split kv will fix this.
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t, cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}
return true;
});
return true;
});
}
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
using MlaSm100Type = MlaSm100<cutlass::half_t, cutlass::half_t, true>;
// Get split kv. Requires problem shape and sm_count only.
typename MlaSm100Type::Fmha::Arguments arguments;
using TileShapeH = typename MlaSm100Type::TileShapeH;
using TileShapeD = typename MlaSm100Type::TileShapeD;
arguments.problem_shape =
cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
// Assumes device 0 when getting sm_count.
arguments.hw_info.sm_count =
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
arguments.split_kv = static_cast<int>(num_kv_splits);
MlaSm100Type::Fmha::set_split_kv(arguments);
return MlaSm100Type::Fmha::get_workspace_size(arguments);
}
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) {
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
}
// clang-format on

View File

@@ -0,0 +1,186 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "attention_kernels.cuh"
#include "../cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, IS_BLOCK_SPARSE>), \
shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
const int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size);
dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V1(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V1(64);
break;
case 80:
LAUNCH_PAGED_ATTENTION_V1(80);
break;
case 96:
LAUNCH_PAGED_ATTENTION_V1(96);
break;
case 112:
LAUNCH_PAGED_ATTENTION_V1(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V1(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V1(128);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V1(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V1(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
if (is_block_sparse) { \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
} else { \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE)
}
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP

View File

@@ -0,0 +1,196 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "attention_kernels.cuh"
#include "../cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
const int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// For paged attention v2 kernel.
dim3 grid(num_heads, num_seqs, max_num_partitions);
int shared_mem_size = std::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 32:
LAUNCH_PAGED_ATTENTION_V2(32);
break;
case 64:
LAUNCH_PAGED_ATTENTION_V2(64);
break;
case 80:
LAUNCH_PAGED_ATTENTION_V2(80);
break;
case 96:
LAUNCH_PAGED_ATTENTION_V2(96);
break;
case 112:
LAUNCH_PAGED_ATTENTION_V2(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V2(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V2(128);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V2(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V2(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
if (is_block_sparse) { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
} else { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP

View File

@@ -0,0 +1,401 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <assert.h>
#include <cuda.h>
#include <torch/all.h>
__device__ int64_t save_blocks(int* block_offset, int64_t range_start,
int64_t range_end, int64_t block_size,
int64_t input_block_count, int64_t kv_seqlen) {
if (range_start >= kv_seqlen) {
return input_block_count;
}
if (range_end > kv_seqlen) {
range_end = kv_seqlen;
}
int64_t current_block_count = input_block_count;
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[current_block_count++] = idx;
}
return current_block_count;
}
__global__ void convert_vertical_slash_indexes_kernel(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count,
block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M,
BLOCK_SIZE_N, NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* This function builds the index of each row of blocks from vertical indices
* and slash indices. The vertical indices are treated as points, while the
* slash indices are converted as ranges. The output consists of the merged
* ranges and separate column indices, where the ranges are represented by
* block indices.
*
* The implementation is referenced from the original MInference repo:
* https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu.
*/
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(), block_offset.data_ptr<int>(),
column_count.data_ptr<int>(), column_index.data_ptr<int>(), batch_size,
num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash,
causal);
}
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
const int* per_head_vertical_topkv, const int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
// MergeHead: each head has it's unique max topk NNZ_VNNZ_S. (NNZ_VNNZ_S
// above is buffer size, use to compute offset)
NNZ_S = per_head_slash_topkv[head_idx];
NNZ_V = per_head_vertical_topkv[head_idx];
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* per_head_vertical_topkv, int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes,
per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset,
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* Like the above convert_vertical_slash_indexes, but with
* pre-computed vertical and slash counts.
*/
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, // [N_HEADS, ]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64_mergehead(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
vertical_indices_count.data_ptr<int>(),
slash_indices_count.data_ptr<int>(), block_count.data_ptr<int>(),
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
block_size_M, block_size_N, nnz_vertical, nnz_slash, causal);
}