This commit is contained in:
2026-01-09 13:34:11 +08:00
parent dfa6476b58
commit b2ef04d792
538 changed files with 105693 additions and 2 deletions

View File

@@ -0,0 +1,161 @@
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include <torch/extension.h>
#include "torch_musa/csrc/core/MUSAGuard.h"
#include <cmath>
#include "musa_compat.h"
#include "dispatch_utils.h"
namespace vllm {
// Activation and gating kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x) * y;
}
}
template<typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x)));
}
template<typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float) x;
constexpr float ALPHA = M_SQRT1_2;
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}
template<typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float) x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715;
float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube);
return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
}
} // namespace vllm
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::musa::OptionalMUSAGuard device_guard(device_of(input)); \
const musaStream_t stream = at::musa::getCurrentMUSAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"act_and_mul_kernel", \
[&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
void silu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
void gelu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
}
void gelu_tanh_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
}
namespace vllm {
// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
} // namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::musa::OptionalMUSAGuard device_guard(device_of(input)); \
const musaStream_t stream = at::musa::getCurrentMUSAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"activation_kernel", \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace vllm {
template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float) (x * x * x);
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
template<typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float) x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
} // namespace vllm
void gelu_new(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}

View File

@@ -0,0 +1,7 @@
#pragma once
#include "attention_generic.muh"
#include "dtype_float16.muh"
#include "dtype_float32.muh"
#include "dtype_bfloat16.muh"
#include "dtype_fp8.muh"

View File

@@ -0,0 +1,65 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <stdint.h>
namespace vllm {
// A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE>
struct Vec {};
// A vector type to store FP32 accumulators.
template<typename T>
struct FloatVec {};
// Template vector operations.
template<typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);
template<typename T>
inline __device__ float sum(T v);
template<typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}
template<typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}
template<typename T>
inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
} // namespace vllm

View File

@@ -0,0 +1,981 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include "torch_musa/csrc/core/MUSAGuard.h"
#include "attention_dtypes.h"
#include "attention_utils.muh"
#if defined(ENABLE_FP8_E5M2)
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#elif defined(ENABLE_FP8_E4M3)
#include "../quantization/fp8/amd_detail/quant_utils.cuh"
#endif
#include <algorithm>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __mt_bfloat16;
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace vllm {
// Utility function for attention softmax.
template<int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0);
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template<
typename scalar_t,
typename cache_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_KV_CACHE,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const float kv_scale) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block.
return;
}
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
// fetch or compute 16 bytes at a time.
// For example, if the size of a thread group is 4 and the data type is half,
// then the vector size is 16 / (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
#endif
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t);
float qk_max = -FLT_MAX;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
// because int32 can lead to overflow when this variable is multiplied by large numbers
// (e.g., kv_block_stride).
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (IS_FP8_KV_CACHE) {
#if defined(ENABLE_FP8_E5M2)
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec.
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
#elif defined(ENABLE_FP8_E4M3)
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k
// cache vec to k vec in higher precision (FP16, BFloat16, etc.)
k_vecs[j] = fp8_e4m3::scaled_vec_conversion<K_vec, Quant_vec>(k_vec_quant, kv_scale);
#else
assert(false);
#endif
} else {
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
}
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
// If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+ head_idx * max_num_partitions
+ partition_idx;
*max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+ head_idx * max_num_partitions
+ partition_idx;
*exp_sums_ptr = exp_sum;
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
#endif
using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
// because int32 can lead to overflow when this variable is multiplied by large numbers
// (e.g., kv_block_stride).
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec;
if constexpr (IS_FP8_KV_CACHE) {
#if defined(ENABLE_FP8_E5M2)
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
#elif defined(ENABLE_FP8_E4M3)
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert
// FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.)
v_vec = fp8_e4m3::scaled_vec_conversion<V_vec, V_quant_vec>(v_quant_vec, kv_scale);
#else
assert(false);
#endif
} else {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
accs[i] += dot(logits_vec, v_vec);
}
}
}
// Perform reduction within each warp.
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for logits
// is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ head_idx * max_num_partitions * HEAD_SIZE
+ partition_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]);
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template<
typename scalar_t,
typename cache_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_KV_CACHE>
__global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template<
typename scalar_t,
typename cache_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS,
bool IS_FP8_KV_CACHE,
int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride,
const float kv_scale) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride, kv_scale);
}
// Grid: (num_heads, num_seqs).
template<
typename scalar_t,
int HEAD_SIZE,
int NUM_THREADS,
int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ head_idx * max_num_partitions * HEAD_SIZE;
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
out_ptr[i] = tmp_out_ptr[i];
}
// Terminate the thread block.
return;
}
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warp_idx = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
// Size: 2 * num_partitions.
extern __shared__ char shared_mem[];
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+ head_idx * max_num_partitions;
float max_logit = -FLT_MAX;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i];
shared_max_logits[i] = l;
max_logit = fmaxf(max_logit, l);
}
__syncthreads();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
}
__syncthreads();
// Reduce across warps.
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory.
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+ head_idx * max_num_partitions;
float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i];
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
__syncthreads();
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
}
from_float(out_ptr[i], acc);
}
}
} // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_KV_CACHE>), shared_mem_size); \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
num_kv_heads, \
scale, \
block_tables_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
kv_scale);
// TODO(woosuk): Tune NUM_THREADS.
template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
bool IS_FP8_KV_CACHE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size);
dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
const at::musa::OptionalMUSAGuard device_guard(device_of(query));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 64:
LAUNCH_PAGED_ATTENTION_V1(64);
break;
case 80:
LAUNCH_PAGED_ATTENTION_V1(80);
break;
case 96:
LAUNCH_PAGED_ATTENTION_V1(96);
break;
case 112:
LAUNCH_PAGED_ATTENTION_V1(112);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V1(128);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V1(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
break; \
case 16: \
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
break; \
case 32: \
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V1_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, __mt_bfloat16, false);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V1_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, uint8_t, true);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
IS_FP8_KV_CACHE, PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
num_kv_heads, \
scale, \
block_tables_ptr, \
seq_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride, \
kv_scale); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, \
exp_sums_ptr, \
max_logits_ptr, \
tmp_out_ptr, \
seq_lens_ptr, \
max_num_partitions);
template<
typename T,
typename CACHE_T,
int BLOCK_SIZE,
bool IS_FP8_KV_CACHE,
int NUM_THREADS = 128,
int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
float kv_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// For paged attention v2 kernel.
dim3 grid(num_heads, num_seqs, max_num_partitions);
int shared_mem_size = std::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::musa::OptionalMUSAGuard device_guard(device_of(query));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we only compile for the
// head sizes that we use in the model. However, we can easily extend this
// to support any head size which is a multiple of 16.
case 64:
LAUNCH_PAGED_ATTENTION_V2(64);
break;
case 80:
LAUNCH_PAGED_ATTENTION_V2(80);
break;
case 96:
LAUNCH_PAGED_ATTENTION_V2(96);
break;
case 112:
LAUNCH_PAGED_ATTENTION_V2(112);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V2(128);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V2(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
out, \
exp_sums, \
max_logits, \
tmp_out, \
query, \
key_cache, \
value_cache, \
num_kv_heads, \
scale, \
block_tables, \
seq_lens, \
max_seq_len, \
alibi_slopes, \
kv_scale);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
break; \
case 16: \
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
break; \
case 32: \
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
int num_kv_heads, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V2_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, __mt_bfloat16, false);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8") {
if (query.dtype() == at::ScalarType::Float) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_V2_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, uint8_t, true);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP

View File

@@ -0,0 +1,57 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../musa_compat.h"
#include "attention_dtypes.h"
#include <float.h>
#include <type_traits>
namespace vllm {
// Q*K^T operation.
template<int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
}
return qk;
}
template<typename T, int THREAD_GROUP_SIZE>
struct Qk_dot {
template<typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
}
};
} // namespace vllm

View File

@@ -0,0 +1,452 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "attention_generic.muh"
#include "dtype_float32.muh"
#ifndef USE_ROCM
#include <musa_bf16.h>
#include <musa_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
typedef __hip_bfloat162 __mt_bfloat162;
typedef __hip_bfloat16 __mt_bfloat16;
#endif
#include <stdint.h>
namespace vllm {
// Define custom BF16 vector data types.
struct bf16_4_t {
__mt_bfloat162 x;
__mt_bfloat162 y;
};
struct bf16_8_t {
__mt_bfloat162 x;
__mt_bfloat162 y;
__mt_bfloat162 z;
__mt_bfloat162 w;
};
// BF16 vector types for Q, K, V.
template<>
struct Vec<__mt_bfloat16, 1> {
using Type = __mt_bfloat16;
};
template<>
struct Vec<__mt_bfloat16, 2> {
using Type = __mt_bfloat162;
};
template<>
struct Vec<__mt_bfloat16, 4> {
using Type = bf16_4_t;
};
template<>
struct Vec<__mt_bfloat16, 8> {
using Type = bf16_8_t;
};
// FP32 accumulator vector types corresponding to Vec.
template<>
struct FloatVec<__mt_bfloat16> {
using Type = float;
};
template<>
struct FloatVec<__mt_bfloat162> {
using Type = float2;
};
template<>
struct FloatVec<bf16_4_t> {
using Type = Float4_;
};
template<>
struct FloatVec<bf16_8_t> {
using Type = Float8_;
};
// Utility functions for type conversions.
inline __device__ float2 bf1622float2(const __mt_bfloat162 val) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
return __bfloat1622float2(val);
#endif
}
inline __device__ __mt_bfloat162 bf162bf162(const __mt_bfloat16 val) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
return __bfloat162bfloat162(val);
#endif
}
// Vector addition.
inline __device__ __mt_bfloat16 add(__mt_bfloat16 a, __mt_bfloat16 b) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
#ifndef USE_ROCM
return a + b;
#else
return __hadd(a, b);
#endif
#endif
}
inline __device__ __mt_bfloat162 add(__mt_bfloat162 a, __mt_bfloat162 b) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
return __hadd2(a, b);
#endif
}
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(__mt_bfloat162 a, float2 fb) {
float2 fa = bf1622float2(a);
return add(fa, fb);
}
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
// Vector multiplication.
template<>
inline __device__ __mt_bfloat16 mul(__mt_bfloat16 a, __mt_bfloat16 b) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
return __hmul(a, b);
#endif
}
template<>
inline __device__ __mt_bfloat162 mul(__mt_bfloat162 a, __mt_bfloat162 b) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
return __hmul2(a, b);
#endif
}
template<>
inline __device__ __mt_bfloat162 mul(__mt_bfloat16 a, __mt_bfloat162 b) {
return mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(bf162bf162(a), b);
}
template<>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.x, b.x);
c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.y, b.y);
return c;
}
template<>
inline __device__ bf16_4_t mul(__mt_bfloat16 a, bf16_4_t b) {
__mt_bfloat162 s = bf162bf162(a);
bf16_4_t c;
c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.x);
c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.y);
return c;
}
template<>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.x, b.x);
c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.y, b.y);
c.z = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.z, b.z);
c.w = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.w, b.w);
return c;
}
template<>
inline __device__ bf16_8_t mul(__mt_bfloat16 a, bf16_8_t b) {
__mt_bfloat162 s = bf162bf162(a);
bf16_8_t c;
c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.x);
c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.y);
c.z = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.z);
c.w = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.w);
return c;
}
template<>
inline __device__ float mul(__mt_bfloat16 a, __mt_bfloat16 b) {
float fa = __bfloat162float(a);
float fb = __bfloat162float(b);
return fa * fb;
}
template<>
inline __device__ float2 mul(__mt_bfloat162 a, __mt_bfloat162 b) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
template<>
inline __device__ float2 mul(__mt_bfloat16 a, __mt_bfloat162 b) {
return mul<float2, __mt_bfloat162, __mt_bfloat162>(bf162bf162(a), b);
}
template<>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc;
fc.x = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.x, b.x);
fc.y = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.y, b.y);
return fc;
}
template<>
inline __device__ Float4_ mul(__mt_bfloat16 a, bf16_4_t b) {
__mt_bfloat162 s = bf162bf162(a);
Float4_ fc;
fc.x = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.x);
fc.y = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.y);
return fc;
}
template<>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc;
fc.x = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.x, b.x);
fc.y = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.y, b.y);
fc.z = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.z, b.z);
fc.w = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.w, b.w);
return fc;
}
template<>
inline __device__ Float8_ mul(__mt_bfloat16 a, bf16_8_t b) {
__mt_bfloat162 s = bf162bf162(a);
Float8_ fc;
fc.x = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.x);
fc.y = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.y);
fc.z = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.z);
fc.w = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.w);
return fc;
}
// Vector fused multiply-add.
inline __device__ __mt_bfloat162 fma(__mt_bfloat162 a, __mt_bfloat162 b, __mt_bfloat162 c) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
return __hfma2(a, b, c);
#endif
}
inline __device__ __mt_bfloat162 fma(__mt_bfloat16 a, __mt_bfloat162 b, __mt_bfloat162 c) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
return __hfma2(bf162bf162(a), b, c);
#endif
}
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
bf16_4_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ bf16_4_t fma(__mt_bfloat16 a, bf16_4_t b, bf16_4_t c) {
__mt_bfloat162 s = bf162bf162(a);
bf16_4_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
bf16_8_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ bf16_8_t fma(__mt_bfloat16 a, bf16_8_t b, bf16_8_t c) {
__mt_bfloat162 s = bf162bf162(a);
bf16_8_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float fma(__mt_bfloat16 a, __mt_bfloat16 b, float fc) {
return __bfloat162float(a) * __bfloat162float(b) + fc;
}
inline __device__ float2 fma(__mt_bfloat162 a, __mt_bfloat162 b, float2 fc) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return fma(fa, fb, fc);
}
inline __device__ float2 fma(__mt_bfloat16 a, __mt_bfloat162 b, float2 fc) {
return fma(bf162bf162(a), b, fc);
}
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
inline __device__ Float4_ fma(__mt_bfloat16 a, bf16_4_t b, Float4_ fc) {
__mt_bfloat162 s = bf162bf162(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
inline __device__ Float8_ fma(__mt_bfloat16 a, bf16_8_t b, Float8_ fc) {
__mt_bfloat162 s = bf162bf162(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
// Vector sum.
template<>
inline __device__ float sum(__mt_bfloat16 v) {
return __bfloat162float(v);
}
template<>
inline __device__ float sum(__mt_bfloat162 v) {
float2 vf = bf1622float2(v);
return vf.x + vf.y;
}
template<>
inline __device__ float sum(bf16_4_t v) {
return sum(v.x) + sum(v.y);
}
template<>
inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
}
// From float32 to bfloat16.
inline __device__ void from_float(__mt_bfloat16& dst, float src) {
dst = __float2bfloat16(src);
}
inline __device__ void from_float(__mt_bfloat162& dst, float2 src) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
dst = __float22bfloat162_rn(src);
#endif
}
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
#endif
}
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
dst.z = __float22bfloat162_rn(src.z);
dst.w = __float22bfloat162_rn(src.w);
#endif
}
// From bfloat16 to float32.
inline __device__ float to_float(__mt_bfloat16 u) {
return __bfloat162float(u);
}
// Zero-out a variable.
inline __device__ void zero(__mt_bfloat16& dst) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
#endif
}
} // namespace vllm

View File

@@ -0,0 +1,503 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "attention_generic.muh"
#include "dtype_float32.muh"
#ifdef USE_ROCM
#include <hip/hip_fp16.h>
#endif
#include <stdint.h>
namespace vllm {
// FP16 vector types for Q, K, V.
template<>
struct Vec<uint16_t, 1> {
using Type = uint16_t;
};
template<>
struct Vec<uint16_t, 2> {
using Type = uint32_t;
};
template<>
struct Vec<uint16_t, 4> {
using Type = uint2;
};
template<>
struct Vec<uint16_t, 8> {
using Type = uint4;
};
// FP32 accumulator vector types corresponding to Vec.
template<>
struct FloatVec<uint16_t> {
using Type = float;
};
template<>
struct FloatVec<uint32_t> {
using Type = float2;
};
template<>
struct FloatVec<uint2> {
using Type = Float4_;
};
template<>
struct FloatVec<uint4> {
using Type = Float8_;
};
// Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) {
#ifndef USE_ROCM
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
#else
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u16[0] = a;
tmp.u16[1] = a;
return tmp.u32;
#endif
}
inline __device__ float half_to_float(uint16_t h) {
float f;
#ifndef USE_ROCM
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
#else
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
#endif
return f;
}
inline __device__ float2 half2_to_float2(uint32_t v) {
#ifndef USE_ROCM
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
#else
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u32 = v;
float2 ret;
ret.x = half_to_float(tmp.u16[0]);
ret.y = half_to_float(tmp.u16[1]);
return ret;
#endif
}
inline __device__ uint16_t float_to_half(float f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#ifndef USE_ROCM
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
#else
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
#endif
return tmp.u16[0];
}
inline __device__ uint32_t float2_to_half2(float2 f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#ifndef USE_ROCM
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
#else
tmp.u16[0] = float_to_half(f.x);
tmp.u16[1] = float_to_half(f.y);
#endif
return tmp.u32;
}
// Vector addition.
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
uint16_t c;
#ifndef USE_ROCM
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
#else
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
uint32_t c;
#ifndef USE_ROCM
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
#else
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}
inline __device__ uint2 add(uint2 a, uint2 b) {
uint2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ uint4 add(uint4 a, uint4 b) {
uint4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(uint32_t a, float2 fb) {
float2 fa = half2_to_float2(a);
return add(fa, fb);
}
inline __device__ Float4_ add(uint2 a, Float4_ fb) {
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
// Vector multiplication.
template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c;
#ifndef USE_ROCM
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
#else
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}
template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c;
#ifndef USE_ROCM
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
#else
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c;
}
template<>
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
}
template<>
inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
return c;
}
template<>
inline __device__ uint2 mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
return c;
}
template<>
inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
return c;
}
template<>
inline __device__ uint4 mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
return c;
}
template<>
inline __device__ float mul(uint16_t a, uint16_t b) {
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb;
}
template<>
inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb);
}
template<>
inline __device__ float2 mul(uint16_t a, uint32_t b) {
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
}
template<>
inline __device__ Float4_ mul(uint2 a, uint2 b) {
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
return fc;
}
template<>
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
return fc;
}
template<>
inline __device__ Float8_ mul(uint4 a, uint4 b) {
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
return fc;
}
template<>
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
return fc;
}
// Vector fused multiply-add.
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
#ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
#else
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
#endif
return d;
}
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
return fma(h0_h0(a), b, c);
}
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
uint2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
uint32_t s = h0_h0(a);
uint2 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
uint4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
uint32_t s = h0_h0(a);
uint4 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb + fc;
}
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return fma(fa, fb, fc);
}
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
return fma(h0_h0(a), b, fc);
}
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
uint32_t s = h0_h0(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
uint32_t s = h0_h0(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
// Vector sum.
template<>
inline __device__ float sum(uint16_t v) {
return half_to_float(v);
}
template<>
inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
template<>
inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y);
return sum(c);
}
template<>
inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
return sum(c);
}
// From float32 to float16.
inline __device__ void from_float(uint16_t& dst, float src) {
dst = float_to_half(src);
}
inline __device__ void from_float(uint32_t& dst, float2 src) {
dst = float2_to_half2(src);
}
inline __device__ void from_float(uint2& dst, Float4_ src) {
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
}
inline __device__ void from_float(uint4& dst, Float8_ src) {
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
dst.z = float2_to_half2(src.z);
dst.w = float2_to_half2(src.w);
}
// From float16 to float32.
inline __device__ float to_float(uint16_t u) {
return half_to_float(u);
}
inline __device__ float2 to_float(uint32_t u) {
return half2_to_float2(u);
}
inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
return tmp;
}
inline __device__ Float8_ to_float(uint4 u) {
Float8_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
tmp.z = half2_to_float2(u.z);
tmp.w = half2_to_float2(u.w);
return tmp;
}
// Zero-out a variable.
inline __device__ void zero(uint16_t& dst) {
dst = uint16_t(0);
}
} // namespace vllm

View File

@@ -0,0 +1,274 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "attention_generic.muh"
#include <stdint.h>
namespace vllm {
// Define custom FP32 vector data types.
struct Float4_ {
float2 x;
float2 y;
};
struct Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
// FP32 vector types for Q, K, V.
template<>
struct Vec<float, 1> {
using Type = float;
};
template<>
struct Vec<float, 2> {
using Type = float2;
};
template<>
struct Vec<float, 4> {
using Type = float4;
};
// FP32 accumulator vector types corresponding to Vec.
template<>
struct FloatVec<float> {
using Type = float;
};
template<>
struct FloatVec<float2> {
using Type = float2;
};
template<>
struct FloatVec<float4> {
using Type = float4;
};
// Vector addition.
inline __device__ float add(float a, float b) {
return a + b;
}
inline __device__ float2 add(float2 a, float2 b) {
float2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ float4 add(float4 a, float4 b) {
float4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
// Vector multiplication.
template<>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}
template<>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
template<>
inline __device__ float2 mul(float a, float2 b) {
float2 c;
c.x = a * b.x;
c.y = a * b.y;
return c;
}
template<>
inline __device__ float4 mul(float4 a, float4 b) {
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
template<>
inline __device__ float4 mul(float a, float4 b) {
float4 c;
c.x = a * b.x;
c.y = a * b.y;
c.z = a * b.z;
c.w = a * b.w;
return c;
}
// Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) {
return a * b + c;
}
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ float2 fma(float a, float2 b, float2 c) {
float2 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
float4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ float4 fma(float a, float4 b, float4 c) {
float4 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
Float4_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
Float8_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
// Vector sum.
template<>
inline __device__ float sum(float v) {
return v;
}
template<>
inline __device__ float sum(float2 v) {
return v.x + v.y;
}
template<>
inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w;
}
template<>
inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y;
}
template<>
inline __device__ float sum(Float8_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
}
// Vector dot product.
inline __device__ float dot(float a, float b) {
return a * b;
}
inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b);
return c.x + c.y;
}
inline __device__ float dot(Float4_ a, Float4_ b) {
float2 acc = mul<float2, float2, float2>(a.x, b.x);
acc = fma(a.y, b.y, acc);
return acc.x + acc.y;
}
inline __device__ float dot(Float8_ a, Float8_ b) {
float2 acc = mul<float2, float2, float2>(a.x, b.x);
acc = fma(a.y, b.y, acc);
acc = fma(a.z, b.z, acc);
acc = fma(a.w, b.w, acc);
return acc.x + acc.y;
}
// From float to float.
inline __device__ void from_float(float& dst, float src) {
dst = src;
}
inline __device__ void from_float(float2& dst, float2 src) {
dst = src;
}
inline __device__ void from_float(float4& dst, float4 src) {
dst = src;
}
// From float to float.
inline __device__ float to_float(float u) {
return u;
}
inline __device__ float2 to_float(float2 u) {
return u;
}
inline __device__ float4 to_float(float4 u) {
return u;
}
inline __device__ Float4_ to_float(Float4_ u) {
return u;
}
inline __device__ Float8_ to_float(Float8_ u) {
return u;
}
// Zero-out a variable.
inline __device__ void zero(float& dst) {
dst = 0.f;
}
} // namespace vllm

View File

@@ -0,0 +1,35 @@
#pragma once
#include "attention_generic.muh"
#include <stdint.h>
#ifdef ENABLE_FP8_E5M2
#include <cuda_fp8.h>
#endif
namespace vllm {
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
// fp8 vector types for quantization of kv cache
template<>
struct Vec<uint8_t, 1> {
using Type = uint8_t;
};
template<>
struct Vec<uint8_t, 2> {
using Type = uint16_t;
};
template<>
struct Vec<uint8_t, 4> {
using Type = uint32_t;
};
template<>
struct Vec<uint8_t, 8> {
using Type = uint2;
};
#endif // ENABLE_FP8_E5M2
} // namespace vllm

38
csrc_musa/cache.h Normal file
View File

@@ -0,0 +1,38 @@
#pragma once
#include <torch/extension.h>
#include <map>
#include <vector>
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
void reshape_and_cache(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const float kv_scale);
void reshape_and_cache_flash(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
// Just for unittest
void convert_fp8(
torch::Tensor& src_cache,
torch::Tensor& dst_cache);

419
csrc_musa/cache_kernels.mu Normal file
View File

@@ -0,0 +1,419 @@
#include <torch/extension.h>
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include "torch_musa/csrc/core/MUSAGuard.h"
#include "musa_compat.h"
#include "dispatch_utils.h"
#if defined(ENABLE_FP8_E5M2)
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#elif defined(ENABLE_FP8_E4M3)
#include "quantization/fp8/amd_detail/quant_utils.cuh"
#endif
#include <algorithm>
#include <cassert>
#include <map>
#include <vector>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __mt_bfloat16;
#endif
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
musaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK(
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
memcpy_type = musaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = musaMemcpyDeviceToHost;
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
memcpy_type = musaMemcpyHostToDevice;
} else {
TORCH_CHECK(false, "Invalid device combination");
}
char *src_ptr = static_cast<char*>(src.data_ptr());
char *dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::musa::OptionalMUSAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
const musaStream_t stream = at::musa::getCurrentMUSAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
int64_t dst_block_number = pair.second;
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
musaMemcpyAsync(
dst_ptr + dst_offset,
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
}
}
namespace vllm {
// Grid: (num_layers, num_pairs)
template<typename scalar_t>
__global__ void copy_blocks_kernel(
int64_t* key_cache_ptrs,
int64_t* value_cache_ptrs,
const int64_t* __restrict__ block_mapping,
const int numel_per_block) {
const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y;
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int64_t src_block_number = block_mapping[2 * pair_idx];
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
const int64_t src_block_offset = src_block_number * numel_per_block;
const int64_t dst_block_offset = dst_block_number * numel_per_block;
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
int64_t src_offset = src_block_offset + i;
int64_t dst_offset = dst_block_offset + i;
key_cache[dst_offset] = key_cache[src_offset];
}
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
int64_t src_offset = src_block_offset + i;
int64_t dst_offset = dst_block_offset + i;
value_cache[dst_offset] = value_cache[src_offset];
}
}
} // namespace vllm
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
torch::Device cache_device = key_caches[0].device();
TORCH_CHECK(cache_device.is_cuda());
// Create data structures for the kernel.
// Create an array of pointers to the key and value caches.
int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}
// Create block mapping array.
std::vector<int64_t> block_mapping_vec;
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
}
int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2;
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor block_mapping_tensor = torch::from_blob(
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
// Launch the kernel.
const int numel_per_block = key_caches[0][0].numel();
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block));
const at::musa::OptionalMUSAGuard device_guard(cache_device);
const musaStream_t stream = at::musa::getCurrentMUSAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int64_t>(),
numel_per_block);
}));
}
namespace vllm {
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size,
const int x,
const float kv_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (is_fp8_kv_cache) {
#if defined(ENABLE_FP8_E5M2)
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#elif defined(ENABLE_FP8_E4M3)
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
#else
assert(false);
#endif
} else {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
}
}
}
template<typename scalar_t>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride,
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride
+ block_offset * num_heads * head_size
+ head_idx * head_size
+ head_offset;
k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx];
}
}
} // namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, \
value_stride, \
num_heads, \
head_size, \
block_size, \
x, \
kv_scale);
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype,
const float kv_scale)
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::musa::OptionalMUSAGuard device_guard(device_of(key));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
if (kv_cache_dtype == "auto") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, float, false);
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__mt_bfloat16, __mt_bfloat16, false);
}
} else if (kv_cache_dtype == "fp8") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__mt_bfloat16, uint8_t, true);
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
}
void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype)
{
// FIXME: only support auto datatype, does not support fp8
if (kv_cache_dtype != "auto") {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = k_cache.size(1);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = k_cache.stride(0);
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::musa::OptionalMUSAGuard device_guard(device_of(key));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"reshape_and_cache_flash",
[&] {
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
k_cache.data_ptr<scalar_t>(),
v_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(),
block_stride,
key_stride,
value_stride,
num_heads,
head_size,
block_size);
});
}
namespace vllm {
template<typename Tout, typename Tin>
__global__ void convert_fp8_kernel(
const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i;
#if defined(ENABLE_FP8_E5M2)
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
#elif defined(ENABLE_FP8_E4M3)
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
#else
assert(false);
#endif
}
}
} // namespace vllm
#define CALL_CONVERT_FP8(Tout, Tin) \
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
block_stride);
void convert_fp8(
torch::Tensor& src_cache,
torch::Tensor& dst_cache)
{
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
TORCH_CHECK(
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
at::musa::OptionalMUSAGuard device_guard(src_device);
int64_t num_blocks = src_cache.size(0);
int64_t block_stride = src_cache.stride(0);
dim3 grid(num_blocks);
dim3 block(std::min(block_stride, int64_t(512)));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __mt_bfloat16);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__mt_bfloat16, uint8_t);
}
}

View File

@@ -0,0 +1,148 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &),
bool is_gated>
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
scalar_t *__restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(d % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
for (int j = 0; j < d; j += VEC_ELEM_NUM) {
int start = i * d;
if constexpr (is_gated) {
start *= 2;
}
const scalar_vec_t x(input + start + j);
const vec_op::FP32Vec8 f32_x(x);
vec_op::FP32Vec8 f32_ans = func(f32_x);
if constexpr (is_gated) {
const scalar_vec_t y(input + start + d + j);
const vec_op::FP32Vec8 f32_y(y);
f32_ans = f32_y * f32_ans;
}
const scalar_vec_t result(f32_ans);
result.save(output + i * d + j);
}
}
}
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0);
return x / (ones + (zeros - x).exp());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
const vec_op::FP32Vec8 w3(0.5);
const vec_op::FP32Vec8 x3 = x * x * x;
const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
return w3 * x * (ones + t);
}
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
const vec_op::FP32Vec8 w3(0.5);
const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
return w3 * x * (ones + t);
}
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2);
const vec_op::FP32Vec8 w2(0.5);
return x * w2 * (ones + (x * w1).er());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
const vec_op::FP32Vec8 w2(0.5);
const vec_op::FP32Vec8 w3(0.044715);
const vec_op::FP32Vec8 x_3 = x * x * x;
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
return x * w2 * (ones + inner.tanh());
}
}; // namespace
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
}
void gelu_and_mul(torch::Tensor &out, // [..., d]
torch::Tensor &input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "gelu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
}
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
torch::Tensor &input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
activation_kernel<scalar_t, gelu_tanh_act, true>(
num_tokens, d, input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
});
}
void gelu_new(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_new_impl)
activation_kernel<scalar_t, gelu_new_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_new_impl)
});
}
void gelu_fast(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_fast_impl)
activation_kernel<scalar_t, gelu_fast_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
});
}

746
csrc_musa/cpu/attention.cpp Normal file
View File

@@ -0,0 +1,746 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t> struct KernelVecType {
using q_load_vec_type = void;
using q_vec_type = void;
using k_load_vec_type = void;
using k_vec_type = void;
using qk_acc_vec_type = void;
using v_load_vec_type = void;
};
template <> struct KernelVecType<float> {
using q_load_vec_type = vec_op::FP32Vec4;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP32Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::FP32Vec16;
};
#ifdef __AVX512BF16__
template <> struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::BF16Vec32;
using k_load_vec_type = vec_op::BF16Vec32;
using k_vec_type = vec_op::BF16Vec32;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#else
template <> struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16;
using k_vec_type = vec_op::FP32Vec16;
using qk_acc_vec_type = vec_op::FP32Vec16;
using v_load_vec_type = vec_op::BF16Vec16;
};
#endif
template <typename T>
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
const int capacity) {
T max = data[0];
for (int i = 1; i < size; ++i) {
max = max >= data[i] ? max : data[i];
}
T sum = 0;
for (int i = 0; i < size; ++i) {
data[i] = std::exp(data[i] - max);
sum += data[i];
}
int i = 0;
for (; i < size; ++i) {
data[i] /= sum;
}
for (; i < capacity; ++i) {
data[i] = 0;
}
return {max, sum};
}
template <typename T>
FORCE_INLINE std::pair<T, T>
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
const float alibi_slope, const int start_index,
const int seq_len) {
data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0];
for (int i = 1; i < size; ++i) {
T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
data[i] = qk;
max = max >= qk ? max : qk;
}
T sum = 0;
for (int i = 0; i < size; ++i) {
data[i] = std::exp(data[i] - max);
sum += data[i];
}
int i = 0;
for (; i < size; ++i) {
data[i] /= sum;
}
for (; i < capacity; ++i) {
data[i] = 0;
}
return {max, sum};
}
template <typename T>
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data,
const int size) {
T max = max_data[0];
for (int i = 1; i < size; ++i) {
max = max >= max_data[i] ? max : max_data[i];
}
T rescaled_sum = 0;
for (int i = 0; i < size; ++i) {
T rescale_factor = std::exp(max_data[i] - max);
rescaled_sum += rescale_factor * sum_data[i];
sum_data[i] *= rescale_factor;
}
for (int i = 0; i < size; ++i) {
sum_data[i] /= rescaled_sum + 1e-8;
}
}
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int x>
struct reduceQKBlockKernel {
using q_load_vec_type = typename KernelVecType<scalar_t>::q_load_vec_type;
using q_vec_type = typename KernelVecType<scalar_t>::q_vec_type;
using k_load_vec_type = typename KernelVecType<scalar_t>::k_load_vec_type;
using k_vec_type = typename KernelVecType<scalar_t>::k_vec_type;
using qk_acc_vec_type = typename KernelVecType<scalar_t>::qk_acc_vec_type;
constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;
static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
static_assert(k_load_vec_type::get_elem_num() % x == 0);
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
FORCE_INLINE static void call(const scalar_t *__restrict__ q,
const scalar_t *__restrict__ k_block,
float *__restrict__ logits, float scale,
const int token_num) {
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
qk_acc_vec_type group_accums[MAX_GROUP_NUM];
if (token_num == BLOCK_SIZE) {
for (int q_offset = 0; q_offset < HEAD_SIZE;
q_offset += x, k_block += x * BLOCK_SIZE) {
q_load_vec_type q_load_group_vec(q + q_offset);
q_vec_type q_group_vec(q_load_group_vec);
vec_op::unroll_loop<int, MAX_GROUP_NUM>(
[k_block, &q_group_vec, &group_accums](int token_group_idx) {
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
TOKEN_PER_GROUP);
k_vec_type k_group_vec(k_load_group_vec);
vec_op::fma(group_accums[token_group_idx], q_group_vec,
k_group_vec);
vec_op::prefetch(k_block + x * BLOCK_SIZE +
token_group_idx * x * TOKEN_PER_GROUP);
});
}
} else {
for (int q_offset = 0; q_offset < HEAD_SIZE;
q_offset += x, k_block += x * BLOCK_SIZE) {
q_load_vec_type q_load_group_vec(q + q_offset);
q_vec_type q_group_vec(q_load_group_vec);
for (int token_group_start = 0; token_group_start < group_num;
token_group_start += UNROLL_GROUP_NUM) {
vec_op::unroll_loop<int, UNROLL_GROUP_NUM>(
[token_group_start, k_block, &q_group_vec,
&group_accums](int token_group_idx) {
token_group_idx += token_group_start;
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
TOKEN_PER_GROUP);
k_vec_type k_group_vec(k_load_group_vec);
vec_op::fma(group_accums[token_group_idx], q_group_vec,
k_group_vec);
vec_op::prefetch(k_block + x * BLOCK_SIZE +
token_group_idx * x * TOKEN_PER_GROUP);
});
}
}
}
for (int token_group_idx = 0; token_group_idx < group_num;
++token_group_idx) {
vec_op::unroll_loop<int, TOKEN_PER_GROUP>(
[&group_accums, logits, scale, token_group_idx](int token_idx) {
float dot_v =
group_accums[token_group_idx]
.template reduce_sub_sum<qk_acc_vec_type::get_elem_num() /
TOKEN_PER_GROUP>(token_idx);
logits[token_group_idx * TOKEN_PER_GROUP + token_idx] =
dot_v * scale;
});
}
}
};
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
int HEAD_PARTITION_SIZE, typename acc_t>
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
acc_t &&acc) {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
static_assert(BLOCK_SIZE == ELEM_NUM);
vec_op::FP32Vec16 prob_vec(prob);
vec_op::unroll_loop<int, HEAD_PARTITION_SIZE>([&](int head_elem_idx) {
v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx);
vec_op::FP32Vec16 fp32_v_vec(v_vec);
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
});
}
}; // namespace
// Paged attention v1
namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
struct paged_attention_v1_impl {
static void
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads) {
constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads;
static_assert(BLOCK_SIZE == 16);
int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
const int parallel_work_item_num = omp_get_max_threads();
size_t logits_bytes =
parallel_work_item_num * max_seq_len_padded * sizeof(float);
float *logits = (float *)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_seq_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int seq_len = seq_lens[seq_idx];
const int *seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num =
seq_len - (block_num - 1) * BLOCK_SIZE;
float *__restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_seq_len_padded;
// Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits =
thread_block_logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
}
// Compute softmax
if (alibi_slopes) {
reduceSoftmaxAlibi(thread_block_logits, seq_len,
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
seq_len);
} else {
reduceSoftmax(thread_block_logits, seq_len,
block_num * BLOCK_SIZE);
}
// Compute value
constexpr int head_elem_num_per_partition = 16;
constexpr int head_partition_num =
HEAD_SIZE / head_elem_num_per_partition;
for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr =
thread_block_logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
head_elem_num_per_partition>(
prob_vec_ptr, v_block_cache_ptr, accums);
if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
if (head_elem_idx % 2 == 0) {
vec_op::prefetch(next_v_block_cache_ptr +
BLOCK_SIZE * head_elem_idx);
}
});
}
}
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
float value = accums[head_elem_idx].reduce_sum();
vec_op::storeFP32(value, out_ptr + head_elem_idx);
});
}
}
}
std::free(logits);
}
};
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads);
template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher(
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &seq_lens,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
: nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) {
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 256:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V1_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
} // namespace
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch::Tensor &key_cache, torch::Tensor &value_cache,
int num_kv_heads, float scale,
torch::Tensor &block_tables,
torch::Tensor &seq_lens, int block_size,
int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
});
}
// Paged attention v2
namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
struct paged_attention_v2_impl {
static void call(
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads, const int max_num_partitions) {
constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads;
static_assert(BLOCK_SIZE == 16);
static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0);
static_assert(PARTITION_SIZE % BLOCK_SIZE == 0);
#pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int partition_idx = 0; partition_idx < max_num_partitions;
++partition_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= seq_len)
continue;
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
const bool no_reduce = (partition_num == 1);
const int token_num =
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx);
const int block_num =
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num =
token_num - (block_num - 1) * BLOCK_SIZE;
const int *seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
// Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits =
logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
}
std::pair<float, float> max_and_sum;
if (alibi_slopes) {
max_and_sum = reduceSoftmaxAlibi(
logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, seq_len);
} else {
max_and_sum = reduceSoftmax(logits, token_num,
block_num * BLOCK_SIZE);
}
auto &&[max_logit, exp_sum] = max_and_sum;
scalar_t *__restrict__ output_buffer = nullptr;
if (!no_reduce) {
auto idx = seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
max_logits[idx] = max_logit;
exp_sums[idx] = exp_sum;
output_buffer =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE +
partition_idx * HEAD_SIZE;
} else {
output_buffer =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
}
// Compute value
constexpr int head_elem_num_per_partition = 16;
constexpr int head_partition_num =
HEAD_SIZE / head_elem_num_per_partition;
for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr =
output_buffer + head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr =
logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
head_elem_num_per_partition>(
prob_vec_ptr, v_block_cache_ptr, accums);
if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
if (head_elem_idx % 2 == 0) {
vec_op::prefetch(next_v_block_cache_ptr +
BLOCK_SIZE * head_elem_idx);
}
});
}
}
vec_op::unroll_loop<int, head_elem_num_per_partition>(
[&](int head_elem_idx) {
float value = accums[head_elem_idx].reduce_sum();
vec_op::storeFP32(value, out_ptr + head_elem_idx);
});
}
}
}
}
// Rescale partition softmax and store the factors to exp_sums
#pragma omp parallel for collapse(2) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int seq_len = seq_lens[seq_idx];
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1)
continue;
reducePartitonSoftmax(
max_logits + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions,
exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions,
partition_num);
}
}
// Reduce values
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
constexpr int head_elem_num_per_group =
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
// didn't align with 64 bytes
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
const float *__restrict__ rescale_factors = exp_sums;
#pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
const int seq_len = seq_lens[seq_idx];
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1)
continue;
const float *__restrict__ seq_head_rescale_factors =
rescale_factors + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
const scalar_t *__restrict__ seq_head_tmp_out =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE +
group_idx * head_elem_num_per_group;
scalar_t *__restrict__ seq_head_output =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
group_idx * head_elem_num_per_group;
vec_op::FP32Vec16 acc;
for (int i = 0; i < partition_num; ++i) {
vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]);
v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE);
vec_op::FP32Vec16 fp32_value(value);
acc = acc + fp32_value * rescale_factor;
}
v_load_vec_type cast_acc(acc);
cast_acc.save(seq_head_output);
}
}
}
}
};
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions);
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
void paged_attention_v2_impl_launcher(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int max_num_partitions = exp_sums.size(-1);
// NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
: nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr());
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr());
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>();
int *seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) {
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 256:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, block_size, \
max_seq_len, alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V2_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
} // namespace
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch::Tensor &max_logits, torch::Tensor &tmp_out,
torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads,
float scale, torch::Tensor &block_tables,
torch::Tensor &seq_lens, int block_size,
int max_seq_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
});
}

141
csrc_musa/cpu/cache.cpp Normal file
View File

@@ -0,0 +1,141 @@
#include <map>
#include <vector>
#include "cpu_types.hpp"
namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(
std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches,
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
const int element_num_per_block, const int layer_num) {
const size_t pair_num = mapping_pairs.size();
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
int64_t target_offset =
element_num_per_block * mapping_pairs[pair].second;
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t *source_ptr = key_cache_ptr + source_offset;
scalar_t *target_ptr = key_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
source_ptr = value_cache_ptr + source_offset;
target_ptr = value_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
}
}
}
template <typename scalar_t>
void reshape_and_cache_cpu_impl(
const scalar_t *__restrict__ key, const scalar_t *__restrict__ value,
scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache,
const int64_t *__restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) {
const int block_elem_num = num_heads * head_size * block_size;
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx >= 0) {
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
int src_value_head_idx =
token_idx * value_stride + head_idx * head_size;
const scalar_t *src_key_head_ptr = key + src_key_head_idx;
const scalar_t *src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
scalar_t *target_key_head_ptr = key_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
scalar_t *target_value_head_ptr = value_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
const int64_t target_offset =
src_key_idx * block_size + block_offset * x;
for (int i = 0; i < x; ++i) {
target_key_head_ptr[target_offset + i] =
src_key_head_ptr[src_key_idx + i];
}
}
for (int src_value_idx = 0; src_value_idx < head_size;
++src_value_idx) {
const int64_t target_offset =
src_value_idx * block_size + block_offset;
target_value_head_ptr[target_offset] =
src_value_head_ptr[src_value_idx];
}
}
}
}
}
}; // namespace
void copy_blocks(std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches,
const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
mapping_pairs.reserve(block_mapping.size());
for (const auto &pair : block_mapping) {
for (const auto &dst : pair.second) {
mapping_pairs.emplace_back(pair.first, dst);
}
}
const int element_num_per_block = key_caches[0][0].numel();
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
});
}
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
torch::Tensor &key_cache, torch::Tensor &value_cache,
torch::Tensor &slot_mapping,
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
reshape_and_cache_cpu_impl<scalar_t>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
value_stride, num_heads, head_size, block_size, x);
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
});
}
void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
const std::map<int64_t, int64_t> &block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}

352
csrc_musa/cpu/cpu_types.hpp Normal file
View File

@@ -0,0 +1,352 @@
#ifndef CPU_TYPES_HPP
#define CPU_TYPES_HPP
#include <immintrin.h>
#include <torch/extension.h>
namespace vec_op {
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
(f(std::integral_constant<T, indexes>{}), ...);
}
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F &&f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T> struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
struct FP32Vec8;
struct FP32Vec16;
#ifdef __AVX512FP16__
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__m128h reg;
explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
explicit FP16Vec8(__m128h data) : reg(data) {}
FP16Vec8 operator*(const FP16Vec8 &b) const {
return FP16Vec8(_mm_mul_ph(reg, b.reg));
}
FP16Vec8 operator+(const FP16Vec8 &b) const {
return FP16Vec8(_mm_add_ph(reg, b.reg));
}
FP16Vec8 operator-(const FP16Vec8 &b) const {
return FP16Vec8(_mm_sub_ph(reg, b.reg));
}
FP16Vec8 operator/(const FP16Vec8 &b) const {
return FP16Vec8(_mm_div_ph(reg, b.reg));
}
void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
};
#endif
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__m128i reg;
explicit BF16Vec8(const void *ptr)
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
explicit BF16Vec8(const FP32Vec8 &);
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
__m256i reg;
explicit BF16Vec16(const void *ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
explicit BF16Vec16(const FP32Vec16 &);
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
};
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
__m512i reg;
explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
explicit BF16Vec32(__m512i data) : reg(data) {}
explicit BF16Vec32(BF16Vec8 &vec8_data)
: reg((__m512i)_mm512_inserti32x4(
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
(__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1),
(__m128i)vec8_data.reg, 2),
(__m128i)vec8_data.reg, 3)) {}
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
};
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
__m128 reg;
float values[VEC_ELEM_NUM];
};
__m128 reg;
explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
explicit FP32Vec4(__m128 data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
__m256 reg;
float values[VEC_ELEM_NUM];
};
__m256 reg;
explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
explicit FP32Vec8(__m256 data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
#ifdef __AVX512FP16__
explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
#endif
explicit FP32Vec8(const BF16Vec8 &v)
: reg(_mm256_castsi256_ps(
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
return result;
}
FP32Vec8 exp() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
expf(ar.values[5]), expf(ar.values[4]),
expf(ar.values[3]), expf(ar.values[2]),
expf(ar.values[1]), expf(ar.values[0])));
}
FP32Vec8 tanh() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
tanhf(ar.values[5]), tanhf(ar.values[4]),
tanhf(ar.values[3]), tanhf(ar.values[2]),
tanhf(ar.values[1]), tanhf(ar.values[0])));
}
FP32Vec8 er() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
erf(ar.values[5]), erf(ar.values[4]),
erf(ar.values[3]), erf(ar.values[2]),
erf(ar.values[1]), erf(ar.values[0])));
}
FP32Vec8 operator*(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
}
FP32Vec8 operator+(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_add_ps(reg, b.reg));
}
FP32Vec8 operator-(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
}
FP32Vec8 operator/(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_div_ps(reg, b.reg));
}
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m512 reg;
float values[VEC_ELEM_NUM];
};
__m512 reg;
explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
explicit FP32Vec16(__m512 data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
explicit FP32Vec16(const FP32Vec4 &data)
: reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4(
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
(__m128i)data.reg, 1),
(__m128i)data.reg, 2),
(__m128i)data.reg, 3)) {}
explicit FP32Vec16(const FP32Vec8 &data)
: reg((__m512)_mm512_inserti32x8(
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
explicit FP32Vec16(const BF16Vec16 &v)
: reg(_mm512_castsi512_ps(
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
}
FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_add_ps(reg, b.reg));
}
FP32Vec16 operator-(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
}
FP32Vec16 operator/(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_div_ps(reg, b.reg));
}
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
template <int group_size> float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
return _mm512_mask_reduce_add_ps(mask, reg);
}
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
};
template <typename T> struct VecType { using vec_type = void; };
template <typename T> using vec_t = typename VecType<T>::vec_type;
template <> struct VecType<float> { using vec_type = FP32Vec8; };
#ifdef __AVX512FP16__
template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
#endif
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
#ifdef __AVX512FP16__
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
*reinterpret_cast<_Float16 *>(ptr) = v;
}
#endif
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
acc = acc + a * b;
}
#ifdef __AVX512BF16__
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
}
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
}
#else
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
reinterpret_cast<c10::BFloat16 *>(&v);
*ptr = *(v_ptr + 1);
}
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg(_mm256_cvtepi32_epi16(
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
: reg(_mm512_cvtepi32_epi16(
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
#endif
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
}; // namespace vec_op
#endif

117
csrc_musa/cpu/layernorm.cpp Normal file
View File

@@ -0,0 +1,117 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t>
void rms_norm_impl(scalar_t *__restrict__ out,
const scalar_t *__restrict__ input,
const scalar_t *__restrict__ weight, const float epsilon,
const int num_tokens, const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
vec_op::FP32Vec8 variance(0.0);
auto input_p = input + i * hidden_size;
auto output_p = out + i * hidden_size;
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t x(input_p + j);
vec_op::FP32Vec8 fp32_x(x);
variance = variance + fp32_x * fp32_x;
}
float s_variance =
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
vec_op::FP32Vec8 fp32_s_variance(s_variance);
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t x(input_p + j);
scalar_vec_t w(weight + j);
vec_op::FP32Vec8 fp32_x(x);
vec_op::FP32Vec8 fp32_w(w);
vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w;
scalar_vec_t out(fp32_out);
out.save(output_p + j);
}
}
}
template <typename scalar_t>
void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
scalar_t *__restrict__ residual,
const scalar_t *__restrict__ weight,
const float epsilon, const int num_tokens,
const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
vec_op::FP32Vec8 variance(0.0);
auto input_p = input + i * hidden_size;
auto residual_p = residual + i * hidden_size;
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t x(input_p + j);
scalar_vec_t res(residual_p + j);
vec_op::FP32Vec8 fp32_x(x);
vec_op::FP32Vec8 fp32_res(res);
fp32_x = fp32_x + fp32_res;
variance = variance + fp32_x * fp32_x;
scalar_vec_t out(fp32_x);
out.save(residual_p + j);
}
float s_variance =
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
vec_op::FP32Vec8 fp32_s_variance(s_variance);
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
scalar_vec_t w(weight + j);
scalar_vec_t res(residual_p + j);
vec_op::FP32Vec8 fp32_w(w);
vec_op::FP32Vec8 fp32_res(res);
vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w;
scalar_vec_t out(fp32_out);
out.save(input_p + j);
}
}
}
} // namespace
void rms_norm(torch::Tensor &out, torch::Tensor &input,
torch::Tensor &weight, float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
CPU_KERNEL_GUARD_IN(rms_norm_impl)
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size);
CPU_KERNEL_GUARD_OUT(rms_norm_impl)
});
}
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual,
torch::Tensor &weight, float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "fused_add_rms_norm_impl", [&] {
CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl)
fused_add_rms_norm_impl(
input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl)
});
}

View File

@@ -0,0 +1,199 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t>
void rotary_embedding_impl(
const int64_t
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
/// [num_tokens, num_heads, head_size]
scalar_t
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
// [num_tokens, num_kv_heads, head_size]
const scalar_t
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
constexpr int ELEM_SIZE = sizeof(scalar_t);
const int embed_dim = rot_dim / 2;
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
#pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
for (int i = 0; i < num_heads; ++i) {
const int head_idx = i;
const int64_t token_head =
token_idx * query_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t q_x(query + out_x);
const scalar_vec_t q_y(query + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_q_x(q_x);
vec_op::FP32Vec8 fp32_q_y(q_y);
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
scalar_vec_t(out1).save(query + out_x);
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
scalar_vec_t(out2).save(query + out_y);
}
}
for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t k_x(key + out_x);
const scalar_vec_t k_y(key + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_k_x(k_x);
vec_op::FP32Vec8 fp32_k_y(k_y);
auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
scalar_vec_t(out1).save(key + out_x);
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
scalar_vec_t(out2).save(key + out_y);
}
}
}
}
template <typename scalar_t>
void rotary_embedding_gptj_impl(
const int64_t
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
/// [num_tokens, num_heads, head_size]
scalar_t
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
// [num_tokens, num_kv_heads, head_size]
const scalar_t
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) {
const int embed_dim = rot_dim / 2;
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_heads; ++i) {
int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i;
const int64_t token_head =
token_idx * query_stride + head_idx * head_size;
scalar_t *head_query = token_head + query;
for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j;
const int x_index = 2 * rot_offset;
const int y_index = 2 * rot_offset + 1;
const float cos = cos_cache_ptr[rot_offset];
const float sin = sin_cache_ptr[rot_offset];
const float x = head_query[x_index];
const float y = head_query[y_index];
head_query[x_index] = x * cos - y * sin;
head_query[y_index] = y * cos + x * sin;
}
}
}
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) {
int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
scalar_t *head_key = key + token_head;
for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j;
const int x_index = 2 * rot_offset;
const int y_index = 2 * rot_offset + 1;
const float cos = cos_cache_ptr[rot_offset];
const float sin = sin_cache_ptr[rot_offset];
const float x = head_key[x_index];
const float y = head_key[y_index];
head_key[x_index] = x * cos - y * sin;
head_key[y_index] = y * cos + x * sin;
}
}
}
}
}; // namespace
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
torch::Tensor &key, int head_size,
torch::Tensor &cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t key_stride = key.stride(-2);
int64_t query_stride = query.stride(-2);
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "rotary_embedding_impl", [&] {
CPU_KERNEL_GUARD_IN(rotary_embedding_impl)
if (is_neox) {
rotary_embedding_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
head_size, num_tokens);
} else {
rotary_embedding_gptj_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
head_size, num_tokens);
}
CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)
});
}

73
csrc_musa/cpu/pybind.cpp Normal file
View File

@@ -0,0 +1,73 @@
#include "cache.h"
#include "cuda_utils.h"
#include "ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def(
"paged_attention_v1",
&paged_attention_v1,
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops
ops.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm
ops.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def(
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
}

View File

@@ -0,0 +1,148 @@
#include "torch_musa/csrc/core/MUSAException.h"
#include "torch_musa/csrc/core/MUSAGuard.h"
#include "torch_musa/csrc/core/MUSAStream.h"
#include <torch/extension.h>
#include "custom_all_reduce.muh"
// fake pointer type
using fptr_t = uint64_t;
static_assert(sizeof(void *) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets, int rank,
bool full_nvlink) {
int world_size = offsets.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (world_size != handles.size())
throw std::invalid_argument(
"handles length should equal to offsets length");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
musaIpcMemHandle_t ipc_handles[8];
for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(musaIpcMemHandle_t));
}
return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}
/**
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
* because it allows transpose of contiguous slice (i.e. slicing the first
* dimension). Currently, we require this because stride information is not
* passed into the kernels and we treat input tensors as flat.
*
* Examples
* A = torch.zeros(3, 3, 3)
* 1. A: OK
* 2. A[1:]: OK
* 3. A.permute(2, 0, 1): OK
* 4. A[1:].permute(2, 0, 1): OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor &t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
}
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16
if (inp_size % 16 != 0) return false;
if (!_is_weak_contiguous(inp)) return false;
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
// performance improvement over NCCL.
return false;
}
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
musaStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
reinterpret_cast<float *>(out.data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
reinterpret_cast<half *>(out.data_ptr()),
out.numel());
break;
}
#if (__MUSA_ARCH__ >= 800 || !defined(__MUSA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<mt_bfloat16>(
stream, reinterpret_cast<mt_bfloat16 *>(inp.data_ptr()),
reinterpret_cast<mt_bfloat16 *>(out.data_ptr()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
const at::musa::OptionalMUSAGuard device_guard(device_of(inp));
auto stream = c10::musa::getCurrentMUSAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
_all_reduce(_fa, inp, out, stream);
}
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
torch::Tensor &out) {
const at::musa::OptionalMUSAGuard device_guard(device_of(inp));
auto stream = c10::musa::getCurrentMUSAStream().stream();
auto input_size = inp.numel() * inp.element_size();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
"registered buffer is too small to contain the input");
C10_MUSA_CHECK(musaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
input_size, musaMemcpyDeviceToDevice, stream));
_all_reduce(_fa, reg_buffer, out, stream);
}
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
delete fa;
}
int meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor &t,
const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
}
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
return fa->get_graph_buffer_ipc_meta();
}
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
fa->register_graph_buffers(handles, offsets);
}

View File

@@ -0,0 +1,485 @@
#pragma once
#include <musa.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#include <musa_runtime.h>
#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>
#define CUDACHECK(cmd) \
do { \
musaError_t e = cmd; \
if (e != musaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
musaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
namespace vllm {
constexpr int kMaxBlocks = 64;
// note: we don't want to use atomics for signals because peer atomics are no
// supported on PCIe links
struct Signal {
alignas(128) uint32_t start[kMaxBlocks][8];
alignas(128) uint32_t end[kMaxBlocks][8];
};
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; RankData& operator=(const RankData& ){return *this;} };
struct __align__(16) RankSignals { volatile Signal *signals[8]; };
// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
T data[sz];
using type = T;
static constexpr int size = sz;
};
// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
// the (P)acked type for load/store
using P = array_t<T, 16 / sizeof(T)>;
// the (A)ccumulator type for reduction
using A = array_t<float, 16 / sizeof(T)>;
};
#define DINLINE __device__ __forceinline__
// scalar cast functions
DINLINE float upcast_s(half val) { return __half2float(val); }
template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
return __float2half(val);
}
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE half &assign_add(half &a, half b) {
a = __hadd(a, b);
return a;
}
DINLINE float &assign_add(float &a, float b) { return a += b; }
#if (__MUSA_ARCH__ >= 800 || !defined(__MUSA_ARCH__))
DINLINE float upcast_s(mt_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE mt_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE mt_bfloat16 &assign_add(mt_bfloat16 &a, mt_bfloat16 b) {
a = __hadd(a, b);
return a;
}
#endif
template <typename T, int N>
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
#pragma unroll
for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]);
}
return a;
}
template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
if constexpr (std::is_same<T, float>::value) {
return val;
} else {
array_t<float, N> out;
#pragma unroll
for (int i = 0; i < N; i++) {
out.data[i] = upcast_s(val.data[i]);
}
return out;
}
}
template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
if constexpr (std::is_same<typename O::type, float>::value) {
return val;
} else {
O out;
#pragma unroll
for (int i = 0; i < O::size; i++) {
out.data[i] = downcast_s<typename O::type>(val.data[i]);
}
return out;
}
}
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
int rank) {
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->end[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x])
;
}
__syncthreads();
}
// This function is meant to be used as the second or the final synchronization
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
int rank) {
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
if constexpr (!final_sync) __threadfence_system();
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->start[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x])
;
}
if constexpr (!final_sync) __syncthreads();
}
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P *ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]);
#pragma unroll
for (int i = 1; i < ngpus; i++) {
packed_assign_add(tmp, upcast(ptrs[i][idx]));
}
return downcast<P>(tmp);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result,
int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
start_sync<ngpus>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
((P *)result)[idx] =
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
}
end_sync<ngpus, true>(sg, self_sg, rank);
}
template <typename P>
DINLINE P *get_tmp_buf(volatile Signal *sg) {
return (P *)(((Signal *)sg) + 1);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result,
int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P *ptrs[ngpus];
P *tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P *)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
start_sync<ngpus>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
end_sync<ngpus>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P *)result)[dst_idx] = tmps[i][idx];
}
}
}
}
using IPC_KEY = std::array<uint8_t, sizeof(musaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(musaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(musaIpcMemHandle_t));
class CustomAllreduce {
public:
int rank_;
int world_size_;
bool full_nvlink_;
// below are device pointers
RankSignals sg_;
std::unordered_map<void *, RankData *> buffers_;
Signal *self_sg_;
// stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void *> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char *> ipc_handles_;
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
*
* There's a total of sizeof(Signal) of prefix before the actual data,
* so meta + 1 points to actual temporary buffer.
*
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
const musaIpcMemHandle_t *handles,
const std::vector<int64_t> &offsets, int rank,
bool full_nvlink = true)
: rank_(rank),
world_size_(offsets.size()),
full_nvlink_(full_nvlink),
self_sg_(meta),
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
Signal *rank_sg;
if (i != rank_) {
char *handle = open_ipc_handle(&handles[i]);
handle += offsets[i];
rank_sg = (Signal *)handle;
} else {
rank_sg = self_sg_;
}
sg_.signals[i] = rank_sg;
}
}
char *open_ipc_handle(const void *ipc_handle) {
auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
if (new_handle) {
char *ipc_ptr;
CUDACHECK(musaIpcOpenMemHandle((void **)&ipc_ptr,
*((const musaIpcMemHandle_t *)ipc_handle),
musaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
std::pair<std::vector<uint8_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(musaIpcMemHandle_t);
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
void *base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (muPointerGetAttribute(&base_ptr,
MU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
(MUdeviceptr)ptr) != MUSA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(musaIpcGetMemHandle(
(musaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
}
return std::make_pair(handles, offsets);
}
void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error(
"Rank data buffer is overflowed by " +
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
void register_buffer(const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets, void *self) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
if (i != rank_) {
char *handle = open_ipc_handle(handles[i].data());
handle += offsets[i];
data.ptrs[i] = handle;
} else {
data.ptrs[i] = self;
}
}
auto d_data = d_rank_data_base_++;
CUDACHECK(
musaMemcpy(d_data, &data, sizeof(RankData), musaMemcpyHostToDevice));
buffers_[self] = d_data;
}
// note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void register_graph_buffers(
const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i];
auto &rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char *handle =
open_ipc_handle(&handles[j][i * sizeof(musaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
CUDACHECK(musaMemcpy(d_rank_data_base_, rank_data.data(),
sizeof(RankData) * num_buffers,
musaMemcpyHostToDevice));
d_rank_data_base_ += num_buffers;
graph_unreg_buffers_.clear();
}
/**
* This is the result after careful grid search. Using 36 blocks give the best
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
* Not quite sure the underlying reason, but my guess is that too many SMs
* will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(musaStream_t stream, T *input, T *output, int size,
int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " +
std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));
RankData *ptrs;
musaStreamCaptureStatus status;
CUDACHECK(at::musa::musaStreamIsCapturing(stream, &status));
if (status == musaStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " +
std::to_string(reinterpret_cast<uint64_t>(input)) +
" is not registered!");
ptrs = it->second;
}
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
(world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
~CustomAllreduce() {
for (auto [_, ptr] : ipc_handles_) {
CUDACHECK(musaIpcCloseMemHandle(ptr));
}
}
};
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* template void vllm::CustomAllreduce::allreduce<half>(musaStream_t, half *,
half *, int, int, int);
*/
} // namespace vllm

View File

@@ -0,0 +1,316 @@
/**
* This is a standalone test for custom allreduce.
* To compile, make sure you have MPI and NCCL installed in your system.
* export MPI_HOME=XXX
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
*
* Warning: this C++ test is not designed to be very readable and was used
* during the rapid prototyping process.
*
* To run:
* mpirun -np 8 ./custom_all_reduce_test
*/
#include <musa.h>
#include <murand_kernel.h>
#include <stdio.h>
#include <stdlib.h>
#include <limits>
#include <vector>
#include "musa_profiler_api.h"
#include "custom_all_reduce.muh"
#include "mpi.h"
#include "nccl.h"
#define MPICHECK(cmd) \
do { \
int e = cmd; \
if (e != MPI_SUCCESS) { \
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define NCCLCHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
__global__ void dummy_kernel() {
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
}
template <typename T>
__global__ void set_data(T *data, int size, int myRank) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
data[idx] = myRank * 0.11f;
}
}
template <typename T>
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
double *fdata2, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
fdata1[idx] = data1[idx];
fdata2[idx] = data2[idx];
}
}
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
for (int i = 0; i < nRanks; i++) {
curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
}
}
}
template <typename T>
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
int myRank, int nRanks, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
double sum = 0.0;
for (int i = 0; i < nRanks; i++) {
double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
T hval = val; // downcast first
sum += static_cast<double>(hval);
if (i == myRank) data[idx] = hval;
}
ground_truth[idx] = sum;
}
}
template <typename T>
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
int data_size, bool performance_test) {
T *result;
musaStream_t stream;
CUDACHECK(musaStreamCreateWithFlags(&stream, musaStreamNonBlocking));
CUDACHECK(musaMalloc(&result, data_size * sizeof(T)));
CUDACHECK(musaMemset(result, 0, data_size * sizeof(T)));
musaIpcMemHandle_t self_data_handle;
musaIpcMemHandle_t data_handles[8];
vllm::Signal *buffer;
T *self_data_copy;
/**
* Allocate IPC buffer
*
* The first section is a temporary buffer for storing intermediate allreduce
* results, if a particular algorithm requires it. The second section is for
* the input to the allreduce. The actual API takes the input pointer as an
* argument (that is, they can and usually should be allocated separately).
* But since the input pointers and the temporary buffer all require IPC
* registration, they are allocated and registered together in the test for
* convenience.
*/
CUDACHECK(
musaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
CUDACHECK(
musaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
CUDACHECK(musaMalloc(&self_data_copy, data_size * sizeof(T)));
CUDACHECK(musaIpcGetMemHandle(&self_data_handle, buffer));
MPICHECK(MPI_Allgather(&self_data_handle, sizeof(musaIpcMemHandle_t),
MPI_BYTE, data_handles, sizeof(musaIpcMemHandle_t),
MPI_BYTE, MPI_COMM_WORLD));
void *rank_data;
size_t rank_data_sz = 16 * 1024 * 1024;
CUDACHECK(musaMalloc(&rank_data, rank_data_sz));
std::vector<int64_t> offsets(nRanks, 0);
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
offsets, myRank);
auto *self_data =
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration
{
std::vector<std::string> handles;
handles.reserve(nRanks);
for (int i = 0; i < nRanks; i++) {
char *begin = (char *)&data_handles[i];
char *end = (char *)&data_handles[i + 1];
handles.emplace_back(begin, end);
}
std::vector<int64_t> offsets(nRanks,
sizeof(vllm::Signal) + data_size * sizeof(T));
fa.register_buffer(handles, offsets, self_data);
}
double *ground_truth;
CUDACHECK(musaMallocHost(&ground_truth, data_size * sizeof(double)));
curandState_t *states;
CUDACHECK(musaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
nRanks, data_size);
CUDACHECK(musaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
musaMemcpyDeviceToDevice, stream));
musaEvent_t start, stop;
CUDACHECK(musaEventCreate(&start));
CUDACHECK(musaEventCreate(&stop));
ncclDataType_t ncclDtype;
if (std::is_same<T, half>::value) {
ncclDtype = ncclFloat16;
} else if (std::is_same<T, mt_bfloat16>::value) {
ncclDtype = ncclBfloat16;
} else {
ncclDtype = ncclFloat;
}
double *nccl_result, *my_result;
CUDACHECK(musaMallocHost(&nccl_result, data_size * sizeof(double)));
CUDACHECK(musaMallocHost(&my_result, data_size * sizeof(double)));
if (performance_test) {
dummy_kernel<<<1, 1, 0, stream>>>();
constexpr int warmup_iters = 5;
constexpr int num_iters = 100;
// warmup
for (int i = 0; i < warmup_iters; i++) {
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
comm, stream));
}
CUDACHECK(musaEventRecord(start, stream));
for (int i = 0; i < num_iters; i++) {
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
comm, stream));
}
CUDACHECK(musaEventRecord(stop, stream));
CUDACHECK(musaStreamSynchronize(stream));
float allreduce_ms = 0;
musaEventElapsedTime(&allreduce_ms, start, stop);
dummy_kernel<<<1, 1, 0, stream>>>();
// warm up
for (int i = 0; i < warmup_iters; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads,
block_limit);
}
CUDACHECK(musaEventRecord(start, stream));
for (int i = 0; i < num_iters; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads,
block_limit);
}
CUDACHECK(musaEventRecord(stop, stream));
CUDACHECK(musaStreamSynchronize(stream));
float duration_ms = 0;
musaEventElapsedTime(&duration_ms, start, stop);
if (myRank == 0)
printf(
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
"time:%.2fus\n",
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
// And wait for all the queued up work to complete
CUDACHECK(musaStreamSynchronize(stream));
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
ncclSum, comm, stream));
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
my_result, data_size);
CUDACHECK(musaStreamSynchronize(stream));
for (unsigned long j = 0; j < data_size; j++) {
auto diff = abs(nccl_result[j] - my_result[j]);
if (diff >= 4e-2) {
printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
break;
}
}
long double nccl_diffs = 0.0;
long double my_diffs = 0.0;
for (int j = 0; j < data_size; j++) {
nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
my_diffs += abs(my_result[j] - ground_truth[j]);
}
if (myRank == 0)
std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
<< " me: " << my_diffs / data_size << std::endl;
} else {
for (int i = 0; i < 100; i++) {
fa.allreduce<T>(stream, self_data, result, data_size, threads,
block_limit);
CUDACHECK(musaStreamSynchronize(stream));
NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype,
ncclSum, comm, stream));
convert_data<T><<<108, 1024, 0, stream>>>(
self_data_copy, result, nccl_result, my_result, data_size);
CUDACHECK(musaStreamSynchronize(stream));
for (unsigned long j = 0; j < data_size; j++) {
auto diff = abs(nccl_result[j] - my_result[j]);
if (diff >= 4e-2) {
printf(
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
break;
}
}
}
if (myRank == 0)
printf("Test passed: nGPUs:%d, sz (kb): %d, %d, %d\n", nRanks,
data_size * sizeof(T) / 1024, threads, block_limit);
// long double nccl_diffs = 0.0;
// long double my_diffs = 0.0;
// for (int j = 0; j < data_size; j++) {
// nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
// my_diffs += abs(my_result[j] - ground_truth[j]);
// }
// if (myRank == 0)
// std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
// << " me: " << my_diffs / data_size << std::endl;
}
CUDACHECK(musaFree(result));
CUDACHECK(musaFree(self_data_copy));
CUDACHECK(musaFree(rank_data));
CUDACHECK(musaFree(buffer));
CUDACHECK(musaFree(states));
CUDACHECK(musaFreeHost(ground_truth));
CUDACHECK(musaFreeHost(nccl_result));
CUDACHECK(musaFreeHost(my_result));
CUDACHECK(musaStreamDestroy(stream));
}
int main(int argc, char **argv) {
int nRanks, myRank;
MPICHECK(MPI_Init(&argc, &argv));
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
CUDACHECK(musaSetDevice(myRank));
ncclUniqueId id;
ncclComm_t comm;
if (myRank == 0) ncclGetUniqueId(&id);
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
MPI_COMM_WORLD));
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
bool performance_test = true;
cudaProfilerStart();
// for (int threads : {256, 512}) {
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// }
// }
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
}
cudaProfilerStop();
return EXIT_SUCCESS;
}

View File

@@ -0,0 +1,37 @@
/*
* Adapted from
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
*/
#pragma once
#include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

View File

@@ -0,0 +1,352 @@
#include <torch/extension.h>
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include "torch_musa/csrc/core/MUSAGuard.h"
#include "dispatch_utils.h"
#include "reduction_utils.muh"
#ifndef USE_ROCM
#include <musa_bf16.h>
#include <musa_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
using __mt_bfloat16 = __hip_bfloat16;
using __mt_bfloat162 = __hip_bfloat162;
#endif
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template<typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float) input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
}
}
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
operators/constructors are not consistently implemented by HIP/CUDA, so
a generic conversion via type casts cannot be implemented.
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template<typename torch_type>
struct _typeConvert { static constexpr bool exists = false; };
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template<>
struct _typeConvert<c10::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
};
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template<>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
using hip_type = __mt_bfloat16;
using packed_hip_type = __mt_bfloat162;
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
};
#endif // defined(__MUSA_ARCH__) && __MUSA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template<typename scalar_t, int width>
struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>;
using T1 = typename Converter::hip_type;
using T2 = typename Converter::packed_hip_type;
T1 data[width];
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]};
temp += T2{other.data[i], other.data[i+1]};
data[i] = temp.x;
data[i+1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] += other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]};
temp *= T2{other.data[i], other.data[i+1]};
data[i] = temp.x;
data[i+1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] *= other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i+1]});
temp_f.x *= scale;
temp_f.y *= scale;
T2 temp = Converter::convert(temp_f);
data[i] = temp.x;
data[i+1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp);
}
}
return *this;
}
__device__ float sum_squares() const {
float result = 0.0f;
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i+1]});
result += z.x * z.x + z.y * z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]);
result += x * x;
}
}
return result;
}
};
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template<typename scalar_t, int width>
__global__ std::enable_if_t<
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
const int vec_hidden_size = hidden_size / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
input_v[id] = temp;
}
}
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template<typename scalar_t, int width>
__global__ std::enable_if_t<
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float) z;
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
}
}
} // namespace vllm
void rms_norm(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::musa::OptionalMUSAGuard device_guard(device_of(input));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"rms_norm_kernel",
[&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"fused_add_rms_norm_kernel", \
[&] { \
vllm::fused_add_rms_norm_kernel \
<scalar_t, width><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), \
epsilon, \
num_tokens, \
hidden_size); \
});
void fused_add_rms_norm(
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::musa::OptionalMUSAGuard device_guard(device_of(input));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
&& wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}

View File

@@ -0,0 +1,7 @@
#include "moe_ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
}

9
csrc_musa/moe/moe_ops.h Normal file
View File

@@ -0,0 +1,9 @@
#pragma once
#include <torch/extension.h>
void topk_softmax(
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);

View File

@@ -0,0 +1,500 @@
/*
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
* Copyright (c) 2024, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include "torch_musa/csrc/core/MUSAGuard.h"
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
namespace vllm {
namespace moe {
static constexpr int WARP_SIZE = 32;
/// Aligned array type
template <
typename T,
/// Number of elements in the array
int N,
/// Alignment requirement in bytes
int Alignment = sizeof(T) * N
>
class alignas(Alignment) AlignedArray {
float data[N];
};
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template <int TPB>
__launch_bounds__(TPB) __global__
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
{
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
const int thread_row_offset = blockIdx.x * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX);
// Don't touch finished rows.
if ((finished != nullptr) && finished[blockIdx.x])
{
return;
}
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
threadData = max(static_cast<float>(input[idx]), threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
if (threadIdx.x == 0)
{
float_max = maxElem;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
threadData += exp((static_cast<float>(input[idx]) - float_max));
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
if (threadIdx.x == 0)
{
normalizing_factor = 1.f / Z;
}
__syncthreads();
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
output[idx] = val;
}
}
template <int TPB>
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
{
using cub_kvp = cub::KeyValuePair<int, float>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
const int num_rows = gridDim.x;
const int block_row = blockIdx.x;
const bool row_is_active = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
thread_kvp.key = 0;
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
{
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
{
const int prior_winning_expert = indices[k * block_row + prior_k];
if (prior_winning_expert == expert)
{
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0)
{
// Ignore experts the node isn't responsible for with expert parallelism
const int expert = result_kvp.key;
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
const int idx = k * block_row + k_idx;
output[idx] = result_kvp.value;
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
assert(indices[idx] >= 0);
source_rows[idx] = k_idx * num_rows + block_row;
}
__syncthreads();
}
}
// ====================== TopK softmax things ===============================
/*
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
are a small power of 2. This allows us to cleanly share the rows among the threads in
a single warp and eliminate communication between warps (so no need to use shared mem).
It fuses the softmax, max and argmax into a single kernel.
Limitations:
1) This implementation is intended for when the number of experts is a small power of 2.
2) This implementation assumes k is small, but will work for any k.
*/
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
// Restrictions based on previous section.
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
// We have NUM_EXPERTS elements per row. We specialize for small #experts
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
// Restrictions for previous section.
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
// ===================== From this point, we finally start computing run-time variables. ========================
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
// This, each block processes a chunk of rows. We start by computing the start row for each block.
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
// Now, using the base row per thread block, we compute the base row per warp.
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
// The threads in a warp are split into sub-groups that will work on a row.
// We compute row offset for each thread sub-group
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
const int thread_row = warp_base_row + thread_row_in_warp;
// Threads with indices out of bounds should early exit here.
if (thread_row >= num_rows)
{
return;
}
const bool row_is_active = finished ? !finished[thread_row] : true;
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read.
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
// this can support all powers of 2 up to 16.
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
// Finally, we pull in the data from global mem
float row_chunk[VPT];
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
{
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction.
float thread_max = row_chunk[0];
#pragma unroll
for (int ii = 1; ii < VPT; ++ii)
{
thread_max = max(thread_max, row_chunk[ii]);
}
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
}
// From this point, thread max in all the threads have the max within the row.
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
float row_sum = 0;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii)
{
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
row_sum += row_chunk[ii];
}
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
}
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
// argmax after computing the softmax.
const float reciprocal_row_sum = 1.f / row_sum;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii)
{
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
}
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
// with the max index.
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
// First, each thread does the local argmax
float max_val = row_chunk[0];
int expert = start_col;
#pragma unroll
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
{
#pragma unroll
for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
{
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
// No check on the experts here since columns with the smallest index are processed first and only
// updated if > (not >=)
if (val > max_val)
{
max_val = val;
expert = col + ii;
}
}
}
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
// then blank out their max with -inf and the warp can run more iterations...
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this way
if (other_max > max_val || (other_max == max_val && other_expert < expert))
{
max_val = other_max;
expert = other_expert;
}
}
// Write the max for this k iteration to global memory.
if (thread_group_idx == 0)
{
// Add a guard to ignore experts not included by this node
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
// single) thread per row of the input/output matrices.
const int idx = k * thread_row + k_idx;
output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
}
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
if (k_idx + 1 < k)
{
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
if (thread_group_idx == thread_to_clear_in_group)
{
const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
}
}
}
}
namespace detail
{
// Constructs some constants needed to partition the work across threads at compile time.
template <int EXPERTS, int BYTES_PER_LDG>
struct TopkConstants
{
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
};
} // namespace detail
template <int EXPERTS, int WARPS_PER_TB>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, musaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
}
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, nullptr, topk_weights, topk_indicies, \
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
void topkGatingSoftmaxKernelLauncher(
const float* gating_output,
float* topk_weights,
int* topk_indicies,
int* token_expert_indices,
float* softmax_workspace,
const int num_tokens,
const int num_experts,
const int topk,
musaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
switch (num_experts) {
case 1:
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
break;
case 2:
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
break;
case 4:
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
break;
case 8:
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
break;
case 16:
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
break;
case 32:
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
break;
case 64:
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
break;
case 128:
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
break;
case 256:
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
break;
default: {
TORCH_CHECK(softmax_workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2.");
static constexpr int TPB = 256;
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
gating_output, nullptr, softmax_workspace, num_experts);
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
num_experts, topk, 0, num_experts);
}
}
}
} // namespace moe
} // namespace vllm
void topk_softmax(
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk]
torch::Tensor& gating_output) // [num_tokens, num_experts]
{
const int num_experts = gating_output.size(-1);
const int num_tokens = gating_output.numel() / num_experts;
const int topk = topk_weights.size(-1);
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
const bool needs_workspace = !is_pow_2 || num_experts > 256;
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
const at::musa::OptionalMUSAGuard device_guard(device_of(gating_output));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
vllm::moe::topkGatingSoftmaxKernelLauncher(
gating_output.data_ptr<float>(),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens,
num_experts,
topk,
stream);
}

View File

@@ -0,0 +1,125 @@
#include <torch/extension.h>
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include <ATen/ATen.h>
#include <THC/THCAtomics.muh>
#include "musa_compat.h"
#include "dispatch_utils.h"
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
namespace vllm {
namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
// don't worry about overflow because num_experts is relatively small
return row * total_col + col;
}
}
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
int32_t *sorted_token_ids,
int32_t *expert_ids,
int32_t *total_tokens_post_pad,
int32_t num_experts,
int32_t block_size,
size_t numel) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[];
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned
* to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding blocks
* and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
/**
* Each thread processes a token shard, calculating the index of each token after
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
* where * represents a padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
* stores the indices of the tokens processed by the expert with expert_id within
* the current thread's token shard.
*/
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
}
}
void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const musaStream_t stream = at::musa::getCurrentMUSAStream();
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
// set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
AT_MUSA_CHECK(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel());
});
}

38
csrc_musa/musa_compat.h Normal file
View File

@@ -0,0 +1,38 @@
#pragma once
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#ifndef USE_ROCM
#define VLLM_LDG(arg) __ldg(arg)
#else
#define VLLM_LDG(arg) *(arg)
#endif
#ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
#else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#endif
#ifndef USE_ROCM
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
#else
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif
#ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
musaFuncSetAttribute(FUNC, musaFuncAttributeMaxDynamicSharedMemorySize, VAL)
#else
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif

10
csrc_musa/musa_utils.h Normal file
View File

@@ -0,0 +1,10 @@
#pragma once
#include <torch/extension.h>
int get_device_attribute(
int attribute,
int device_id);
int get_max_shared_memory_per_block_device_attribute(
int device_id);

View File

@@ -0,0 +1,35 @@
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#endif
int get_device_attribute(
int attribute,
int device_id)
{
int device, value;
if (device_id < 0) {
musaGetDevice(&device);
}
else {
device = device_id;
}
musaDeviceGetAttribute(&value, static_cast<musaDeviceAttr>(attribute), device);
return value;
}
int get_max_shared_memory_per_block_device_attribute(
int device_id)
{
int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
#else
attribute = musaDevAttrMaxSharedMemoryPerBlockOptin;
#endif
return get_device_attribute(attribute, device_id);
}

206
csrc_musa/ops.h Normal file
View File

@@ -0,0 +1,206 @@
#pragma once
#include <torch/extension.h>
void paged_attention_v1(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);
void paged_attention_v2(
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& seq_lens,
int block_size,
int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);
void rms_norm(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);
void fused_add_rms_norm(
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);
void rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
void batched_rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_tanh_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);
#ifndef USE_ROCM
torch::Tensor aqlm_gemm(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias
);
torch::Tensor aqlm_dequant(
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes
);
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);
torch::Tensor awq_dequantize(
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters,
int thx,
int thy);
torch::Tensor marlin_gemm(
torch::Tensor& a,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
torch::Tensor& workspace,
int64_t size_m,
int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_gemm(
torch::Tensor &a,
torch::Tensor &b_q_weight,
torch::Tensor &b_scales,
torch::Tensor &g_idx,
torch::Tensor &perm,
torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full);
torch::Tensor gptq_marlin_repack(
torch::Tensor &b_q_weight,
torch::Tensor &perm,
int64_t size_k,
int64_t size_n,
int64_t num_bits);
#endif
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);
torch::Tensor gptq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama,
int bit);
void gptq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm,
int bit);
void static_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);
void dynamic_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);
void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM
using fptr_t = uint64_t;
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets, int rank,
bool full_nvlink);
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
torch::Tensor &out);
void dispose(fptr_t _fa);
int meta_size();
void register_buffer(fptr_t _fa, torch::Tensor &t,
const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets);
#endif

View File

@@ -0,0 +1,226 @@
#include <torch/extension.h>
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include "torch_musa/csrc/core/MUSAGuard.h"
#include "musa_compat.h"
#include "dispatch_utils.h"
namespace vllm {
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = VLLM_LDG(cos_ptr + x_index);
sin = VLLM_LDG(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = VLLM_LDG(cos_ptr + x_index / 2);
sin = VLLM_LDG(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* cache_ptr,
const int head_size,
const int num_heads,
const int num_kv_heads,
const int rot_dim,
const int token_idx,
const int64_t query_stride,
const int64_t key_stride)
{
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
}
}
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}
template<typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}
} // namespace vllm
void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::musa::OptionalMUSAGuard device_guard(device_of(query));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}
/*
Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner.
*/
void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) {
int64_t num_tokens = cos_sin_cache_offsets.size(0);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::musa::OptionalMUSAGuard device_guard(device_of(query));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) {
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}

217
csrc_musa/punica/.LICENSE Normal file
View File

@@ -0,0 +1,217 @@
Contains code from https://github.com/punica-ai/punica
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------------------------------------------------------------------------------
This product bundles various third-party components under other open source licenses.
This section summarizes those components and their licenses. See licenses/
for text of these licenses.
Apache-2.0
* third_party/nvbench (with LLVM exception)
* third_party/flashinfer
BSD-3-Clause:
* third_party/cutlass

View File

@@ -0,0 +1,5 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, mt_bfloat16, mt_bfloat16, mt_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, mt_bfloat16, mt_bfloat16, mt_bfloat16)

View File

@@ -0,0 +1,5 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, mt_bfloat16, float, mt_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, mt_bfloat16, float, mt_bfloat16)

View File

@@ -0,0 +1,162 @@
#pragma once
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale);
// clang-format off
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 640) \
f(in_T, out_T, W_T, narrow, 768) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1152) \
f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1536) \
f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2304) \
f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 4608) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
f(in_T, out_T, W_T, narrow, 9216) \
f(in_T, out_T, W_T, narrow, 10240) \
f(in_T, out_T, W_T, narrow, 11008) \
f(in_T, out_T, W_T, narrow, 12288) \
f(in_T, out_T, W_T, narrow, 13696) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 15360) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32512) \
f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 43264) \
f(in_T, out_T, W_T, narrow, 49152) \
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
f(in_T, out_T, W_T, narrow, 102400) \
f(in_T, out_T, W_T, narrow, 102656) \
f(in_T, out_T, W_T, narrow, 102912) \
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)
// clang-format on

View File

@@ -0,0 +1,5 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)

View File

@@ -0,0 +1,5 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)

View File

@@ -0,0 +1,5 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, mt_bfloat16, mt_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, mt_bfloat16, mt_bfloat16)

View File

@@ -0,0 +1,5 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)

View File

@@ -0,0 +1,297 @@
#pragma once
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include <cooperative_groups.h>
#include <cuda/pipeline>
#include <musa_runtime.h>
#include <iostream>
#include <stdio.h>
#include "vec_dtypes.cuh"
namespace cg = cooperative_groups;
// nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t j = blockIdx.x;
constexpr size_t num_pipeline_stages = 2;
constexpr size_t tile_size = tx * ty * vec_size;
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
__shared__ float y_warpwise[ty];
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
auto pipe = cuda::make_pipeline();
// pipeline load W/X and compute WX;
pipe.producer_acquire();
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
pipe.producer_commit();
size_t copy_idx, compute_idx;
float y = 0.f;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
++tile_idx) {
copy_idx = tile_idx % num_pipeline_stages;
// pipeline stage: async copy W fragment
pipe.producer_acquire();
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size,
X + (batch_idx * feat_in) + tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size,
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
}
pipe.producer_commit();
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// pipeline stage: compute WX
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] = sum;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
}
compute_idx = (tile_idx - 1) % num_pipeline_stages;
// final pipeline stage
pipe.consumer_wait();
block.sync();
x_vec.load(X_shared + X_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W_shared + W_shared_offset[compute_idx] +
(threadIdx.y * tx + threadIdx.x) * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += __shfl_down_sync(0xffffffff, sum, offset);
}
y_warpwise[threadIdx.y] =
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
? sum
: 0.f;
block.sync();
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y += y_warpwise[i];
}
block.sync();
pipe.consumer_release();
// write Y;
if (block.thread_rank() == 0) {
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
}
}
// nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T>
__global__ void
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
auto block = cg::this_thread_block();
size_t tile_idx = blockIdx.x;
// load X;
vec_t<in_T, vec_size> x_vec;
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
// load W;
vec_t<W_T, vec_size> w_vec;
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
block.thread_rank() * vec_size);
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
}
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += g.shfl_down(sum, offset);
}
sum = g.shfl(sum, 0);
if (threadIdx.x == 0) {
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
}
}
template <int feat_in, int feat_out, typename in_T, typename out_T,
typename W_T>
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
constexpr size_t vec_size = 8;
constexpr int tz = 4;
const musaStream_t stream = at::musa::getCurrentMUSAStream();
if constexpr (feat_in <= feat_out) {
static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size;
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
constexpr int ty = 32 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
constexpr int ty = 16 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else {
constexpr int ty = 8 / tx;
dim3 nblks(feat_out / (ty * tz), batch_size);
dim3 nthrs(tx, ty, tz);
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
} else {
static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0);
if constexpr (feat_in % (vec_size * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
vec_size * sizeof(W_T), tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
constexpr int tx = 32;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
constexpr int tx = 16;
constexpr int ty = 4;
dim3 nblks(feat_out, batch_size);
dim3 nthrs(tx, ty);
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
vec_size * sizeof(in_T) / 2,
vec_size * sizeof(W_T) / 2, tx, ty, tz>
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
full_y_size, num_layers, layer_idx,
scale);
}
}
}
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
template void bgmv_kernel<feat_in, feat_out>( \
out_T * __restrict__ Y, const in_T *__restrict__ X, \
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)

View File

@@ -0,0 +1,48 @@
DTYPES = ["fp16", "bf16", "fp32"]
DTYPE_MAP = {
"fp16": "nv_half",
"bf16": "mt_bfloat16",
"fp32": "float",
}
TEMPLATE = """
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip() # noqa: E501
for input_dtype in DTYPES:
for output_dtype in DTYPES:
for weight_dtype in DTYPES:
if weight_dtype == "fp32":
# FP32 weights are not supported.
continue
if output_dtype == "fp32":
# LoRA A matrix.
if input_dtype != weight_dtype:
# NOTE(woosuk): While Punica supports the case where the
# input and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif input_dtype == "fp32":
# LoRA B matrix.
if output_dtype != weight_dtype:
# NOTE(woosuk): While Punica supports the case where the
# output and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif not (input_dtype == output_dtype == weight_dtype):
# NOTE(woosuk): While Punica supports mixed data types for
# input, output, and weight, we only generate the kernels with
# the same data types to reduce the binary size.
continue
kernel_definition = TEMPLATE.format(
input_dtype=DTYPE_MAP[input_dtype],
output_dtype=DTYPE_MAP[output_dtype],
weight_dtype=DTYPE_MAP[weight_dtype])
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
with open(filename, "w") as f:
f.write(kernel_definition)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,582 @@
#include <musa_bf16.h>
#include <musa_fp16.h>
#include <torch/extension.h>
#include "torch_musa/csrc/core/MUSAGuard.h"
#include <cstdint>
#include "bgmv/bgmv_config.h"
namespace {
//====== utils ======
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
const char *a_name, const char *b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
".size(", i, ")");
}
}
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
return (uint64_t(a) << 32) | uint64_t(b);
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_DIM(d, x) \
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) \
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
//====== bgmv ======
template <typename in_T, typename out_T, typename W_T>
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
const int64_t *lora_indices,
uint32_t in_features, uint32_t out_features,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
// NOTE(woosuk): While Punica supports various combinations of input/output
// data types, we limit the supported data types to reduce the binary size.
constexpr bool is_input_float = std::is_same<in_T, float>::value;
constexpr bool is_output_float = std::is_same<out_T, float>::value;
if (is_input_float) {
if (!std::is_same<out_T, W_T>::value) {
return false;
}
} else if (is_output_float) {
if (!std::is_same<in_T, W_T>::value) {
return false;
}
} else if (!(std::is_same<in_T, W_T>::value &&
std::is_same<out_T, W_T>::value)) {
return false;
}
switch (pack_u32(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u32(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
break;
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
#undef CASE
#undef CASE_ONESIDE
default:
return false;
}
return true;
}
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, float scale) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t h_in = x.size(1);
int64_t h_out = y.size(1);
int64_t num_layers = w.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
const at::musa::OptionalMUSAGuard device_guard(device_of(x));
bool ok = false;
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
h_out, B, num_layers, layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
float scale, int64_t h_in, int64_t h_out,
int64_t y_offset) {
CHECK_INPUT(y);
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_INPUT(indicies);
CHECK_DIM(2, y);
CHECK_DIM(2, x);
CHECK_DIM(4, w);
CHECK_DIM(1, indicies);
int64_t B = x.size(0);
int64_t num_layers = w.size(1);
int64_t full_y_size = y.size(1);
CHECK_EQ(w.size(3), h_in);
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
const at::musa::OptionalMUSAGuard device_guard(device_of(x));
bool ok = false;
if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) {
case at::ScalarType::Half:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<nv_half *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<mt_bfloat16 *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (y.scalar_type()) {
case at::ScalarType::Half:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::BFloat16:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
case at::ScalarType::Float:
switch (w.scalar_type()) {
case at::ScalarType::Half:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<nv_half *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
case at::ScalarType::BFloat16:
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
static_cast<float *>(x.data_ptr()),
static_cast<mt_bfloat16 *>(w.data_ptr()),
indicies.data_ptr<int64_t>(), h_in, h_out,
y_offset, full_y_size, B, num_layers,
layer_idx, scale);
break;
default:
break;
}
break;
default:
break;
}
break;
default:
break;
}
}
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
} // namespace
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}

136
csrc_musa/pybind.cpp Normal file
View File

@@ -0,0 +1,136 @@
#include "cache.h"
#include "musa_utils.h"
#include "ops.h"
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// vLLM custom ops
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def(
"paged_attention_v1",
&paged_attention_v1,
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops
ops.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm
ops.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def(
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def(
"batched_rotary_embedding",
&batched_rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
// Quantization ops
#ifndef USE_ROCM
// ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
// ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
// ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
// ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
// ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
// ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
// ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif
// ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
// ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
// ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
// ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
// ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
// ops.def(
// "moe_align_block_size",
// &moe_align_block_size,
// "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"reshape_and_cache_flash",
&reshape_and_cache_flash,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"convert_fp8",
&convert_fp8,
"Convert the key and value cache to fp8 data type");
// Cuda utils
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
cuda_utils.def(
"get_device_attribute",
&get_device_attribute,
"Gets the specified device attribute.");
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");
#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
custom_ar.def("dispose", &dispose, "dispose");
custom_ar.def("meta_size", &meta_size, "meta_size");
custom_ar.def("register_buffer", &register_buffer, "register_buffer");
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
"get_graph_buffer_ipc_meta");
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#endif
}

View File

@@ -0,0 +1,712 @@
/*
* Modified by Neural Magic
* Adapted from https://github.com/Vahe1994/AQLM
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <musa.h>
#include <musa_fp16.h>
#include <musa_runtime.h>
#include <torch/extension.h>
#include "torch_musa/csrc/core/MUSAStream.h"
#include "torch_musa/csrc/core/MUSAGuard.h"
#include <iostream>
#include <cstdlib>
namespace vllm {
namespace aqlm {
__global__ void Code1x16MatVec(
const int4* __restrict__ A,
const int4* __restrict__ B,
int4* __restrict__ C,
const int4* __restrict__ codebook,
const int prob_m,
const int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred)
{
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size)
{
codebook += codebook_stride;
++codebook_size;
}
}
int b_gl_rd = 0;
int c_gl_wr = a_gl_rd;
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
__shared__ int4 sh_b[32 * 9];
float res = 0;
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
while (iters--) {
// We pad shared memory to avoid bank conflicts during reads
__syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8)
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
}
__syncthreads();
b_gl_rd += 32 * 8;
int b_sh_rd = 9 * (threadIdx.x % 32);
if (pred && a_gl_rd < a_gl_end) {
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
uint32_t dec[4];
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
// actually help us; this brings > 2x speedup.
asm volatile (
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*) &codebook[enc[i]])
);
half2* a = reinterpret_cast<half2*>(&dec);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {};
#pragma unroll
for (int j = 0; j < 4; j++)
res2 = __hfma2(a[j], b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y);
b_sh_rd++;
}
a_gl_rd += 32;
}
}
if (pred) {
#pragma unroll
for (int i = 16; i > 0; i /= 2)
res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
}
}
__global__ void Code2x8MatVec(
const int4* __restrict__ A,
const int4* __restrict__ B,
int4* __restrict__ C,
const int4* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred)
{
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size)
{
codebook += codebook_stride;
++codebook_size;
}
}
int b_gl_rd = 0;
int c_gl_wr = a_gl_rd;
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int lane = threadIdx.x % 8;
extern __shared__ int4 sh[];
int4* sh_b = sh;
int4* sh_code = sh_b + 32 * 9;
int4* sh_code0 = sh_code;
int4* sh_code1 = sh_code + 256 * 8;
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i];
#pragma unroll
for (int j = 0; j < 8; j++)
sh_code[8 * i + (j + lane) % 8] = dec;
}
__syncthreads();
float res = 0;
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
while (iters--) {
// We pad shared memory to avoid bank conflicts during reads
__syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8)
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
}
__syncthreads();
b_gl_rd += 32 * 8;
int b_sh_rd = 9 * (threadIdx.x % 32);
if (pred && a_gl_rd < a_gl_end) {
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {};
#pragma unroll
for (int j = 0; j < 4; j++)
res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y);
b_sh_rd++;
}
a_gl_rd += 32;
}
}
if (pred) {
#pragma unroll
for (int i = 16; i > 0; i /= 2)
res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
}
}
__global__ void Code1x16Dequant(
const int4* __restrict__ A,
int4* __restrict__ C,
const int4* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
const int codebook_stride // as int4
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred)
{
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size)
{
codebook += codebook_stride;
++codebook_size;
}
}
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int c_gl_stride = prob_k / 8;
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
while (iters--) {
if (pred && a_gl_rd < a_gl_end) {
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
int4 chunk;
auto dec = reinterpret_cast<uint32_t*>(&chunk);
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
// actually help us; this brings > 2x speedup.
asm volatile (
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*) &codebook[enc[i]])
);
C[a_gl_rd * 8 + i] = chunk;
}
}
a_gl_rd += 32;
}
}
__global__ void Code2x8Dequant(
const int4* __restrict__ A,
int4* __restrict__ C,
const int4* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const int codebook_stride // as int4
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred)
{
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size)
{
codebook += codebook_stride;
++codebook_size;
}
}
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int lane = threadIdx.x % 8;
int c_gl_stride = prob_k / 8;
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
extern __shared__ int4 sh[];
int4* sh_code = sh;
int4* sh_code0 = sh_code;
int4* sh_code1 = sh_code + 256 * 8;
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i];
#pragma unroll
for (int j = 0; j < 8; j++)
sh_code[8 * i + (j + lane) % 8] = dec;
}
__syncthreads();
float res = 0;
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
while (iters--) {
if (pred && a_gl_rd < a_gl_end) {
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
int4 chunk;
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
C[a_gl_rd * 8 + i] = chunk;
}
}
a_gl_rd += 32;
}
}
inline int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
const int THREAD_M = 16;
void code1x16_matvec_cuda(
const void* __restrict__ A,
const void* __restrict__ B,
void* __restrict__ C,
const void* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes,
const int codebook_stride
) {
int sms;
musaDeviceGetAttribute(&sms, musaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
musaStream_t stream = at::musa::getCurrentMUSAStream().stream();
Code1x16MatVec<<<blocks, threads, 16*32*9, stream>>>(
(const int4*) A,
(const int4*) B,
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
}
void code2x8_matvec_cuda(
const void* __restrict__ A,
const void* __restrict__ B,
void* __restrict__ C,
const void* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes,
const int codebook_stride
) {
int sms;
musaDeviceGetAttribute(&sms, musaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
int shared = 16 * (2 * 256 * 8 + 32 * 9);
musaFuncSetAttribute(
Code2x8MatVec, musaFuncAttributeMaxDynamicSharedMemorySize, shared
);
musaStream_t stream = at::musa::getCurrentMUSAStream().stream();
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
(const int4*) A,
(const int4*) B,
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
}
void code1x16_dequant_cuda(
const void* __restrict__ A,
void* __restrict__ C,
const void* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int sms;
musaDeviceGetAttribute(&sms, musaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
musaStream_t stream = at::musa::getCurrentMUSAStream().stream();
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
(const int4*) A,
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
codebook_stride // as int4.
);
}
// Dequantizes the code and codebook into weights.
void code2x8_dequant_cuda(
const void* __restrict__ A,
void* __restrict__ C,
const void* __restrict__ codebook,
int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const int codebook_stride // as int4
) {
int sms;
musaDeviceGetAttribute(&sms, musaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
int shared = 16 * (2 * 256 * 8 + 32 * 9);
musaStream_t stream = at::musa::getCurrentMUSAStream().stream();
musaFuncSetAttribute(
Code2x8Dequant, musaFuncAttributeMaxDynamicSharedMemorySize, shared
);
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
(const int4*) A,
(int4*) C,
(const int4*) codebook,
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride
);
}
int codebook_stride(const torch::Tensor& codebooks)
{
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
}
void code1x16_matvec(
const torch::Tensor& A,
const torch::Tensor& B,
torch::Tensor& C,
const torch::Tensor& codebook,
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
) {
const at::musa::OptionalMUSAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code1x16_matvec_cuda(
A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
codebook_a_sizes,
codebook_stride(codebook)
);
}
torch::Tensor code1x16_matmat(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const int4 codebook_a_sizes,
const std::optional<torch::Tensor>& bias) {
auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty({flat_input.size(0), out_features},
torch::TensorOptions()
.dtype(input.dtype())
.device(input.device())
);
for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i});
code1x16_matvec(
codes.squeeze(2),
input_vec,
output_vec,
codebooks,
codebook_a_sizes
);
}
flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) {
flat_output += bias->unsqueeze(0);
}
auto output_sizes = input_sizes.vec();
output_sizes.pop_back();
output_sizes.push_back(-1);
auto output = flat_output.reshape(output_sizes);
return output;
}
void code2x8_matvec(
const torch::Tensor& A,
const torch::Tensor& B,
torch::Tensor& C,
const torch::Tensor& codebook,
const int4 codebook_a_sizes
) {
const at::musa::OptionalMUSAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code2x8_matvec_cuda(
A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
codebook.data_ptr(),
prob_m,
prob_k,
codebook_a_sizes,
2 * codebook_stride(codebook)
);
}
torch::Tensor code2x8_matmat(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const int4 codebook_a_sizes,
const std::optional<torch::Tensor>& bias
) {
auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty({flat_input.size(0), out_features},
torch::TensorOptions()
.dtype(input.dtype())
.device(input.device())
);
for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i});
code2x8_matvec(
codes.squeeze(2),
input_vec,
output_vec,
codebooks,
codebook_a_sizes
);
}
flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) {
flat_output += bias->unsqueeze(0);
}
auto output_sizes = input_sizes.vec();
output_sizes.pop_back();
output_sizes.push_back(-1);
auto output = flat_output.reshape(output_sizes);
return output;
}
// Accumulate the partition sizes.
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
{
int4 cumulative_sizes;
auto cumulative_size = &cumulative_sizes.x;
int i = 0;
int last = 0;
assert(codebook_partition_sizes.size(0) <= 4);
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size)
{
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
last = *cumulative_size;
}
// fill in the rest with unreachable.
for (; i < 4; ++i, ++cumulative_size)
{
*cumulative_size = last*10;
}
return cumulative_sizes;
}
} // namespace aqlm
} // namespace vllm
torch::Tensor aqlm_gemm(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias
)
{
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1);
if (nbooks == 1 && entries == (1 << 16))
{
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
}
if (nbooks == 2 && entries == (1 << 8))
{
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
}
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
return {};
}
torch::Tensor aqlm_dequant(
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes
)
{
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1);
const at::musa::OptionalMUSAGuard device_guard(device_of(codes));
int rows = codes.size(1);
int cols = codes.size(0);
auto in_features = codes.size(1) * 8;
auto out_features = codes.size(0);
assert(out_features = codebook_partition_sizes.sum().item<int>());
auto weights = torch::empty({out_features, in_features},
torch::TensorOptions()
.dtype(codebooks.dtype())
.device(codebooks.device())
);
if (nbooks == 1 && entries == (1 << 16))
{
vllm::aqlm::code1x16_dequant_cuda(
codes.data_ptr(),
weights.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features,
cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
// weights *= scales.index({"...", 0, 0});
return weights;
}
if (nbooks == 2 && entries == (1 << 8))
{
vllm::aqlm::code2x8_dequant_cuda(
codes.data_ptr(),
weights.data_ptr(),
codebooks.data_ptr(),
out_features,
in_features,
cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
// weights *= scales.index({"...", 0, 0});
return weights;
}
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
return {};
}

View File

@@ -0,0 +1,87 @@
/*
Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#pragma once
namespace vllm {
namespace awq {
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 750
assert(false);
#else
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
#endif
}
} // namespace awq
} // namespace vllm

View File

@@ -0,0 +1,446 @@
/*
Adapted from https://github.com/mit-han-lab/llm-awq
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <torch/extension.h>
#include "torch_musa/csrc/core/MUSAGuard.h"
#include "dequantize.cuh"
#include <musa_fp16.h>
namespace vllm {
namespace awq {
// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
template<int N>
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
int G,
int split_k_iters,
half* __restrict__ A,
int* __restrict__ B,
half* __restrict__ scaling_factors,
int* __restrict__ zeros,
int M,
int IC,
int OC,
half* __restrict__ C)
{
// Only support matrix n = 64 or 128
assert(N == 64 || N == 128);
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 750
assert(false);
#else
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (N + 8)];
__shared__ half scaling_factors_shared[N];
__shared__ half zeros_shared[N];
int j_factors1 = ((OC + N - 1) / N);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[N / 4];
for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / N;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * (256 / N)
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
+ (((int)threadIdx.x) % (N / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
+ (((int)threadIdx.x) / (N / 8)) * (N + 8)
+ (((int)threadIdx.x) % (N / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
+ ((int)threadIdx.x) % (N / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * N
+ (((int)threadIdx.x) % (N / 8)) * 8;
half* C_ptr = C
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * N
+ ((int)threadIdx.y) * (N / 2)
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ == 750
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
#else
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
#endif
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
#endif
}
__global__ void __launch_bounds__(64) dequantize_weights(
int* __restrict__ B,
half* __restrict__ scaling_factors,
int* __restrict__ zeros,
half* __restrict__ C,
int G
)
{
int j_factors1 = 4;
int row_stride2 = 4;
int split_k_iters = 1;
static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)];
half* B_shared_ptr2 = B_shared;
half B_shared_warp[32];
int OC = 512;
int N = blockDim.x * gridDim.x; // 2
int col = (blockIdx.x * blockDim.x + threadIdx.x);
int row = blockIdx.y * blockDim.y + threadIdx.y;
int index1 = 8 * col + 8 * row * N;
half* C_ptr2 = C + index1;
int index2 = col + row * N;
int* B_ptr2 = B + index2;
int index3 = col + (int)(row / G) * N;
int* zeros_ptr2 = zeros + index3;
int index4 = 8 * col + (int)(row / G) * N * 8;
half* scaling_factors_ptr2 = scaling_factors + index4;
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
uint32_t B_loaded = *(uint32_t*)B_ptr2;
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
for (int i = 0; i < 8; ++i) {
*(C_ptr2 + i) = B_shared[i];
}
}
} // namespace awq
} // namespace vllm
torch::Tensor awq_dequantize(
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters,
int thx,
int thy)
{
int in_c = _kernel.size(0);
int qout_c = _kernel.size(1);
int out_c = qout_c * 8;
int G = in_c / _scaling_factors.size(0);
int x_thread = thx;
int y_thread = thy;
int x_blocks = 1;
int y_blocks = 1;
if (thx==0) {
x_thread = qout_c;
}
if (thy==0) {
y_thread = in_c;
}
if (thx==0 && thy==0) {
x_thread = 8;
y_thread = 8;
x_blocks = (int)(qout_c / 8);
y_blocks = (int)(in_c / 8);
}
const at::musa::OptionalMUSAGuard device_guard(device_of(_scaling_factors));
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_thread, y_thread);
const musaStream_t stream = at::musa::getCurrentMUSAStream();
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
kernel, scaling_factors, zeros, de_kernel, G);
return _de_kernel;
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::musa::OptionalMUSAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0);
if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64");
if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8");
if (group_size % 32 != 0)
throw std::invalid_argument("Group size should be a multiple of 32");
if (num_out_channels % group_size != 0)
throw std::invalid_argument("OC is not multiple of Group size");
const musaStream_t stream = at::musa::getCurrentMUSAStream();
if (num_out_channels % 128 == 0)
{
int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
{
int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
num_out_channels, out_feats);
}
return _out_feats.sum(0);
}

View File

@@ -0,0 +1,167 @@
#pragma once
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#else
#include <type_traits>
#include <stdint.h>
#include <math.h>
#include <iostream>
#endif
#include "hip_float8_impl.h"
struct alignas(1) hip_fp8
{
struct from_bits_t
{
};
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
uint8_t data;
hip_fp8() = default;
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
: data(v)
{
}
#ifdef __HIP__MI300__
// NOTE: ON-DEVICE... always optimal bias
explicit HIP_FP8_DEVICE hip_fp8(float v)
: data(hip_fp8_impl::to_fp8_from_fp32(v))
{
}
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
: hip_fp8(static_cast<float>(v))
{
}
// Host only implementation using s/w simulation
explicit HIP_FP8_HOST
#else // __HIP__MI300__
// both Host and DEVICE for non-MI300 using s/w simulation
explicit HIP_FP8_HOST_DEVICE
#endif // __HIP__MI300__
hip_fp8(float v)
{
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v);
}
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
: hip_fp8(static_cast<float>(v))
{
}
#ifdef __HIP__MI300__
// upcast using device specific intrinsic
explicit inline HIP_FP8_DEVICE operator float() const
{
float fval;
uint32_t i32val = static_cast<uint32_t>(data);
// upcast
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
}
explicit inline HIP_FP8_HOST operator float() const
#else // __HIP__MI300__
explicit inline HIP_FP8_HOST_DEVICE operator float() const
#endif // __HIP__MI300__
{
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data);
}
};
namespace std
{
inline hip_fp8 sin(hip_fp8 a)
{
return hip_fp8(sinf(float(a)));
}
inline hip_fp8 cos(hip_fp8 a)
{
return hip_fp8(cosf(float(a)));
}
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
{
return a;
}
} // namespace std
// Special operator overloading
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8)
{
return os << float(f8);
}
// all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns float
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b)
{
return (fa + float(b));
}
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb)
{
return (float(a) + fb);
}
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b)
{
return hip_fp8(float(a) + float(b));
}
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b)
{
return a = hip_fp8(float(a) + float(b));
}
// overloading multiplication, always returns float,
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b)
{
return float(a) * float(b);
}
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b)
{
return (a * float(b));
}
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b)
{
return (float(a) * b);
}
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b)
{
return ((float)a * float(b));
}
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b)
{
return ((float)a * float(b));
}
// overloading for compare
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b)
{
return (a.data == b.data);
}
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b)
{
return (a.data != b.data);
}
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b)
{
return static_cast<float>(a) >= static_cast<float>(b);
}
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b)
{
return static_cast<float>(a) > static_cast<float>(b);
}

View File

@@ -0,0 +1,316 @@
#pragma once
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__
#endif
#ifdef __HIPCC__
#define HIP_FP8_HOST_DEVICE __host__ __device__
#define HIP_FP8_HOST __host__
#define HIP_FP8_DEVICE __device__
#else
#define HIP_FP8_HOST_DEVICE
#define HIP_FP8_HOST
#define HIP_FP8_DEVICE
#endif
namespace hip_fp8_impl
{
#ifdef __HIP__MI300__
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
{
uint8_t i8data;
union {
float fval;
uint32_t i32val;
uint8_t i8val[4]; // NOTE: not endian independent
} val;
uint32_t ival = 0;
val.fval = v;
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
}
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false); // false -> WORD0
val.i32val = ival;
i8data = val.i8val[0];
return i8data;
}
#endif // __HIP__MI300__
HIP_FP8_HOST inline int clz(uint32_t x)
{
return __builtin_clz(x);
}
#if defined(__HIPCC__) || defined(__MUSA_ARCH__)
HIP_FP8_DEVICE inline int clz(uint32_t x)
{
return __clz(x);
}
#endif
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0)
{
#ifdef __HIPCC__
constexpr bool is_half = std::is_same<T, _Float16>::value;
#else
constexpr bool is_half = false;
#endif
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(wm + we == 7, "wm+we==7");
static_assert(is_half || is_float, "Only half and float can be cast to f8");
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
uint32_t x;
if (sizeof(T) == 4) {
x = reinterpret_cast<uint32_t&>(_x);
} else {
x = reinterpret_cast<uint16_t&>(_x);
}
uint32_t head, mantissa;
int exponent, bias;
uint32_t sign;
if (sizeof(T) == 4) {
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
exponent = (head >> 23) & 0xFF;
sign = head >> 31;
bias = 127;
} else {
head = x & 0xFC00;
mantissa = x & 0x3FF;
exponent = (head >> 10) & 0x1F;
sign = head >> 15;
bias = 15;
}
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
// Deal with inf and NaNs
if (negative_zero_nan) {
if (sizeof(T) == 4) {
if ((x & 0x7F800000) == 0x7F800000) {
return 0x80;
}
} else {
// if(__hisinf(x) || __hisnan(x))
if ((x & 0x7C00) == 0x7C00) {
return 0x80;
}
}
} else {
if (sizeof(T) == 4) {
if ((x & 0x7F800000) == 0x7F800000) {
return signed_inf + (mantissa != 0 ? 1 : 0);
}
} else {
if ((x & 0x7C00) == 0x7C00) {
return signed_inf + (mantissa != 0 ? 1 : 0);
}
}
}
if (x == 0) {
return 0;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, f8_exponent, exponent_diff;
if (exponent == 0) { // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
} else { // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if (act_exponent <= f8_denormal_act_exponent) {
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = f8_denormal_act_exponent - act_exponent;
} else { // both fp32/fp16 and f8 are in normal range
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
// for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part
and make something not midpoint look like midpoint. For example, the fp16
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
shift right by 4 bits, it would look like midpoint.
*/
if (exponent_diff > 0) {
mantissa >>= exponent_diff;
} else if (exponent_diff == -1) {
mantissa <<= -exponent_diff;
}
bool implicit_one = mantissa & (1 << mfmt);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that
// is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
// Now we deal with overflow
if (f8_exponent == 0) {
if ((1 << mfmt) & mantissa) {
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
} else {
if ((1 << (mfmt + 1)) & mantissa) {
mantissa >>= 1;
f8_exponent++;
}
}
mantissa >>= (mfmt - wm);
// above range: quantize to maximum possible float of the same sign
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
if (f8_exponent > max_exp) {
if (clip) {
mantissa = (1 << wm) - 1;
f8_exponent = max_exp;
} else {
return signed_inf;
}
}
if (f8_exponent == 0 && mantissa == 0) {
return negative_zero_nan ? 0 : (sign << 7);
}
mantissa &= (1 << wm) - 1;
return (sign << 7) | (f8_exponent << wm) | mantissa;
}
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
{
#ifdef __HIPCC__
constexpr bool is_half = std::is_same<T, _Float16>::value;
#else
constexpr bool is_half = false;
#endif
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "only half and float are supported");
constexpr int weo = is_half ? 5 : 8;
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
T fInf, fNegInf, fNaN, fNeg0;
#ifdef __HIPCC__
if (is_half) {
const uint16_t ihInf = 0x7C00;
const uint16_t ihNegInf = 0xFC00;
const uint16_t ihNaN = 0x7C01;
const uint16_t ihNeg0 = 0x8000;
fInf = reinterpret_cast<const _Float16&>(ihInf);
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
} else
#endif
if (is_float) {
const uint32_t ifInf = 0x7F800000;
const uint32_t ifNegInf = 0xFF800000;
const uint32_t ifNaN = 0x7F800001;
const uint32_t ifNeg0 = 0x80000000;
fInf = reinterpret_cast<const float&>(ifInf);
fNegInf = reinterpret_cast<const float&>(ifNegInf);
fNaN = reinterpret_cast<const float&>(ifNaN);
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
}
if (x == 0) {
return 0;
}
uint32_t sign = x >> 7;
uint32_t mantissa = x & ((1 << wm) - 1);
int exponent = (x & 0x7F) >> wm;
if (negative_zero_nan) {
if (x == 0x80) {
return fNaN;
}
} else {
if (x == 0x80) {
return fNeg0;
}
if (exponent == ((1 << we) - 1)) {
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
}
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
if (we == 5 && is_half && !negative_zero_nan) {
retval = x << 8;
return reinterpret_cast<const T&>(retval);
}
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
// subnormal input
if (exponent == 0) {
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + clz(mantissa) - (32 - wm);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << wm) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if (exponent <= 0) {
mantissa |= 1 << wmo;
mantissa >>= 1 - exponent;
exponent = 0;
}
if (sizeof(T) == 2) {
retval = (sign << 15) | (exponent << 10) | mantissa;
} else {
retval = (sign << 31) | (exponent << 23) | mantissa;
}
return reinterpret_cast<const T&>(retval);
}
} // namespace hip_fp8_impl

View File

@@ -0,0 +1,517 @@
#pragma once
#include "hip_float8.h"
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>
#include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh"
namespace vllm
{
namespace fp8_e4m3 {
template <typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x)
{
return x;
}
template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
{
return x;
}
// fp8 -> half
template <>
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
{
hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res;
res.data = static_cast<float>(f8);
return res.x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
{
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r.x.data = f2[0];
tmp.h2r.y.data = f2[1];
return tmp.ui32;
#else
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
return tmp.u32;
#endif
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
{
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
{
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
return tmp.u64x2;
}
using __mt_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __mt_bfloat16 vec_conversion<__mt_bfloat16, uint8_t>(const uint8_t& a)
{
hip_fp8 f8{a, hip_fp8::from_bits()};
float f{f8};
return __float2bfloat16(f);
}
using __mt_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __mt_bfloat162 vec_conversion<__mt_bfloat162, uint16_t>(const uint16_t& a)
{
__mt_bfloat162 res;
res.x = vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)a);
res.y = vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)(a >> 8U));
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
{
bf16_4_t res;
res.x = vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)a);
res.y = vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
{
bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template <>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
{
hip_fp8 fp8{a, hip_fp8::from_bits()};
return static_cast<float>(fp8);
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
{
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
float2 res;
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.x = f2[0];
res.y = f2[1];
return res;
#else
float2 res;
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
return res;
#endif
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
{
Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
{
Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// half -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
{
__half_raw tmp;
tmp.x = a;
hip_fp8 f8{static_cast<float>(tmp.data)};
return f8.data;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __mt_bfloat16>(const __mt_bfloat16& a)
{
hip_fp8 res{__bfloat162float(a)};
return res.data;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
{
hip_fp8 f8(a);
return f8.data;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
{
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
// float2 -> half2
template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
{
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(a);
return uint32;
}
// Float4 -> half2x2
template <>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
{
uint2 b;
float2 val;
val.x = a.x.x;
val.y = a.x.y;
b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.y.x;
val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val);
return b;
}
// Float4 -> float4
template <>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
{
float4 b;
b.x = a.x.x;
b.y = a.x.y;
b.z = a.y.x;
b.w = a.y.y;
return b;
}
// Float8 -> half2x4
template <>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
{
uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x);
b.y = vec_conversion<uint32_t, float2>(a.y);
b.z = vec_conversion<uint32_t, float2>(a.z);
b.w = vec_conversion<uint32_t, float2>(a.w);
return b;
}
// float2 -> bfloat162
template <>
__inline__ __device__ __mt_bfloat162 vec_conversion<__mt_bfloat162, float2>(const float2& a)
{
__mt_bfloat162 b = __float22bfloat162_rn(a);
return b;
}
// Float4 -> bfloat162x2
template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
{
bf16_4_t b;
b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y);
return b;
}
// Float8 -> bfloat162x4
template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
{
bf16_8_t b;
b.x = __float22bfloat162_rn(a.x);
b.y = __float22bfloat162_rn(a.y);
b.z = __float22bfloat162_rn(a.z);
b.w = __float22bfloat162_rn(a.w);
return b;
}
/* Scaled and vectorized conversions, for data exchange between high and low precision domains
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
s.t.
Quantize(HP / scale) => FP8
Dequant(FP8) * scale => HP
*/
// fp8 -> half
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
{
hip_fp8 f8{a, hip_fp8::from_bits()};
__half_raw res;
res.data = static_cast<float>(f8) * scale;
return res.x;
}
// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
{
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
union {
__half2_raw h2r;
uint32_t ui32;
} tmp;
tmp.h2r.x.data = f2[0] * scale;
tmp.h2r.y.data = f2[1] * scale;
return tmp.ui32;
#else
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
return tmp.u32;
#endif
}
// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
{
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
return tmp.u32x2;
}
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
{
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
return tmp.u64x2;
}
using __mt_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __mt_bfloat16 scaled_vec_conversion<__mt_bfloat16, uint8_t>(const uint8_t& a, const float scale)
{
hip_fp8 f8{a, hip_fp8::from_bits()};
float f{f8};
return __float2bfloat16(f * scale);
}
using __mt_bfloat162 = __hip_bfloat162;
// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __mt_bfloat162 scaled_vec_conversion<__mt_bfloat162, uint16_t>(const uint16_t& a, const float scale)
{
__mt_bfloat162 res;
res.x = scaled_vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)a, scale);
res.y = scaled_vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
return res;
}
// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
{
bf16_4_t res;
res.x = scaled_vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
return res;
}
// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
{
bf16_4_t tmp1, tmp2;
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
{
hip_fp8 fp8{a, hip_fp8::from_bits()};
return static_cast<float>(fp8) * scale;
}
// fp8x2 -> float2
template <>
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
{
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
float2 res;
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
res.x = f2[0] * scale;
res.y = f2[1] * scale;
return res;
#else
float2 res;
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
return res;
#endif
}
// fp8x4 -> float4
template <>
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
{
Float4_ res;
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
return res;
}
// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
{
Float4_ tmp1, tmp2;
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
/* Quantize(HP / scale) => FP8 */
// TODO(Hai): vectorized to add
// half -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
{
__half_raw tmp;
tmp.x = a;
hip_fp8 f8{static_cast<float>(tmp.data)/scale};
return f8.data;
}
// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __mt_bfloat16>(const __mt_bfloat16& a, const float scale)
{
hip_fp8 res{__bfloat162float(a)/scale};
return res.data;
}
// float -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
{
hip_fp8 f8(a/scale);
return f8.data;
}
// fp8x4 -> float4
template <>
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
{
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
}
} // namespace vllm

View File

@@ -0,0 +1,126 @@
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include <torch/extension.h>
#include "torch_musa/csrc/core/MUSAGuard.h"
#include <cmath>
#include "musa_compat.h"
#include "dispatch_utils.h"
namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template<typename scalar_t>
__global__ void segmented_max_reduction(
float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int i = blockDim.x * blockIdx.x + threadIdx.x;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;
__syncthreads();
// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
}
}
template<typename scalar_t>
__global__ void scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
while (i < num_elems) {
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
i += blockDim.x * gridDim.x;
}
}
} // namespace vllm
void static_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::musa::OptionalMUSAGuard device_guard(device_of(input));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"scaled_fp8_quant_kernel",
[&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
});
}
void dynamic_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::musa::OptionalMUSAGuard device_guard(device_of(input));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"scaled_fp8_quant_kernel",
[&] {
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
scale.data_ptr<float>(),
input.data_ptr<scalar_t>(),
num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
});
}

View File

@@ -0,0 +1,277 @@
#pragma once
#include <assert.h>
#include <stdint.h>
#include <float.h>
#include <type_traits>
#include "../../attention/attention_dtypes.h"
#include "../../attention/dtype_float32.cuh"
#include "../../attention/dtype_float16.cuh"
#include "../../attention/dtype_bfloat16.cuh"
namespace vllm {
#ifdef ENABLE_FP8_E5M2
namespace fp8_e5m2_unscaled {
template<typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x)
{
return x;
}
// fp8 -> half
template<>
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
{
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
return res.x;
}
// fp8x2 -> half2
template<>
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
{
union {
uint16_t u16[2];
uint32_t u32;
} tmp;
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
tmp.u16[0] = res.x;
tmp.u16[1] = res.y;
return tmp.u32;
}
// fp8x4 -> half2x2
template<>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
{
union {
uint2 u32x2;
uint32_t u32[2];
} tmp;
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
return tmp.u32x2;
}
// fp8x8 -> half2x4
template<>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
{
union {
uint4 u64x2;
uint2 u64[2];
} tmp;
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
return tmp.u64x2;
}
// fp8 -> __nv_bfloat16
template<>
__inline__ __device__ __mt_bfloat16 vec_conversion<__mt_bfloat16, uint8_t>(const uint8_t& a)
{
// Note there is no direct convert function from fp8 to bf16.
// fp8 -> half
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
// half -> float -> bf16
float tmp = half_to_float(res.x);
return __float2bfloat16(tmp);
}
// fp8x2 -> __nv_bfloat162
template<>
__inline__ __device__ __mt_bfloat162 vec_conversion<__mt_bfloat162, uint16_t>(const uint16_t& a)
{
__mt_bfloat162 res;
res.x = vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)a);
res.y = vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)(a >> 8U));
return res;
}
// fp8x4 -> bf16_4_t
template<>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
{
bf16_4_t res;
res.x = vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)a);
res.y = vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> bf16_8_t
template<>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
{
bf16_4_t tmp1, tmp2;
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
bf16_8_t res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// fp8 -> float
template<>
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
{
// fp8 -> half
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
// half -> float
return half_to_float(tmp);
}
// fp8x2 -> float2
template<>
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
{
// fp8x2 -> half2
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
// half2 -> float2
return half2_to_float2(tmp);
}
// fp8x4 -> float4
template<>
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
{
Float4_ res;
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
return res;
}
// fp8x8 -> float8
template<>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
{
Float4_ tmp1, tmp2;
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
Float8_ res;
res.x = tmp1.x;
res.y = tmp1.y;
res.z = tmp2.x;
res.w = tmp2.y;
return res;
}
// half -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
{
__half_raw tmp;
tmp.x = a;
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
}
// bf16 -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __mt_bfloat16>(const __mt_bfloat16& a)
{
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
assert(false);
#else
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__mt_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
#endif
}
// float -> fp8
template<>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
{
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
return (uint8_t)res;
}
// fp8x4 -> float4
template<>
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
{
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
return res;
}
template<>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
{
union {
half2 float16;
uint32_t uint32;
};
float16 = __float22half2_rn(a);
return uint32;
}
template<>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
{
uint2 b;
float2 val;
val.x = a.x.x;
val.y = a.x.y;
b.x = vec_conversion<uint32_t, float2>(val);
val.x = a.y.x;
val.y = a.y.y;
b.y = vec_conversion<uint32_t, float2>(val);
return b;
}
template<>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
{
float4 b;
b.x = a.x.x;
b.y = a.x.y;
b.z = a.y.x;
b.w = a.y.y;
return b;
}
template<>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
{
uint4 b;
b.x = vec_conversion<uint32_t, float2>(a.x);
b.y = vec_conversion<uint32_t, float2>(a.y);
b.z = vec_conversion<uint32_t, float2>(a.z);
b.w = vec_conversion<uint32_t, float2>(a.w);
return b;
}
template<>
__inline__ __device__ __mt_bfloat162 vec_conversion<__mt_bfloat162, float2>(const float2 &a) {
__mt_bfloat162 b;
from_float(b, a);
return b;
}
template<>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
bf16_4_t b;
from_float(b, a);
return b;
}
template<>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
bf16_8_t b;
from_float(b, a);
return b;
}
} // namespace fp8_e5m2_unscaled
#endif // ENABLE_FP8_E5M2
} // namespace vllm

View File

@@ -0,0 +1,64 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _compat_cuh
#define _compat_cuh
namespace vllm {
namespace gptq {
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__MUSA_ARCH__) || defined(USE_ROCM)
#if __MUSA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __MUSA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
} // namespace gptq
} // namespace vllm
#endif

View File

@@ -0,0 +1,274 @@
/*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
*/
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <musa_runtime.h>
#include <musa_fp16.h>
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
{
half2* ptr = (half2*) item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __low2half(i01);
items[1] = __high2half(i01);
items[2] = __low2half(i23);
items[3] = __high2half(i23);
}
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2float(__low2half(i01));
items[1] = __half2float(__high2half(i01));
items[2] = __half2float(__low2half(i23));
items[3] = __half2float(__high2half(i23));
}
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2half2(__low2half(i01));
items[1] = __half2half2(__high2half(i01));
items[2] = __half2half2(__low2half(i23));
items[3] = __half2half2(__high2half(i23));
}
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
{
half2 v01 = __halves2half2(v0, v1);
half2 v23 = __halves2half2(v2, v3);
half2* ptr = (half2*) item_ptr(row, column);
ptr[0] = v01;
ptr[1] = v23;
}
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
items[2] = (d >> 8) & 0x0f;
items[3] = (d >> 12) & 0x0f;
}
};
class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
class MatrixView_q2_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x0f) * 2;
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x0f) * 2;
uint32_t d = data[row * width / 16 + column / 16] >> shift;
items[0] = d & 0x03;
items[1] = (d >> 2) & 0x03;
items[2] = (d >> 4) & 0x03;
items[3] = (d >> 6) & 0x03;
}
};
class MatrixView_q3_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int z_w = column * 3 / 32;
int z_mod = column & 0x1f;
if (z_mod == 10) {
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
} else if (z_mod == 21) {
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
} else if (z_mod < 10) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
} else if (z_mod < 21) {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
} else {
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
}
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x1f);
uint32_t d;
if (shift <= 4) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
} else if (shift == 8) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
} else if (shift <= 16) {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
} else if (shift == 20) {
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
} else {
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
}
items[0] = d & 0x07;
items[1] = (d >> 3) & 0x07;
items[2] = (d >> 6) & 0x07;
items[3] = (d >> 9) & 0x07;
}
};
class MatrixView_q8_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x03) * 8;
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x03) * 8;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x03) * 2;
uint32_t d = data[row * width / 4 + column / 4] >> shift;
items[0] = d & 0xff;
items[1] = (d >> 8) & 0xff;
items[2] = (d >> 16) & 0xff;
items[3] = (d >> 24) & 0xff;
}
};
} // namespace gptq
} // namespace vllm
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,87 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_2_cuh
#define _qdq_2_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
// Permutation:
//
// ffddbb99 77553311 eeccaa88 66442200
__forceinline__ __device__ void shuffle_2bit_16
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 8; i++)
{
uint32_t qa0 = qa & 0x03;
uint32_t qa1 = (qa & 0x0c) >> 2;
qa >>= 4;
qb |= (qa1 << (i * 2 + 16));
qb |= (qa0 << (i * 2));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_2bit_16
(
const uint32_t q_0,
half2 (&dq)[8],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y4_ = __float2half_rn(1.0f / 4.0f);
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y4 = __halves2half2(y4_, y4_);
const half2 y16 = __halves2half2(y16_, y16_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z4 = __half2half2(z4_);
const half2 z16 = __half2half2(z16_);
const half2 z64 = __half2half2(z64_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
qa >>= 8;
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y4, z4);
dq[2] = __hfma2(q2.as_half2, y16, z16);
dq[3] = __hfma2(q3.as_half2, y64, z64);
dq[4] = __hadd2(q4.as_half2, z1);
dq[5] = __hfma2(q5.as_half2, y4, z4);
dq[6] = __hfma2(q6.as_half2, y16, z16);
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
} // namespace gptq
} // namespace vllm
#endif

View File

@@ -0,0 +1,141 @@
#ifndef _qdq_3_cuh
#define _qdq_3_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
// vtttrrrp ppnnnlll usssqqqo oommmkkk
__forceinline__ __device__ void shuffle_3bit_32
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0 * stride];
uint32_t qb = q[1 * stride];
uint32_t qc = q[2 * stride];
// qa: aa999888 77766655 54443332 22111000
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
uint32_t qd = qc >> 26;
qc <<= 4;
qc |= qb >> 28;
qb <<= 2;
qb |= qa >> 30;
// qa: ..999888 77766655 54443332 22111000
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
// qd: vvvuuu
uint32_t za = 0;
uint32_t zb = 0;
uint32_t zc = 0;
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
// za: 9997775 55333111 8886664 44222000
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
// qd: vvvuuu
za |= ((qd & 0x01) >> 0) << 15;
zb |= ((qd & 0x02) >> 1) << 15;
zc |= ((qd & 0x04) >> 2) << 15;
za |= ((qd & 0x08) >> 3) << 31;
zb |= ((qd & 0x10) >> 4) << 31;
zc |= ((qd & 0x20) >> 5) << 31;
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
q[0 * stride] = za;
q[1 * stride] = zb;
q[2 * stride] = zc;
}
__forceinline__ __device__ void dequant_3bit_32
(
const uint32_t q_0,
const uint32_t q_1,
const uint32_t q_2,
half2 (&dq)[16],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y8_ = __float2half_rn(1.0f / 8.0f);
const half y64_ = __float2half_rn(1.0f / 64.0f);
const half2 y8 = __halves2half2(y8_, y8_);
const half2 y64 = __halves2half2(y64_, y64_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
const half2 z8 = __halves2half2(z8_, z8_);
const half2 z64 = __halves2half2(z64_, z64_);
uint32_t qa = q_0;
uint32_t qb = q_1;
uint32_t qc = q_2;
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
qa >>= 6;
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
qa >>= 9;
qa &= 0x00010001;
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
qb >>= 6;
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
qb >>= 8;
qb &= 0x00020002;
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
qc >>= 6;
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
qc >>= 7;
qc &= 0x00040004;
half2_uint32 q15((qa | qb | qc) | c0);
dq[ 0] = __hadd2( q0.as_half2, z1);
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
dq[ 2] = __hadd2( q2.as_half2, z1);
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
dq[ 5] = __hadd2( q5.as_half2, z1);
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
dq[ 7] = __hadd2( q7.as_half2, z1);
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
dq[10] = __hadd2(q10.as_half2, z1);
dq[11] = __hfma2(q11.as_half2, y8, z8);
dq[12] = __hadd2(q12.as_half2, z1);
dq[13] = __hfma2(q13.as_half2, y8, z8);
dq[14] = __hfma2(q14.as_half2, y64, z64);
dq[15] = __hadd2(q15.as_half2, z1);
}
} // namespace gptq
} // namespace vllm
#endif

View File

@@ -0,0 +1,147 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 4; i++)
{
uint32_t qa0 = qa & 0x0f;
uint32_t qa1 = (qa & 0xf0) >> 4;
qa >>= 8;
qb |= (qa1 << (i * 4 + 16));
qb |= (qa0 << (i * 4));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride,
const uint32_t zero
)
{
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
const half2 z1 = __half2half2(z1_.as_half);
const half2 z16 = __half2half2(z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y16, z16);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y16, z16);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1z16)[2],
half2 (&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
half2 scale2 = __half2half2(scale);
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
z1z16[1] = __hmul2(scale2, __half2half2(z16));
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __hmul2(scale2, __half2half2(y1));
y1y16[1] = __hmul2(scale2, __half2half2(y16));
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1z16)[2],
half2(&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
z1z16[0] = __half2half2(z1.as_half);
z1z16[1] = __half2half2(z16);
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __half2half2(y1);
y1y16[1] = __half2half2(y16);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1z16)[2],
half2 (&y1y16)[2],
int stride,
bool scaled
)
{
const uint32_t c0 = 0x64006400;
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled)
{
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
}
else
{
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
}
}
} // namespace gptq
} // namespace vllm
#endif

View File

@@ -0,0 +1,40 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_8_cuh
#define _qdq_8_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
__forceinline__ __device__ void shuffle_8bit_4
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_8bit_8
(
const uint32_t q_0,
const uint32_t q_1,
half2 (&dq)[4],
int stride,
const uint32_t zero
)
{
half dqh[8];
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
} // namespace gptq
} // namespace vllm
#endif

View File

@@ -0,0 +1,60 @@
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
namespace vllm {
namespace gptq {
union half2_uint32
{
uint32_t as_uint32;
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
};
union half_uint16
{
uint16_t as_uint16;
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
};
// Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
{
int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
{
return __hmul(__int2half_rn(q - qzero), scale);
}
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero);
}
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
{
return (int)((q >> shift) & mask);
}
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
{
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
} // namespace gptq
} // namespace vllm
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,70 @@
#pragma once
#include <torch/extension.h>
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
#include "torch_musa/csrc/core/MUSAGuard.h"
#include <musa.h>
#include <musa_fp16.h>
#include <musa_runtime.h>
#include <iostream>
namespace gptq_marlin {
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have
// many registers per warp and small tiles.
static constexpr int default_threads = 256;
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};
using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
// No support for async
#else
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); }
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
#endif
} // namespace gptq_marlin

View File

@@ -0,0 +1,352 @@
#include "gptq_marlin.cuh"
namespace gptq_marlin {
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {}
} // namespace gptq_marlin
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int perm_size = tile_k_size / 4;
int4 *sh_perm_ptr = sh;
int4 *sh_pipe_ptr = sh_perm_ptr;
if constexpr (has_perm) {
sh_pipe_ptr += perm_size;
}
constexpr int tile_ints = tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
int4 const *perm_int4_ptr = reinterpret_cast<int4 const *>(perm_ptr);
if (threadIdx.x < perm_size) {
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
}
__syncthreads();
};
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe;
if constexpr (has_perm) {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
uint32_t const *sh_perm_int_ptr =
reinterpret_cast<uint32_t const *>(sh_perm_ptr);
int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor;
cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(&(
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
}
} else {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)])));
}
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
constexpr int sh_stride = 64;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);
uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);
uint32_t vals[8];
if constexpr (has_perm) {
for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i];
uint32_t src_k = sh_perm_int_ptr[k_idx];
uint32_t src_k_pos = src_k % pack_factor;
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
vals[i] = b1_cur_val;
vals[4 + i] = b2_cur_val;
}
} else {
uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints];
#pragma unroll
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
}
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
if constexpr (has_perm) {
load_perm_to_shared(k_tile_id);
}
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
} // namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
musaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
musaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", pack_factor = ", pack_factor);
TORCH_CHECK(b_q_weight.size(1) == size_n,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not size_n = ", size_n);
// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
// Alloc buffers
const at::musa::OptionalMUSAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out =
torch::empty({size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / pack_factor},
options);
// Detect if there is act_order
bool has_perm = perm.size(0) != 0;
// Get ptrs
uint32_t const *b_q_weight_ptr =
reinterpret_cast<uint32_t const *>(b_q_weight.data_ptr());
uint32_t const *perm_ptr =
reinterpret_cast<uint32_t const *>(perm.data_ptr());
uint32_t *out_ptr = reinterpret_cast<uint32_t *>(out.data_ptr());
// Get dev info
int dev = b_q_weight.get_device();
musaStream_t stream = at::cuda::getCurrentMUSAStream(dev);
int blocks;
musaDeviceGetAttribute(&blocks, musaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
musaDeviceGetAttribute(&max_shared_mem,
musaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (false) {
}
CALL_IF(4, false)
CALL_IF(4, true)
CALL_IF(8, false)
CALL_IF(8, true)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm);
}
return out;
}
#endif

View File

@@ -0,0 +1,209 @@
Contains code from https://github.com/IST-DASLab/marlin
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
------------------------------------------------------------------------------------
This product bundles various third-party components under other open source licenses.
This section summarizes those components and their licenses. See licenses/
for text of these licenses.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,225 @@
#include <torch/all.h>
#include <torch/python.h>
#include <musa.h>
#include <musa_runtime.h>
#include <musa_fp16.h>
// half-tensor
#include "torch_musa/csrc/core/MUSAStream.h"
#include <ATen/musa/MUSA_PORT_TensorMethods.muh>
#include "torch_musa/csrc/core/MUSAGuard.h"
#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
namespace vllm {
namespace squeezellm {
__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}
// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM
const half2* __restrict__ vec,
#else
const __half2* __restrict__ vec,
#endif
const int* __restrict__ mat,
#ifndef USE_ROCM
half2* __restrict__ mul,
#else
float2* __restrict__ mul,
#endif
const __half* __restrict__ lookup_table,
int height,
int width,
int batch,
int vec_height
) {
const int blockwidth2 = BLOCKWIDTH / 2;
int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
#ifndef USE_ROCM
__shared__ half2 blockvec[blockwidth2];
#else
__shared__ __half2 blockvec[blockwidth2];
#endif
__shared__ __half deq2[16][BLOCKWIDTH];
int off = threadIdx.x;
int column_offset = col * 16;
for (int val = 0; val < 16; val += 1) {
int lut_index = column_offset + val;
deq2[val][off] = lookup_table[lut_index];
}
__half res;
#ifndef USE_ROCM
half2 res2;
half2 tmp2;
#else
__half2 res2;
__half2 tmp2;
#endif
int i;
int k;
unsigned int tmp1;
unsigned int lut_index1, lut_index2;
for (int b = 0; b < batch; ++b){
i = width * row + col;
res = __int2half_rd(0);
k = 0;
__syncthreads();
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
__syncthreads();
while (k < blockwidth2) {
tmp1 = as_unsigned(mat[i]);
#ifndef USE_ROCM
res2 = {};
tmp2 = {};
#else
res2.x = __half_as_ushort(__float2half(0));
res2.y = __half_as_ushort(__float2half(0));
tmp2.x = __half_as_ushort(__float2half(0));
tmp2.y = __half_as_ushort(__float2half(0));
#endif
lut_index1 = tmp1 & 0xF;
lut_index2 = (tmp1 >> 4) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
lut_index1 = (tmp1 >> 8) & 0xF;
lut_index2 = (tmp1 >> 12) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
lut_index1 = (tmp1 >> 16) & 0xF;
lut_index2 = (tmp1 >> 20) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
lut_index1 = (tmp1 >> 24) & 0xF;
lut_index2 = (tmp1 >> 28) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
#ifndef USE_ROCM
res = __hadd(__hadd(res2.x, res2.y), res);
#else
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
#endif
i += width;
k += 4;
}
// col%2 -> only set one of the two values
#ifndef USE_ROCM
half2 res3 = {};
if (col % 2 == 0) {
res3.x = res;
} else {
res3.y = res;
}
#else
__half2 res3;
res3.x = __half_as_ushort(__float2half(0));
res3.y = __half_as_ushort(__float2half(0));
if (col % 2 == 0) {
res3.x = __half_as_ushort(res);
} else {
res3.y = __half_as_ushort(res);
}
#endif
#ifndef USE_ROCM
atomicAdd(&mul[b * width / 2 + col / 2], res3);
#else
int tmp_addr = b * width / 2 + col / 2;
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
#endif
}
}
} // namespace squeezellm
} // namespace vllm
// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table
) {
int height = mat.size(0);
int width = mat.size(1);
int batch = vec.size(0);
int vec_height = vec.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
const at::musa::OptionalMUSAGuard device_guard(device_of(vec));
const musaStream_t stream = at::musa::getCurrentMUSAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM
(half2*) vec.data<at::Half>(),
#else
(__half2*) vec.data_ptr<at::Half>(),
#endif
mat.data_ptr<int>(),
#ifndef USE_ROCM
(half2*) mul.data<at::Half>(),
(__half*) lookup_table.data<at::Half>(),
#else
(float2*) mul.data_ptr<float>(),
(__half*) lookup_table.data_ptr<at::Half>(),
#endif
height, width, batch, vec_height
);
}
#undef BLOCKWIDTH
#undef BLOCKHEIGHT4

View File

@@ -0,0 +1,66 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "musa_compat.h"
namespace vllm {
template<typename T, int numLanes = WARP_SIZE>
__inline__ __device__ T warpReduceSum(T val) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
static_assert(numLanes <= WARP_SIZE);
#pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask);
return val;
}
// Helper function to return the next largest power of 2
static constexpr int _nextPow2(unsigned int num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
/* Calculate the sum of all elements in a block */
template<typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
static_assert(maxBlockSize <= 1024);
if constexpr (maxBlockSize > WARP_SIZE) {
val = warpReduceSum<T>(val);
// Calculates max number of lanes that need to participate in the last warpReduce
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
static __shared__ T shared[maxActiveLanes];
int lane = threadIdx.x % WARP_SIZE;
int wid = threadIdx.x / WARP_SIZE;
if (lane == 0)
shared[wid] = val;
__syncthreads();
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
} else {
// A single warpReduce is equal to blockReduce
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
}
return val;
}
} // namespace vllm