init
This commit is contained in:
161
csrc_musa/activation_kernels.mu
Normal file
161
csrc_musa/activation_kernels.mu
Normal file
@@ -0,0 +1,161 @@
|
||||
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
|
||||
#include <torch/extension.h>
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "musa_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Activation and gating kernel template.
|
||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
__global__ void act_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x) * y;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||
// x * sigmoid(x)
|
||||
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
||||
// Equivalent to PyTorch GELU with 'none' approximation.
|
||||
// Refer to:
|
||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
|
||||
const float f = (float) x;
|
||||
constexpr float ALPHA = M_SQRT1_2;
|
||||
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||
// Equivalent to PyTorch GELU with 'tanh' approximation.
|
||||
// Refer to:
|
||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
|
||||
const float f = (float) x;
|
||||
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
|
||||
constexpr float KAPPA = 0.044715;
|
||||
float x_cube = f * f * f;
|
||||
float inner = BETA * (f + KAPPA * x_cube);
|
||||
return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Launch activation and gating kernel.
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(input)); \
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"act_and_mul_kernel", \
|
||||
[&] { \
|
||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
d); \
|
||||
});
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||
}
|
||||
|
||||
void gelu_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||
}
|
||||
|
||||
void gelu_tanh_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Element-wise activation kernel template.
|
||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
__global__ void activation_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., d]
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Launch element-wise activation kernel.
|
||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||
int d = input.size(-1); \
|
||||
int64_t num_tokens = input.numel() / d; \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(input)); \
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"activation_kernel", \
|
||||
[&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
d); \
|
||||
});
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
||||
const float x3 = (float) (x * x * x);
|
||||
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
|
||||
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||
const float f = (float) x;
|
||||
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
||||
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||
}
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||
}
|
||||
7
csrc_musa/attention/attention_dtypes.h
Normal file
7
csrc_musa/attention/attention_dtypes.h
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "attention_generic.muh"
|
||||
#include "dtype_float16.muh"
|
||||
#include "dtype_float32.muh"
|
||||
#include "dtype_bfloat16.muh"
|
||||
#include "dtype_fp8.muh"
|
||||
65
csrc_musa/attention/attention_generic.muh
Normal file
65
csrc_musa/attention/attention_generic.muh
Normal file
@@ -0,0 +1,65 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// A vector type to store Q, K, V elements.
|
||||
template<typename T, int VEC_SIZE>
|
||||
struct Vec {};
|
||||
|
||||
// A vector type to store FP32 accumulators.
|
||||
template<typename T>
|
||||
struct FloatVec {};
|
||||
|
||||
// Template vector operations.
|
||||
template<typename Acc, typename A, typename B>
|
||||
inline __device__ Acc mul(A a, B b);
|
||||
|
||||
template<typename T>
|
||||
inline __device__ float sum(T v);
|
||||
|
||||
template<typename T>
|
||||
inline __device__ float dot(T a, T b) {
|
||||
return sum(mul<T, T, T>(a, b));
|
||||
}
|
||||
|
||||
template<typename A, typename T>
|
||||
inline __device__ float dot(T a, T b) {
|
||||
return sum(mul<A, T, T>(a, b));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline __device__ void zero(T& dst) {
|
||||
constexpr int WORDS = sizeof(T) / 4;
|
||||
union {
|
||||
T raw;
|
||||
uint32_t words[WORDS];
|
||||
} tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < WORDS; ++ii) {
|
||||
tmp.words[ii] = 0u;
|
||||
}
|
||||
dst = tmp.raw;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
981
csrc_musa/attention/attention_kernels.mu
Normal file
981
csrc_musa/attention/attention_kernels.mu
Normal file
@@ -0,0 +1,981 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.muh"
|
||||
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
#include "../quantization/fp8/amd_detail/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 __mt_bfloat16;
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Utility function for attention softmax.
|
||||
template<int NUM_WARPS>
|
||||
inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Decompose the thread index into warp / lane.
|
||||
int warp = threadIdx.x / WARP_SIZE;
|
||||
int lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// Compute the sum per warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Warp leaders store the data to shared memory.
|
||||
if (lane == 0) {
|
||||
red_smem[warp] = sum;
|
||||
}
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
// The warps compute the final sums.
|
||||
if (lane < NUM_WARPS) {
|
||||
sum = red_smem[lane];
|
||||
}
|
||||
|
||||
// Parallel reduction inside the warp.
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Broadcast to other threads.
|
||||
return VLLM_SHFL_SYNC(sum, 0);
|
||||
}
|
||||
|
||||
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_KV_CACHE,
|
||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||
__device__ void paged_attention_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const float kv_scale) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int partition_idx = blockIdx.z;
|
||||
const int max_num_partitions = gridDim.z;
|
||||
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
|
||||
// No work to do. Terminate the thread block.
|
||||
return;
|
||||
}
|
||||
|
||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
|
||||
|
||||
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
||||
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
||||
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
|
||||
const int num_blocks = end_block_idx - start_block_idx;
|
||||
|
||||
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
||||
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
||||
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
|
||||
const int num_tokens = end_token_idx - start_token_idx;
|
||||
|
||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
||||
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int thread_idx = threadIdx.x;
|
||||
const int warp_idx = thread_idx / WARP_SIZE;
|
||||
const int lane = thread_idx % WARP_SIZE;
|
||||
|
||||
const int head_idx = blockIdx.x;
|
||||
const int num_heads = gridDim.x;
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
const int kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||
|
||||
// A vector type to store a part of a key or a query.
|
||||
// The vector size is configured in such a way that the threads in a thread group
|
||||
// fetch or compute 16 bytes at a time.
|
||||
// For example, if the size of a thread group is 4 and the data type is half,
|
||||
// then the vector size is 16 / (4 * sizeof(half)) == 2.
|
||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
||||
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
||||
#endif
|
||||
|
||||
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
||||
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
||||
|
||||
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
|
||||
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
|
||||
|
||||
// Load the query to registers.
|
||||
// Each thread in a thread group has a different part of the query.
|
||||
// For example, if the the thread group size is 4, then the first thread in the group
|
||||
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
||||
// th vectors of the query, and so on.
|
||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
||||
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||
}
|
||||
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
||||
|
||||
// Memory planning.
|
||||
extern __shared__ char shared_mem[];
|
||||
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
|
||||
float* logits = reinterpret_cast<float*>(shared_mem);
|
||||
// Workspace for reduction.
|
||||
__shared__ float red_smem[2 * NUM_WARPS];
|
||||
|
||||
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
||||
// Each thread group fetches x elements from the key at a time.
|
||||
constexpr int x = 16 / sizeof(cache_t);
|
||||
float qk_max = -FLT_MAX;
|
||||
|
||||
// Iterate over the key blocks.
|
||||
// Each warp fetches a block of keys for each iteration.
|
||||
// Each thread group in a warp fetches a key from the block, and computes
|
||||
// dot product with the query.
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
||||
// (e.g., kv_block_stride).
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||
|
||||
// Load a key to registers.
|
||||
// Each thread in a thread group has a different part of the key.
|
||||
// For example, if the the thread group size is 4, then the first thread in the group
|
||||
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
|
||||
// vectors of the key, and so on.
|
||||
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
||||
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||
if constexpr (IS_FP8_KV_CACHE) {
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
// Vector conversion from Quant_vec to K_vec.
|
||||
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
// Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k
|
||||
// cache vec to k vec in higher precision (FP16, BFloat16, etc.)
|
||||
k_vecs[j] = fp8_e4m3::scaled_vec_conversion<K_vec, Quant_vec>(k_vec_quant, kv_scale);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute dot product.
|
||||
// This includes a reduction across the threads in the same thread group.
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
// Add the ALiBi bias if slopes are given.
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
|
||||
|
||||
if (thread_group_offset == 0) {
|
||||
// Store the partial reductions to shared memory.
|
||||
// NOTE(woosuk): It is required to zero out the masked logits.
|
||||
const bool mask = token_idx >= seq_len;
|
||||
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
||||
// Update the max value.
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform reduction across the threads in the same warp to get the
|
||||
// max qk value for each "warp" (not across the thread block yet).
|
||||
// The 0-th thread of each thread group already has its max qk value.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// TODO(woosuk): Refactor this part.
|
||||
// Get the max qk value for the sequence.
|
||||
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
// Broadcast the max qk value to all threads.
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
|
||||
// Get the sum of the exp values.
|
||||
float exp_sum = 0.f;
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
float val = __expf(logits[i] - qk_max);
|
||||
logits[i] = val;
|
||||
exp_sum += val;
|
||||
}
|
||||
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
||||
|
||||
// Compute softmax.
|
||||
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
logits[i] *= inv_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// If partitioning is enabled, store the max logit and exp_sum.
|
||||
if (USE_PARTITIONING && thread_idx == 0) {
|
||||
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions
|
||||
+ partition_idx;
|
||||
*max_logits_ptr = qk_max;
|
||||
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions
|
||||
+ partition_idx;
|
||||
*exp_sums_ptr = exp_sum;
|
||||
}
|
||||
|
||||
// Each thread will fetch 16 bytes from the value cache at a time.
|
||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
||||
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
||||
#endif
|
||||
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
||||
|
||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
||||
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
||||
|
||||
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
||||
float accs[NUM_ROWS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
accs[i] = 0.f;
|
||||
}
|
||||
|
||||
scalar_t zero_value;
|
||||
zero(zero_value);
|
||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
||||
// (e.g., kv_block_stride).
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||
L_vec logits_vec;
|
||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
||||
|
||||
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE) {
|
||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||
V_vec v_vec;
|
||||
if constexpr (IS_FP8_KV_CACHE) {
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec.
|
||||
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert
|
||||
// FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.)
|
||||
v_vec = fp8_e4m3::scaled_vec_conversion<V_vec, V_quant_vec>(v_quant_vec, kv_scale);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
}
|
||||
if (block_idx == num_seq_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||
// we should explicitly zero out the values since they may contain NaNs.
|
||||
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < V_VEC_SIZE; j++) {
|
||||
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
|
||||
}
|
||||
}
|
||||
accs[i] += dot(logits_vec, v_vec);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform reduction within each warp.
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
float acc = accs[i];
|
||||
#pragma unroll
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
|
||||
// NOTE(woosuk): A barrier is required because the shared memory space for logits
|
||||
// is reused for the output.
|
||||
__syncthreads();
|
||||
|
||||
// Perform reduction across warps.
|
||||
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
||||
#pragma unroll
|
||||
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
||||
int mid = i / 2;
|
||||
// Upper warps write to shared memory.
|
||||
if (warp_idx >= mid && warp_idx < i) {
|
||||
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||
dst[row_idx] = accs[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Lower warps update the output.
|
||||
if (warp_idx < mid) {
|
||||
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||
accs[i] += src[row_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Write the final output.
|
||||
if (warp_idx == 0) {
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
||||
+ partition_idx * HEAD_SIZE;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||
from_float(*(out_ptr + row_idx), accs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs, 1).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_KV_CACHE>
|
||||
__global__ void paged_attention_v1_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const float kv_scale) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_KV_CACHE,
|
||||
int PARTITION_SIZE>
|
||||
__global__ void paged_attention_v2_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride,
|
||||
const float kv_scale) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_KV_CACHE, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||
q_stride, kv_block_stride, kv_head_stride, kv_scale);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs).
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int NUM_THREADS,
|
||||
int PARTITION_SIZE>
|
||||
__global__ void paged_attention_v2_reduce_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_partitions) {
|
||||
const int num_heads = gridDim.x;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
||||
if (num_partitions == 1) {
|
||||
// No need to reduce. Only copy tmp_out to out.
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
||||
out_ptr[i] = tmp_out_ptr[i];
|
||||
}
|
||||
// Terminate the thread block.
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warp_idx = threadIdx.x / WARP_SIZE;
|
||||
const int lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// Size: 2 * num_partitions.
|
||||
extern __shared__ char shared_mem[];
|
||||
// Workspace for reduction.
|
||||
__shared__ float red_smem[2 * NUM_WARPS];
|
||||
|
||||
// Load max logits to shared memory.
|
||||
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
||||
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions;
|
||||
float max_logit = -FLT_MAX;
|
||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||
const float l = max_logits_ptr[i];
|
||||
shared_max_logits[i] = l;
|
||||
max_logit = fmaxf(max_logit, l);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Get the global max logit.
|
||||
// Reduce within the warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = max_logit;
|
||||
}
|
||||
__syncthreads();
|
||||
// Reduce across warps.
|
||||
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
// Broadcast the max value to all threads.
|
||||
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
||||
|
||||
// Load rescaled exp sums to shared memory.
|
||||
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions;
|
||||
float global_exp_sum = 0.0f;
|
||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||
float l = shared_max_logits[i];
|
||||
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
|
||||
global_exp_sum += rescaled_exp_sum;
|
||||
shared_exp_sums[i] = rescaled_exp_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
|
||||
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
||||
|
||||
// Aggregate tmp_out to out.
|
||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
#pragma unroll
|
||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
||||
float acc = 0.0f;
|
||||
for (int j = 0; j < num_partitions; ++j) {
|
||||
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
|
||||
}
|
||||
from_float(out_ptr[i], acc);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_KV_CACHE>), shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
seq_lens_ptr, \
|
||||
max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, \
|
||||
q_stride, \
|
||||
kv_block_stride, \
|
||||
kv_head_stride, \
|
||||
kv_scale);
|
||||
|
||||
// TODO(woosuk): Tune NUM_THREADS.
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
bool IS_FP8_KV_CACHE,
|
||||
int NUM_THREADS = 128>
|
||||
void paged_attention_v1_launcher(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float kv_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
assert(head_size % thread_group_size == 0);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_seq_len * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||
// Keep that in sync with the logic here!
|
||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||
|
||||
dim3 grid(num_heads, num_seqs, 1);
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(query));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
// head sizes that we use in the model. However, we can easily extend this
|
||||
// to support any head size which is a multiple of 16.
|
||||
case 64:
|
||||
LAUNCH_PAGED_ATTENTION_V1(64);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_PAGED_ATTENTION_V1(80);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_PAGED_ATTENTION_V1(96);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_PAGED_ATTENTION_V1(112);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_PAGED_ATTENTION_V1(128);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_PAGED_ATTENTION_V1(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
|
||||
out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
seq_lens, \
|
||||
max_seq_len, \
|
||||
alibi_slopes, \
|
||||
kv_scale);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale) {
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, __mt_bfloat16, false);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, uint8_t, true);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_KV_CACHE, PARTITION_SIZE> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
tmp_out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
seq_lens_ptr, \
|
||||
max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, \
|
||||
q_stride, \
|
||||
kv_block_stride, \
|
||||
kv_head_stride, \
|
||||
kv_scale); \
|
||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
tmp_out_ptr, \
|
||||
seq_lens_ptr, \
|
||||
max_num_partitions);
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
bool IS_FP8_KV_CACHE,
|
||||
int NUM_THREADS = 128,
|
||||
int PARTITION_SIZE = 512>
|
||||
void paged_attention_v2_launcher(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
float kv_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
assert(head_size % thread_group_size == 0);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
||||
int logits_size = PARTITION_SIZE * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
|
||||
// For paged attention v2 kernel.
|
||||
dim3 grid(num_heads, num_seqs, max_num_partitions);
|
||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||
// For paged attention v2 reduce kernel.
|
||||
dim3 reduce_grid(num_heads, num_seqs);
|
||||
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(query));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
// head sizes that we use in the model. However, we can easily extend this
|
||||
// to support any head size which is a multiple of 16.
|
||||
case 64:
|
||||
LAUNCH_PAGED_ATTENTION_V2(64);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_PAGED_ATTENTION_V2(80);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_PAGED_ATTENTION_V2(96);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_PAGED_ATTENTION_V2(112);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_PAGED_ATTENTION_V2(128);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_PAGED_ATTENTION_V2(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE>( \
|
||||
out, \
|
||||
exp_sums, \
|
||||
max_logits, \
|
||||
tmp_out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
seq_lens, \
|
||||
max_seq_len, \
|
||||
alibi_slopes, \
|
||||
kv_scale);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale) {
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, __mt_bfloat16, false);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, uint8_t, true);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
57
csrc_musa/attention/attention_utils.muh
Normal file
57
csrc_musa/attention/attention_utils.muh
Normal file
@@ -0,0 +1,57 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../musa_compat.h"
|
||||
#include "attention_dtypes.h"
|
||||
|
||||
#include <float.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Q*K^T operation.
|
||||
template<int THREAD_GROUP_SIZE, typename Vec, int N>
|
||||
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||
using A_vec = typename FloatVec<Vec>::Type;
|
||||
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
||||
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < N; ++ii) {
|
||||
qk_vec = fma(q[ii], k[ii], qk_vec);
|
||||
}
|
||||
|
||||
// Finalize the reduction across lanes.
|
||||
float qk = sum(qk_vec);
|
||||
#pragma unroll
|
||||
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
|
||||
}
|
||||
return qk;
|
||||
}
|
||||
|
||||
template<typename T, int THREAD_GROUP_SIZE>
|
||||
struct Qk_dot {
|
||||
template<typename Vec, int N>
|
||||
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace vllm
|
||||
452
csrc_musa/attention/dtype_bfloat16.muh
Normal file
452
csrc_musa/attention/dtype_bfloat16.muh
Normal file
@@ -0,0 +1,452 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "attention_generic.muh"
|
||||
#include "dtype_float32.muh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <musa_bf16.h>
|
||||
#include <musa_fp16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
typedef __hip_bfloat162 __mt_bfloat162;
|
||||
typedef __hip_bfloat16 __mt_bfloat16;
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Define custom BF16 vector data types.
|
||||
struct bf16_4_t {
|
||||
__mt_bfloat162 x;
|
||||
__mt_bfloat162 y;
|
||||
};
|
||||
|
||||
struct bf16_8_t {
|
||||
__mt_bfloat162 x;
|
||||
__mt_bfloat162 y;
|
||||
__mt_bfloat162 z;
|
||||
__mt_bfloat162 w;
|
||||
};
|
||||
|
||||
// BF16 vector types for Q, K, V.
|
||||
template<>
|
||||
struct Vec<__mt_bfloat16, 1> {
|
||||
using Type = __mt_bfloat16;
|
||||
};
|
||||
template<>
|
||||
struct Vec<__mt_bfloat16, 2> {
|
||||
using Type = __mt_bfloat162;
|
||||
};
|
||||
template<>
|
||||
struct Vec<__mt_bfloat16, 4> {
|
||||
using Type = bf16_4_t;
|
||||
};
|
||||
template<>
|
||||
struct Vec<__mt_bfloat16, 8> {
|
||||
using Type = bf16_8_t;
|
||||
};
|
||||
|
||||
// FP32 accumulator vector types corresponding to Vec.
|
||||
template<>
|
||||
struct FloatVec<__mt_bfloat16> {
|
||||
using Type = float;
|
||||
};
|
||||
template<>
|
||||
struct FloatVec<__mt_bfloat162> {
|
||||
using Type = float2;
|
||||
};
|
||||
template<>
|
||||
struct FloatVec<bf16_4_t> {
|
||||
using Type = Float4_;
|
||||
};
|
||||
template<>
|
||||
struct FloatVec<bf16_8_t> {
|
||||
using Type = Float8_;
|
||||
};
|
||||
|
||||
// Utility functions for type conversions.
|
||||
inline __device__ float2 bf1622float2(const __mt_bfloat162 val) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __bfloat1622float2(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __mt_bfloat162 bf162bf162(const __mt_bfloat16 val) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __bfloat162bfloat162(val);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Vector addition.
|
||||
inline __device__ __mt_bfloat16 add(__mt_bfloat16 a, __mt_bfloat16 b) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
#ifndef USE_ROCM
|
||||
return a + b;
|
||||
#else
|
||||
return __hadd(a, b);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __mt_bfloat162 add(__mt_bfloat162 a, __mt_bfloat162 b) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __hadd2(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
|
||||
bf16_4_t c;
|
||||
c.x = add(a.x, b.x);
|
||||
c.y = add(a.y, b.y);
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
|
||||
bf16_8_t c;
|
||||
c.x = add(a.x, b.x);
|
||||
c.y = add(a.y, b.y);
|
||||
c.z = add(a.z, b.z);
|
||||
c.w = add(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ float2 add(__mt_bfloat162 a, float2 fb) {
|
||||
float2 fa = bf1622float2(a);
|
||||
return add(fa, fb);
|
||||
}
|
||||
|
||||
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
|
||||
Float4_ fc;
|
||||
fc.x = add(a.x, fb.x);
|
||||
fc.y = add(a.y, fb.y);
|
||||
return fc;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
|
||||
Float8_ fc;
|
||||
fc.x = add(a.x, fb.x);
|
||||
fc.y = add(a.y, fb.y);
|
||||
fc.z = add(a.z, fb.z);
|
||||
fc.w = add(a.w, fb.w);
|
||||
return fc;
|
||||
}
|
||||
|
||||
// Vector multiplication.
|
||||
template<>
|
||||
inline __device__ __mt_bfloat16 mul(__mt_bfloat16 a, __mt_bfloat16 b) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __hmul(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ __mt_bfloat162 mul(__mt_bfloat162 a, __mt_bfloat162 b) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __hmul2(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ __mt_bfloat162 mul(__mt_bfloat16 a, __mt_bfloat162 b) {
|
||||
return mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(bf162bf162(a), b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
|
||||
bf16_4_t c;
|
||||
c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.x, b.x);
|
||||
c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.y, b.y);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ bf16_4_t mul(__mt_bfloat16 a, bf16_4_t b) {
|
||||
__mt_bfloat162 s = bf162bf162(a);
|
||||
bf16_4_t c;
|
||||
c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.x);
|
||||
c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.y);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
|
||||
bf16_8_t c;
|
||||
c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.x, b.x);
|
||||
c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.y, b.y);
|
||||
c.z = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.z, b.z);
|
||||
c.w = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ bf16_8_t mul(__mt_bfloat16 a, bf16_8_t b) {
|
||||
__mt_bfloat162 s = bf162bf162(a);
|
||||
bf16_8_t c;
|
||||
c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.x);
|
||||
c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.y);
|
||||
c.z = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.z);
|
||||
c.w = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float mul(__mt_bfloat16 a, __mt_bfloat16 b) {
|
||||
float fa = __bfloat162float(a);
|
||||
float fb = __bfloat162float(b);
|
||||
return fa * fb;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 mul(__mt_bfloat162 a, __mt_bfloat162 b) {
|
||||
float2 fa = bf1622float2(a);
|
||||
float2 fb = bf1622float2(b);
|
||||
return mul<float2, float2, float2>(fa, fb);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 mul(__mt_bfloat16 a, __mt_bfloat162 b) {
|
||||
return mul<float2, __mt_bfloat162, __mt_bfloat162>(bf162bf162(a), b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
|
||||
Float4_ fc;
|
||||
fc.x = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.x, b.x);
|
||||
fc.y = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.y, b.y);
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float4_ mul(__mt_bfloat16 a, bf16_4_t b) {
|
||||
__mt_bfloat162 s = bf162bf162(a);
|
||||
Float4_ fc;
|
||||
fc.x = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.x);
|
||||
fc.y = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.y);
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
|
||||
Float8_ fc;
|
||||
fc.x = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.x, b.x);
|
||||
fc.y = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.y, b.y);
|
||||
fc.z = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.z, b.z);
|
||||
fc.w = mul<float2, __mt_bfloat162, __mt_bfloat162>(a.w, b.w);
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float8_ mul(__mt_bfloat16 a, bf16_8_t b) {
|
||||
__mt_bfloat162 s = bf162bf162(a);
|
||||
Float8_ fc;
|
||||
fc.x = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.x);
|
||||
fc.y = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.y);
|
||||
fc.z = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.z);
|
||||
fc.w = mul<float2, __mt_bfloat162, __mt_bfloat162>(s, b.w);
|
||||
return fc;
|
||||
}
|
||||
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ __mt_bfloat162 fma(__mt_bfloat162 a, __mt_bfloat162 b, __mt_bfloat162 c) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __hfma2(a, b, c);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ __mt_bfloat162 fma(__mt_bfloat16 a, __mt_bfloat162 b, __mt_bfloat162 c) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __hfma2(bf162bf162(a), b, c);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
|
||||
bf16_4_t d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t fma(__mt_bfloat16 a, bf16_4_t b, bf16_4_t c) {
|
||||
__mt_bfloat162 s = bf162bf162(a);
|
||||
bf16_4_t d;
|
||||
d.x = fma(s, b.x, c.x);
|
||||
d.y = fma(s, b.y, c.y);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
|
||||
bf16_8_t d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
d.z = fma(a.z, b.z, c.z);
|
||||
d.w = fma(a.w, b.w, c.w);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ bf16_8_t fma(__mt_bfloat16 a, bf16_8_t b, bf16_8_t c) {
|
||||
__mt_bfloat162 s = bf162bf162(a);
|
||||
bf16_8_t d;
|
||||
d.x = fma(s, b.x, c.x);
|
||||
d.y = fma(s, b.y, c.y);
|
||||
d.z = fma(s, b.z, c.z);
|
||||
d.w = fma(s, b.w, c.w);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ float fma(__mt_bfloat16 a, __mt_bfloat16 b, float fc) {
|
||||
return __bfloat162float(a) * __bfloat162float(b) + fc;
|
||||
}
|
||||
|
||||
inline __device__ float2 fma(__mt_bfloat162 a, __mt_bfloat162 b, float2 fc) {
|
||||
float2 fa = bf1622float2(a);
|
||||
float2 fb = bf1622float2(b);
|
||||
return fma(fa, fb, fc);
|
||||
}
|
||||
|
||||
inline __device__ float2 fma(__mt_bfloat16 a, __mt_bfloat162 b, float2 fc) {
|
||||
return fma(bf162bf162(a), b, fc);
|
||||
}
|
||||
|
||||
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
|
||||
Float4_ fd;
|
||||
fd.x = fma(a.x, b.x, fc.x);
|
||||
fd.y = fma(a.y, b.y, fc.y);
|
||||
return fd;
|
||||
}
|
||||
|
||||
inline __device__ Float4_ fma(__mt_bfloat16 a, bf16_4_t b, Float4_ fc) {
|
||||
__mt_bfloat162 s = bf162bf162(a);
|
||||
Float4_ fd;
|
||||
fd.x = fma(s, b.x, fc.x);
|
||||
fd.y = fma(s, b.y, fc.y);
|
||||
return fd;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
|
||||
Float8_ fd;
|
||||
fd.x = fma(a.x, b.x, fc.x);
|
||||
fd.y = fma(a.y, b.y, fc.y);
|
||||
fd.z = fma(a.z, b.z, fc.z);
|
||||
fd.w = fma(a.w, b.w, fc.w);
|
||||
return fd;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ fma(__mt_bfloat16 a, bf16_8_t b, Float8_ fc) {
|
||||
__mt_bfloat162 s = bf162bf162(a);
|
||||
Float8_ fd;
|
||||
fd.x = fma(s, b.x, fc.x);
|
||||
fd.y = fma(s, b.y, fc.y);
|
||||
fd.z = fma(s, b.z, fc.z);
|
||||
fd.w = fma(s, b.w, fc.w);
|
||||
return fd;
|
||||
}
|
||||
|
||||
// Vector sum.
|
||||
template<>
|
||||
inline __device__ float sum(__mt_bfloat16 v) {
|
||||
return __bfloat162float(v);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(__mt_bfloat162 v) {
|
||||
float2 vf = bf1622float2(v);
|
||||
return vf.x + vf.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(bf16_4_t v) {
|
||||
return sum(v.x) + sum(v.y);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(bf16_8_t v) {
|
||||
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
|
||||
}
|
||||
|
||||
// From float32 to bfloat16.
|
||||
inline __device__ void from_float(__mt_bfloat16& dst, float src) {
|
||||
dst = __float2bfloat16(src);
|
||||
}
|
||||
|
||||
inline __device__ void from_float(__mt_bfloat162& dst, float2 src) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
dst = __float22bfloat162_rn(src);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
dst.x = __float22bfloat162_rn(src.x);
|
||||
dst.y = __float22bfloat162_rn(src.y);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
dst.x = __float22bfloat162_rn(src.x);
|
||||
dst.y = __float22bfloat162_rn(src.y);
|
||||
dst.z = __float22bfloat162_rn(src.z);
|
||||
dst.w = __float22bfloat162_rn(src.w);
|
||||
#endif
|
||||
}
|
||||
|
||||
// From bfloat16 to float32.
|
||||
inline __device__ float to_float(__mt_bfloat16 u) {
|
||||
return __bfloat162float(u);
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(__mt_bfloat16& dst) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
||||
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
503
csrc_musa/attention/dtype_float16.muh
Normal file
503
csrc_musa/attention/dtype_float16.muh
Normal file
@@ -0,0 +1,503 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "attention_generic.muh"
|
||||
#include "dtype_float32.muh"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// FP16 vector types for Q, K, V.
|
||||
template<>
|
||||
struct Vec<uint16_t, 1> {
|
||||
using Type = uint16_t;
|
||||
};
|
||||
template<>
|
||||
struct Vec<uint16_t, 2> {
|
||||
using Type = uint32_t;
|
||||
};
|
||||
template<>
|
||||
struct Vec<uint16_t, 4> {
|
||||
using Type = uint2;
|
||||
};
|
||||
template<>
|
||||
struct Vec<uint16_t, 8> {
|
||||
using Type = uint4;
|
||||
};
|
||||
|
||||
// FP32 accumulator vector types corresponding to Vec.
|
||||
template<>
|
||||
struct FloatVec<uint16_t> {
|
||||
using Type = float;
|
||||
};
|
||||
template<>
|
||||
struct FloatVec<uint32_t> {
|
||||
using Type = float2;
|
||||
};
|
||||
template<>
|
||||
struct FloatVec<uint2> {
|
||||
using Type = Float4_;
|
||||
};
|
||||
template<>
|
||||
struct FloatVec<uint4> {
|
||||
using Type = Float8_;
|
||||
};
|
||||
|
||||
// Utility functions for type conversions.
|
||||
inline __device__ uint32_t h0_h0(uint16_t a) {
|
||||
#ifndef USE_ROCM
|
||||
uint32_t b;
|
||||
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
||||
return b;
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u16[0] = a;
|
||||
tmp.u16[1] = a;
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ float half_to_float(uint16_t h) {
|
||||
float f;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
||||
#else
|
||||
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
||||
#endif
|
||||
return f;
|
||||
}
|
||||
|
||||
inline __device__ float2 half2_to_float2(uint32_t v) {
|
||||
#ifndef USE_ROCM
|
||||
uint16_t lo, hi;
|
||||
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
||||
return make_float2(half_to_float(lo), half_to_float(hi));
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u32 = v;
|
||||
float2 ret;
|
||||
ret.x = half_to_float(tmp.u16[0]);
|
||||
ret.y = half_to_float(tmp.u16[1]);
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ uint16_t float_to_half(float f) {
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
||||
#else
|
||||
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
||||
#endif
|
||||
return tmp.u16[0];
|
||||
}
|
||||
|
||||
inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
#ifndef USE_ROCM
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
#endif
|
||||
#else
|
||||
tmp.u16[0] = float_to_half(f.x);
|
||||
tmp.u16[1] = float_to_half(f.y);
|
||||
#endif
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// Vector addition.
|
||||
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ uint2 add(uint2 a, uint2 b) {
|
||||
uint2 c;
|
||||
c.x = add(a.x, b.x);
|
||||
c.y = add(a.y, b.y);
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ uint4 add(uint4 a, uint4 b) {
|
||||
uint4 c;
|
||||
c.x = add(a.x, b.x);
|
||||
c.y = add(a.y, b.y);
|
||||
c.z = add(a.z, b.z);
|
||||
c.w = add(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ float2 add(uint32_t a, float2 fb) {
|
||||
float2 fa = half2_to_float2(a);
|
||||
return add(fa, fb);
|
||||
}
|
||||
|
||||
inline __device__ Float4_ add(uint2 a, Float4_ fb) {
|
||||
Float4_ fc;
|
||||
fc.x = add(a.x, fb.x);
|
||||
fc.y = add(a.y, fb.y);
|
||||
return fc;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
||||
Float8_ fc;
|
||||
fc.x = add(a.x, fb.x);
|
||||
fc.y = add(a.y, fb.y);
|
||||
fc.z = add(a.z, fb.z);
|
||||
fc.w = add(a.w, fb.w);
|
||||
return fc;
|
||||
}
|
||||
|
||||
// Vector multiplication.
|
||||
template<>
|
||||
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
|
||||
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint2 mul(uint2 a, uint2 b) {
|
||||
uint2 c;
|
||||
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
||||
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint2 mul(uint16_t a, uint2 b) {
|
||||
uint32_t s = h0_h0(a);
|
||||
uint2 c;
|
||||
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
||||
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint4 mul(uint4 a, uint4 b) {
|
||||
uint4 c;
|
||||
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
|
||||
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
|
||||
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
|
||||
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint4 mul(uint16_t a, uint4 b) {
|
||||
uint32_t s = h0_h0(a);
|
||||
uint4 c;
|
||||
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
|
||||
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
|
||||
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
|
||||
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float mul(uint16_t a, uint16_t b) {
|
||||
float fa = half_to_float(a);
|
||||
float fb = half_to_float(b);
|
||||
return fa * fb;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 mul(uint32_t a, uint32_t b) {
|
||||
float2 fa = half2_to_float2(a);
|
||||
float2 fb = half2_to_float2(b);
|
||||
return mul<float2, float2, float2>(fa, fb);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 mul(uint16_t a, uint32_t b) {
|
||||
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float4_ mul(uint2 a, uint2 b) {
|
||||
Float4_ fc;
|
||||
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
||||
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
|
||||
uint32_t s = h0_h0(a);
|
||||
Float4_ fc;
|
||||
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
||||
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float8_ mul(uint4 a, uint4 b) {
|
||||
Float8_ fc;
|
||||
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
|
||||
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
|
||||
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
|
||||
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
|
||||
return fc;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
||||
uint32_t s = h0_h0(a);
|
||||
Float8_ fc;
|
||||
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
|
||||
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
|
||||
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
|
||||
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
|
||||
return fc;
|
||||
}
|
||||
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||
uint32_t d;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
||||
#endif
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
|
||||
return fma(h0_h0(a), b, c);
|
||||
}
|
||||
|
||||
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
|
||||
uint2 d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
|
||||
uint32_t s = h0_h0(a);
|
||||
uint2 d;
|
||||
d.x = fma(s, b.x, c.x);
|
||||
d.y = fma(s, b.y, c.y);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
|
||||
uint4 d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
d.z = fma(a.z, b.z, c.z);
|
||||
d.w = fma(a.w, b.w, c.w);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
|
||||
uint32_t s = h0_h0(a);
|
||||
uint4 d;
|
||||
d.x = fma(s, b.x, c.x);
|
||||
d.y = fma(s, b.y, c.y);
|
||||
d.z = fma(s, b.z, c.z);
|
||||
d.w = fma(s, b.w, c.w);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
|
||||
float fa = half_to_float(a);
|
||||
float fb = half_to_float(b);
|
||||
return fa * fb + fc;
|
||||
}
|
||||
|
||||
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
|
||||
float2 fa = half2_to_float2(a);
|
||||
float2 fb = half2_to_float2(b);
|
||||
return fma(fa, fb, fc);
|
||||
}
|
||||
|
||||
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
|
||||
return fma(h0_h0(a), b, fc);
|
||||
}
|
||||
|
||||
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
|
||||
Float4_ fd;
|
||||
fd.x = fma(a.x, b.x, fc.x);
|
||||
fd.y = fma(a.y, b.y, fc.y);
|
||||
return fd;
|
||||
}
|
||||
|
||||
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
|
||||
uint32_t s = h0_h0(a);
|
||||
Float4_ fd;
|
||||
fd.x = fma(s, b.x, fc.x);
|
||||
fd.y = fma(s, b.y, fc.y);
|
||||
return fd;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
|
||||
Float8_ fd;
|
||||
fd.x = fma(a.x, b.x, fc.x);
|
||||
fd.y = fma(a.y, b.y, fc.y);
|
||||
fd.z = fma(a.z, b.z, fc.z);
|
||||
fd.w = fma(a.w, b.w, fc.w);
|
||||
return fd;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
|
||||
uint32_t s = h0_h0(a);
|
||||
Float8_ fd;
|
||||
fd.x = fma(s, b.x, fc.x);
|
||||
fd.y = fma(s, b.y, fc.y);
|
||||
fd.z = fma(s, b.z, fc.z);
|
||||
fd.w = fma(s, b.w, fc.w);
|
||||
return fd;
|
||||
}
|
||||
|
||||
// Vector sum.
|
||||
template<>
|
||||
inline __device__ float sum(uint16_t v) {
|
||||
return half_to_float(v);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(uint32_t v) {
|
||||
float2 tmp = half2_to_float2(v);
|
||||
return tmp.x + tmp.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(uint2 v) {
|
||||
uint32_t c = add(v.x, v.y);
|
||||
return sum(c);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(uint4 v) {
|
||||
uint32_t c = add(v.x, v.y);
|
||||
c = add(c, v.z);
|
||||
c = add(c, v.w);
|
||||
return sum(c);
|
||||
}
|
||||
|
||||
// From float32 to float16.
|
||||
inline __device__ void from_float(uint16_t& dst, float src) {
|
||||
dst = float_to_half(src);
|
||||
}
|
||||
|
||||
inline __device__ void from_float(uint32_t& dst, float2 src) {
|
||||
dst = float2_to_half2(src);
|
||||
}
|
||||
|
||||
inline __device__ void from_float(uint2& dst, Float4_ src) {
|
||||
dst.x = float2_to_half2(src.x);
|
||||
dst.y = float2_to_half2(src.y);
|
||||
}
|
||||
|
||||
inline __device__ void from_float(uint4& dst, Float8_ src) {
|
||||
dst.x = float2_to_half2(src.x);
|
||||
dst.y = float2_to_half2(src.y);
|
||||
dst.z = float2_to_half2(src.z);
|
||||
dst.w = float2_to_half2(src.w);
|
||||
}
|
||||
|
||||
// From float16 to float32.
|
||||
inline __device__ float to_float(uint16_t u) {
|
||||
return half_to_float(u);
|
||||
}
|
||||
|
||||
inline __device__ float2 to_float(uint32_t u) {
|
||||
return half2_to_float2(u);
|
||||
}
|
||||
|
||||
inline __device__ Float4_ to_float(uint2 u) {
|
||||
Float4_ tmp;
|
||||
tmp.x = half2_to_float2(u.x);
|
||||
tmp.y = half2_to_float2(u.y);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ to_float(uint4 u) {
|
||||
Float8_ tmp;
|
||||
tmp.x = half2_to_float2(u.x);
|
||||
tmp.y = half2_to_float2(u.y);
|
||||
tmp.z = half2_to_float2(u.z);
|
||||
tmp.w = half2_to_float2(u.w);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(uint16_t& dst) {
|
||||
dst = uint16_t(0);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
274
csrc_musa/attention/dtype_float32.muh
Normal file
274
csrc_musa/attention/dtype_float32.muh
Normal file
@@ -0,0 +1,274 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "attention_generic.muh"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Define custom FP32 vector data types.
|
||||
struct Float4_ {
|
||||
float2 x;
|
||||
float2 y;
|
||||
};
|
||||
|
||||
struct Float8_ {
|
||||
float2 x;
|
||||
float2 y;
|
||||
float2 z;
|
||||
float2 w;
|
||||
};
|
||||
|
||||
// FP32 vector types for Q, K, V.
|
||||
template<>
|
||||
struct Vec<float, 1> {
|
||||
using Type = float;
|
||||
};
|
||||
template<>
|
||||
struct Vec<float, 2> {
|
||||
using Type = float2;
|
||||
};
|
||||
template<>
|
||||
struct Vec<float, 4> {
|
||||
using Type = float4;
|
||||
};
|
||||
|
||||
// FP32 accumulator vector types corresponding to Vec.
|
||||
template<>
|
||||
struct FloatVec<float> {
|
||||
using Type = float;
|
||||
};
|
||||
template<>
|
||||
struct FloatVec<float2> {
|
||||
using Type = float2;
|
||||
};
|
||||
template<>
|
||||
struct FloatVec<float4> {
|
||||
using Type = float4;
|
||||
};
|
||||
|
||||
// Vector addition.
|
||||
inline __device__ float add(float a, float b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
inline __device__ float2 add(float2 a, float2 b) {
|
||||
float2 c;
|
||||
c.x = add(a.x, b.x);
|
||||
c.y = add(a.y, b.y);
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ float4 add(float4 a, float4 b) {
|
||||
float4 c;
|
||||
c.x = add(a.x, b.x);
|
||||
c.y = add(a.y, b.y);
|
||||
c.z = add(a.z, b.z);
|
||||
c.w = add(a.w, b.w);
|
||||
return c;
|
||||
}
|
||||
|
||||
// Vector multiplication.
|
||||
template<>
|
||||
inline __device__ float mul<float, float>(float a, float b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 mul(float2 a, float2 b) {
|
||||
float2 c;
|
||||
c.x = a.x * b.x;
|
||||
c.y = a.y * b.y;
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 mul(float a, float2 b) {
|
||||
float2 c;
|
||||
c.x = a * b.x;
|
||||
c.y = a * b.y;
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float4 mul(float4 a, float4 b) {
|
||||
float4 c;
|
||||
c.x = a.x * b.x;
|
||||
c.y = a.y * b.y;
|
||||
c.z = a.z * b.z;
|
||||
c.w = a.w * b.w;
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float4 mul(float a, float4 b) {
|
||||
float4 c;
|
||||
c.x = a * b.x;
|
||||
c.y = a * b.y;
|
||||
c.z = a * b.z;
|
||||
c.w = a * b.w;
|
||||
return c;
|
||||
}
|
||||
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ float fma(float a, float b, float c) {
|
||||
return a * b + c;
|
||||
}
|
||||
|
||||
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
|
||||
float2 d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ float2 fma(float a, float2 b, float2 c) {
|
||||
float2 d;
|
||||
d.x = fma(a, b.x, c.x);
|
||||
d.y = fma(a, b.y, c.y);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
|
||||
float4 d;
|
||||
d.x = fma(a.x, b.x, c.x);
|
||||
d.y = fma(a.y, b.y, c.y);
|
||||
d.z = fma(a.z, b.z, c.z);
|
||||
d.w = fma(a.w, b.w, c.w);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ float4 fma(float a, float4 b, float4 c) {
|
||||
float4 d;
|
||||
d.x = fma(a, b.x, c.x);
|
||||
d.y = fma(a, b.y, c.y);
|
||||
d.z = fma(a, b.z, c.z);
|
||||
d.w = fma(a, b.w, c.w);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
|
||||
Float4_ d;
|
||||
d.x = fma(a, b.x, c.x);
|
||||
d.y = fma(a, b.y, c.y);
|
||||
return d;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
|
||||
Float8_ d;
|
||||
d.x = fma(a, b.x, c.x);
|
||||
d.y = fma(a, b.y, c.y);
|
||||
d.z = fma(a, b.z, c.z);
|
||||
d.w = fma(a, b.w, c.w);
|
||||
return d;
|
||||
}
|
||||
|
||||
// Vector sum.
|
||||
template<>
|
||||
inline __device__ float sum(float v) {
|
||||
return v;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(float2 v) {
|
||||
return v.x + v.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(float4 v) {
|
||||
return v.x + v.y + v.z + v.w;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(Float4_ v) {
|
||||
return v.x.x + v.x.y + v.y.x + v.y.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float sum(Float8_ v) {
|
||||
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
|
||||
}
|
||||
|
||||
// Vector dot product.
|
||||
inline __device__ float dot(float a, float b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
inline __device__ float dot(float2 a, float2 b) {
|
||||
float2 c = mul<float2, float2, float2>(a, b);
|
||||
return c.x + c.y;
|
||||
}
|
||||
|
||||
inline __device__ float dot(Float4_ a, Float4_ b) {
|
||||
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
||||
acc = fma(a.y, b.y, acc);
|
||||
return acc.x + acc.y;
|
||||
}
|
||||
|
||||
inline __device__ float dot(Float8_ a, Float8_ b) {
|
||||
float2 acc = mul<float2, float2, float2>(a.x, b.x);
|
||||
acc = fma(a.y, b.y, acc);
|
||||
acc = fma(a.z, b.z, acc);
|
||||
acc = fma(a.w, b.w, acc);
|
||||
return acc.x + acc.y;
|
||||
}
|
||||
|
||||
// From float to float.
|
||||
inline __device__ void from_float(float& dst, float src) {
|
||||
dst = src;
|
||||
}
|
||||
|
||||
inline __device__ void from_float(float2& dst, float2 src) {
|
||||
dst = src;
|
||||
}
|
||||
|
||||
inline __device__ void from_float(float4& dst, float4 src) {
|
||||
dst = src;
|
||||
}
|
||||
|
||||
// From float to float.
|
||||
inline __device__ float to_float(float u) {
|
||||
return u;
|
||||
}
|
||||
|
||||
inline __device__ float2 to_float(float2 u) {
|
||||
return u;
|
||||
}
|
||||
|
||||
inline __device__ float4 to_float(float4 u) {
|
||||
return u;
|
||||
}
|
||||
|
||||
inline __device__ Float4_ to_float(Float4_ u) {
|
||||
return u;
|
||||
}
|
||||
|
||||
inline __device__ Float8_ to_float(Float8_ u) {
|
||||
return u;
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(float& dst) {
|
||||
dst = 0.f;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
35
csrc_musa/attention/dtype_fp8.muh
Normal file
35
csrc_musa/attention/dtype_fp8.muh
Normal file
@@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include "attention_generic.muh"
|
||||
|
||||
#include <stdint.h>
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
|
||||
// fp8 vector types for quantization of kv cache
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 1> {
|
||||
using Type = uint8_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 2> {
|
||||
using Type = uint16_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 4> {
|
||||
using Type = uint32_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 8> {
|
||||
using Type = uint2;
|
||||
};
|
||||
#endif // ENABLE_FP8_E5M2
|
||||
|
||||
} // namespace vllm
|
||||
38
csrc_musa/cache.h
Normal file
38
csrc_musa/cache.h
Normal file
@@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, int64_t>& block_mapping);
|
||||
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
const float kv_scale);
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache);
|
||||
419
csrc_musa/cache_kernels.mu
Normal file
419
csrc_musa/cache_kernels.mu
Normal file
@@ -0,0 +1,419 @@
|
||||
#include <torch/extension.h>
|
||||
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
|
||||
#include "musa_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
#include "quantization/fp8/amd_detail/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 __mt_bfloat16;
|
||||
#endif
|
||||
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, int64_t>& block_mapping) {
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
musaMemcpyKind memcpy_type;
|
||||
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
||||
TORCH_CHECK(
|
||||
src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
memcpy_type = musaMemcpyDeviceToDevice;
|
||||
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
||||
memcpy_type = musaMemcpyDeviceToHost;
|
||||
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
||||
memcpy_type = musaMemcpyHostToDevice;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
char *src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const at::musa::OptionalMUSAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||
for (const auto& pair : block_mapping) {
|
||||
int64_t src_block_number = pair.first;
|
||||
int64_t dst_block_number = pair.second;
|
||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||
musaMemcpyAsync(
|
||||
dst_ptr + dst_offset,
|
||||
src_ptr + src_offset,
|
||||
block_size_in_bytes,
|
||||
memcpy_type,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Grid: (num_layers, num_pairs)
|
||||
template<typename scalar_t>
|
||||
__global__ void copy_blocks_kernel(
|
||||
int64_t* key_cache_ptrs,
|
||||
int64_t* value_cache_ptrs,
|
||||
const int64_t* __restrict__ block_mapping,
|
||||
const int numel_per_block) {
|
||||
const int layer_idx = blockIdx.x;
|
||||
const int pair_idx = blockIdx.y;
|
||||
|
||||
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
||||
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||
int64_t src_block_number = block_mapping[2 * pair_idx];
|
||||
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||
|
||||
const int64_t src_block_offset = src_block_number * numel_per_block;
|
||||
const int64_t dst_block_offset = dst_block_number * numel_per_block;
|
||||
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||
int64_t src_offset = src_block_offset + i;
|
||||
int64_t dst_offset = dst_block_offset + i;
|
||||
key_cache[dst_offset] = key_cache[src_offset];
|
||||
}
|
||||
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||
int64_t src_offset = src_block_offset + i;
|
||||
int64_t dst_offset = dst_block_offset + i;
|
||||
value_cache[dst_offset] = value_cache[src_offset];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
|
||||
int num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
return;
|
||||
}
|
||||
torch::Device cache_device = key_caches[0].device();
|
||||
TORCH_CHECK(cache_device.is_cuda());
|
||||
|
||||
// Create data structures for the kernel.
|
||||
// Create an array of pointers to the key and value caches.
|
||||
int64_t key_cache_ptrs[num_layers];
|
||||
int64_t value_cache_ptrs[num_layers];
|
||||
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
|
||||
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
|
||||
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||
}
|
||||
// Create block mapping array.
|
||||
std::vector<int64_t> block_mapping_vec;
|
||||
for (const auto& pair : block_mapping) {
|
||||
int64_t src_block_number = pair.first;
|
||||
for (int64_t dst_block_number : pair.second) {
|
||||
block_mapping_vec.push_back(src_block_number);
|
||||
block_mapping_vec.push_back(dst_block_number);
|
||||
}
|
||||
}
|
||||
int64_t* block_mapping_array = block_mapping_vec.data();
|
||||
int num_pairs = block_mapping_vec.size() / 2;
|
||||
|
||||
// Move the data structures to the GPU.
|
||||
// NOTE: This synchronizes the CPU and GPU.
|
||||
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
|
||||
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
||||
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor block_mapping_tensor = torch::from_blob(
|
||||
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
|
||||
|
||||
// Launch the kernel.
|
||||
const int numel_per_block = key_caches[0][0].numel();
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const at::musa::OptionalMUSAGuard device_guard(cache_device);
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping_tensor.data_ptr<int64_t>(),
|
||||
numel_per_block);
|
||||
}));
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int x,
|
||||
const float kv_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx < 0) {
|
||||
// Padding token that should be ignored.
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
const int n = num_heads * head_size;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const int64_t src_key_idx = token_idx * key_stride + i;
|
||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int x_idx = head_offset / x;
|
||||
const int x_offset = head_offset % x;
|
||||
|
||||
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (is_fp8_kv_cache) {
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
||||
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
|
||||
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
key_cache[tgt_key_idx] = tgt_key;
|
||||
value_cache[tgt_value_idx] = tgt_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void reshape_and_cache_flash_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride,
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
const int n = num_heads * head_size;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const int64_t src_key_idx = token_idx * key_stride + i;
|
||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_value_idx = block_idx * block_stride
|
||||
+ block_offset * num_heads * head_size
|
||||
+ head_idx * head_size
|
||||
+ head_offset;
|
||||
k_cache[tgt_value_idx] = key[src_key_idx];
|
||||
v_cache[tgt_value_idx] = value[src_value_idx];
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), \
|
||||
key_stride, \
|
||||
value_stride, \
|
||||
num_heads, \
|
||||
head_size, \
|
||||
block_size, \
|
||||
x, \
|
||||
kv_scale);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype,
|
||||
const float kv_scale)
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(key));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, float, false);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__mt_bfloat16, __mt_bfloat16, false);
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__mt_bfloat16, uint8_t, true);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype)
|
||||
{
|
||||
// FIXME: only support auto datatype, does not support fp8
|
||||
if (kv_cache_dtype != "auto") {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = k_cache.size(1);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
int block_stride = k_cache.stride(0);
|
||||
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(key));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_flash",
|
||||
[&] {
|
||||
vllm::reshape_and_cache_flash_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
k_cache.data_ptr<scalar_t>(),
|
||||
v_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(),
|
||||
block_stride,
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size);
|
||||
});
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
__global__ void convert_fp8_kernel(
|
||||
const Tin* __restrict__ src_cache,
|
||||
Tout* __restrict__ dst_cache,
|
||||
const int64_t block_stride) {
|
||||
const int64_t block_idx = blockIdx.x;
|
||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||
int64_t idx = block_idx * block_stride + i;
|
||||
#if defined(ENABLE_FP8_E5M2)
|
||||
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||
#elif defined(ENABLE_FP8_E4M3)
|
||||
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_CONVERT_FP8(Tout, Tin) \
|
||||
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||
block_stride);
|
||||
|
||||
void convert_fp8(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache)
|
||||
{
|
||||
torch::Device src_device = src_cache.device();
|
||||
torch::Device dst_device = dst_cache.device();
|
||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
|
||||
TORCH_CHECK(
|
||||
src_device.index() == dst_device.index(),
|
||||
"src and dst must be on the same GPU");
|
||||
at::musa::OptionalMUSAGuard device_guard(src_device);
|
||||
|
||||
int64_t num_blocks = src_cache.size(0);
|
||||
int64_t block_stride = src_cache.stride(0);
|
||||
|
||||
dim3 grid(num_blocks);
|
||||
dim3 block(std::min(block_stride, int64_t(512)));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
|
||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8(uint8_t, float);
|
||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8(uint8_t, uint16_t);
|
||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8(uint8_t, __mt_bfloat16);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8(float, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8(uint16_t, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8(__mt_bfloat16, uint8_t);
|
||||
}
|
||||
}
|
||||
148
csrc_musa/cpu/activation.cpp
Normal file
148
csrc_musa/cpu/activation.cpp
Normal file
@@ -0,0 +1,148 @@
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &),
|
||||
bool is_gated>
|
||||
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
|
||||
scalar_t *__restrict__ output) {
|
||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||
|
||||
TORCH_CHECK(d % VEC_ELEM_NUM == 0);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
for (int j = 0; j < d; j += VEC_ELEM_NUM) {
|
||||
int start = i * d;
|
||||
if constexpr (is_gated) {
|
||||
start *= 2;
|
||||
}
|
||||
|
||||
const scalar_vec_t x(input + start + j);
|
||||
const vec_op::FP32Vec8 f32_x(x);
|
||||
vec_op::FP32Vec8 f32_ans = func(f32_x);
|
||||
|
||||
if constexpr (is_gated) {
|
||||
const scalar_vec_t y(input + start + d + j);
|
||||
const vec_op::FP32Vec8 f32_y(y);
|
||||
f32_ans = f32_y * f32_ans;
|
||||
}
|
||||
|
||||
const scalar_vec_t result(f32_ans);
|
||||
result.save(output + i * d + j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) {
|
||||
const vec_op::FP32Vec8 zeros(0.0);
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
return x / (ones + (zeros - x).exp());
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(0.79788456f);
|
||||
const vec_op::FP32Vec8 w2(0.044715f);
|
||||
const vec_op::FP32Vec8 w3(0.5);
|
||||
const vec_op::FP32Vec8 x3 = x * x * x;
|
||||
const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
|
||||
return w3 * x * (ones + t);
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(0.79788456f);
|
||||
const vec_op::FP32Vec8 w2(0.044715f);
|
||||
const vec_op::FP32Vec8 w3(0.5);
|
||||
const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
|
||||
return w3 * x * (ones + t);
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) {
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(M_SQRT1_2);
|
||||
const vec_op::FP32Vec8 w2(0.5);
|
||||
return x * w2 * (ones + (x * w1).er());
|
||||
}
|
||||
|
||||
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
|
||||
const vec_op::FP32Vec8 ones(1.0);
|
||||
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
|
||||
const vec_op::FP32Vec8 w2(0.5);
|
||||
const vec_op::FP32Vec8 w3(0.044715);
|
||||
const vec_op::FP32Vec8 x_3 = x * x * x;
|
||||
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
|
||||
return x * w2 * (ones + inner.tanh());
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "silu_and_mul_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
|
||||
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
|
||||
input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_and_mul(torch::Tensor &out, // [..., d]
|
||||
torch::Tensor &input) // [..., 2 * d]
|
||||
{
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "gelu_and_mul_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
|
||||
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
|
||||
input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
|
||||
torch::Tensor &input) // [..., 2 * d]
|
||||
{
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
|
||||
activation_kernel<scalar_t, gelu_tanh_act, true>(
|
||||
num_tokens, d, input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_new(torch::Tensor &out, torch::Tensor &input) {
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(gelu_new_impl)
|
||||
activation_kernel<scalar_t, gelu_new_act, false>(
|
||||
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(gelu_new_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void gelu_fast(torch::Tensor &out, torch::Tensor &input) {
|
||||
int num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(gelu_fast_impl)
|
||||
activation_kernel<scalar_t, gelu_fast_act, false>(
|
||||
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
|
||||
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
|
||||
});
|
||||
}
|
||||
746
csrc_musa/cpu/attention.cpp
Normal file
746
csrc_musa/cpu/attention.cpp
Normal file
@@ -0,0 +1,746 @@
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t> struct KernelVecType {
|
||||
using q_load_vec_type = void;
|
||||
using q_vec_type = void;
|
||||
using k_load_vec_type = void;
|
||||
using k_vec_type = void;
|
||||
using qk_acc_vec_type = void;
|
||||
using v_load_vec_type = void;
|
||||
};
|
||||
|
||||
template <> struct KernelVecType<float> {
|
||||
using q_load_vec_type = vec_op::FP32Vec4;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::FP32Vec16;
|
||||
using k_vec_type = vec_op::FP32Vec16;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <> struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::BF16Vec32;
|
||||
using k_load_vec_type = vec_op::BF16Vec32;
|
||||
using k_vec_type = vec_op::BF16Vec32;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#else
|
||||
template <> struct KernelVecType<c10::BFloat16> {
|
||||
using q_load_vec_type = vec_op::BF16Vec8;
|
||||
using q_vec_type = vec_op::FP32Vec16;
|
||||
using k_load_vec_type = vec_op::BF16Vec16;
|
||||
using k_vec_type = vec_op::FP32Vec16;
|
||||
using qk_acc_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
|
||||
const int capacity) {
|
||||
T max = data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
max = max >= data[i] ? max : data[i];
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data[i] = std::exp(data[i] - max);
|
||||
sum += data[i];
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (; i < size; ++i) {
|
||||
data[i] /= sum;
|
||||
}
|
||||
|
||||
for (; i < capacity; ++i) {
|
||||
data[i] = 0;
|
||||
}
|
||||
|
||||
return {max, sum};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE std::pair<T, T>
|
||||
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
|
||||
const float alibi_slope, const int start_index,
|
||||
const int seq_len) {
|
||||
data[0] += alibi_slope * (start_index - seq_len + 1);
|
||||
T max = data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
|
||||
data[i] = qk;
|
||||
max = max >= qk ? max : qk;
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
data[i] = std::exp(data[i] - max);
|
||||
sum += data[i];
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (; i < size; ++i) {
|
||||
data[i] /= sum;
|
||||
}
|
||||
|
||||
for (; i < capacity; ++i) {
|
||||
data[i] = 0;
|
||||
}
|
||||
|
||||
return {max, sum};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data,
|
||||
const int size) {
|
||||
T max = max_data[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
max = max >= max_data[i] ? max : max_data[i];
|
||||
}
|
||||
|
||||
T rescaled_sum = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
T rescale_factor = std::exp(max_data[i] - max);
|
||||
rescaled_sum += rescale_factor * sum_data[i];
|
||||
sum_data[i] *= rescale_factor;
|
||||
}
|
||||
for (int i = 0; i < size; ++i) {
|
||||
sum_data[i] /= rescaled_sum + 1e-8;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int x>
|
||||
struct reduceQKBlockKernel {
|
||||
using q_load_vec_type = typename KernelVecType<scalar_t>::q_load_vec_type;
|
||||
using q_vec_type = typename KernelVecType<scalar_t>::q_vec_type;
|
||||
using k_load_vec_type = typename KernelVecType<scalar_t>::k_load_vec_type;
|
||||
using k_vec_type = typename KernelVecType<scalar_t>::k_vec_type;
|
||||
using qk_acc_vec_type = typename KernelVecType<scalar_t>::qk_acc_vec_type;
|
||||
|
||||
constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
|
||||
constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
|
||||
constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;
|
||||
|
||||
static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
|
||||
static_assert(k_load_vec_type::get_elem_num() % x == 0);
|
||||
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
|
||||
|
||||
FORCE_INLINE static void call(const scalar_t *__restrict__ q,
|
||||
const scalar_t *__restrict__ k_block,
|
||||
float *__restrict__ logits, float scale,
|
||||
const int token_num) {
|
||||
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
|
||||
|
||||
qk_acc_vec_type group_accums[MAX_GROUP_NUM];
|
||||
if (token_num == BLOCK_SIZE) {
|
||||
for (int q_offset = 0; q_offset < HEAD_SIZE;
|
||||
q_offset += x, k_block += x * BLOCK_SIZE) {
|
||||
q_load_vec_type q_load_group_vec(q + q_offset);
|
||||
q_vec_type q_group_vec(q_load_group_vec);
|
||||
|
||||
vec_op::unroll_loop<int, MAX_GROUP_NUM>(
|
||||
[k_block, &q_group_vec, &group_accums](int token_group_idx) {
|
||||
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
|
||||
TOKEN_PER_GROUP);
|
||||
k_vec_type k_group_vec(k_load_group_vec);
|
||||
vec_op::fma(group_accums[token_group_idx], q_group_vec,
|
||||
k_group_vec);
|
||||
vec_op::prefetch(k_block + x * BLOCK_SIZE +
|
||||
token_group_idx * x * TOKEN_PER_GROUP);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
for (int q_offset = 0; q_offset < HEAD_SIZE;
|
||||
q_offset += x, k_block += x * BLOCK_SIZE) {
|
||||
q_load_vec_type q_load_group_vec(q + q_offset);
|
||||
q_vec_type q_group_vec(q_load_group_vec);
|
||||
for (int token_group_start = 0; token_group_start < group_num;
|
||||
token_group_start += UNROLL_GROUP_NUM) {
|
||||
vec_op::unroll_loop<int, UNROLL_GROUP_NUM>(
|
||||
[token_group_start, k_block, &q_group_vec,
|
||||
&group_accums](int token_group_idx) {
|
||||
token_group_idx += token_group_start;
|
||||
k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
|
||||
TOKEN_PER_GROUP);
|
||||
k_vec_type k_group_vec(k_load_group_vec);
|
||||
vec_op::fma(group_accums[token_group_idx], q_group_vec,
|
||||
k_group_vec);
|
||||
vec_op::prefetch(k_block + x * BLOCK_SIZE +
|
||||
token_group_idx * x * TOKEN_PER_GROUP);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int token_group_idx = 0; token_group_idx < group_num;
|
||||
++token_group_idx) {
|
||||
vec_op::unroll_loop<int, TOKEN_PER_GROUP>(
|
||||
[&group_accums, logits, scale, token_group_idx](int token_idx) {
|
||||
float dot_v =
|
||||
group_accums[token_group_idx]
|
||||
.template reduce_sub_sum<qk_acc_vec_type::get_elem_num() /
|
||||
TOKEN_PER_GROUP>(token_idx);
|
||||
logits[token_group_idx * TOKEN_PER_GROUP + token_idx] =
|
||||
dot_v * scale;
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
|
||||
int HEAD_PARTITION_SIZE, typename acc_t>
|
||||
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
|
||||
acc_t &&acc) {
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
|
||||
static_assert(BLOCK_SIZE == ELEM_NUM);
|
||||
vec_op::FP32Vec16 prob_vec(prob);
|
||||
|
||||
vec_op::unroll_loop<int, HEAD_PARTITION_SIZE>([&](int head_elem_idx) {
|
||||
v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx);
|
||||
vec_op::FP32Vec16 fp32_v_vec(v_vec);
|
||||
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
|
||||
});
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
// Paged attention v1
|
||||
namespace {
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
|
||||
struct paged_attention_v1_impl {
|
||||
static void
|
||||
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int
|
||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int *__restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float *__restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const int num_seqs, const int num_heads) {
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
|
||||
static_assert(BLOCK_SIZE == 16);
|
||||
|
||||
int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
|
||||
int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
|
||||
TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
|
||||
|
||||
const int parallel_work_item_num = omp_get_max_threads();
|
||||
|
||||
size_t logits_bytes =
|
||||
parallel_work_item_num * max_seq_len_padded * sizeof(float);
|
||||
float *logits = (float *)std::aligned_alloc(
|
||||
64, logits_bytes); // Cacheline alignment for each context token.
|
||||
// [parallel_work_item_num, max_seq_len_padded]
|
||||
|
||||
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
int seq_len = seq_lens[seq_idx];
|
||||
const int *seq_block_table =
|
||||
block_tables + max_num_blocks_per_seq * seq_idx;
|
||||
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const scalar_t *__restrict__ q_vec_ptr =
|
||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
const int last_block_token_num =
|
||||
seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||
float *__restrict__ thread_block_logits =
|
||||
logits + omp_get_thread_num() * max_seq_len_padded;
|
||||
|
||||
// Compute logits
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const scalar_t *__restrict__ k_block_cache_ptr =
|
||||
k_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride;
|
||||
float *__restrict__ head_block_logits =
|
||||
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||
|
||||
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
|
||||
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
|
||||
}
|
||||
|
||||
// Compute softmax
|
||||
if (alibi_slopes) {
|
||||
reduceSoftmaxAlibi(thread_block_logits, seq_len,
|
||||
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
|
||||
seq_len);
|
||||
} else {
|
||||
reduceSoftmax(thread_block_logits, seq_len,
|
||||
block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
// Compute value
|
||||
constexpr int head_elem_num_per_partition = 16;
|
||||
constexpr int head_partition_num =
|
||||
HEAD_SIZE / head_elem_num_per_partition;
|
||||
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||
++head_part_idx) {
|
||||
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||
scalar_t *__restrict__ out_ptr =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||
head_part_idx * head_elem_num_per_partition;
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const float *__restrict__ prob_vec_ptr =
|
||||
thread_block_logits + block_idx * BLOCK_SIZE;
|
||||
const scalar_t *__restrict__ v_block_cache_ptr =
|
||||
v_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
|
||||
head_elem_num_per_partition>(
|
||||
prob_vec_ptr, v_block_cache_ptr, accums);
|
||||
|
||||
if (block_idx != block_num - 1) {
|
||||
const int64_t next_physical_block_idx =
|
||||
seq_block_table[block_idx + 1];
|
||||
const scalar_t *__restrict__ next_v_block_cache_ptr =
|
||||
v_cache + next_physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
if (head_elem_idx % 2 == 0) {
|
||||
vec_op::prefetch(next_v_block_cache_ptr +
|
||||
BLOCK_SIZE * head_elem_idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
float value = accums[head_elem_idx].reduce_sum();
|
||||
vec_op::storeFP32(value, out_ptr + head_elem_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
std::free(logits);
|
||||
}
|
||||
};
|
||||
|
||||
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
|
||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
|
||||
num_heads);
|
||||
|
||||
template <typename T, int BLOCK_SIZE>
|
||||
void paged_attention_v1_impl_launcher(
|
||||
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
|
||||
torch::Tensor &value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor &block_tables, torch::Tensor &seq_lens,
|
||||
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float *alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
|
||||
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
|
||||
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
|
||||
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
|
||||
int *block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int *seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 64:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||
seq_lens, max_seq_len, alibi_slopes);
|
||||
|
||||
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_V1_KERNEL_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
|
||||
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
||||
int num_kv_heads, float scale,
|
||||
torch::Tensor &block_tables,
|
||||
torch::Tensor &seq_lens, int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
|
||||
CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
|
||||
CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
|
||||
});
|
||||
}
|
||||
|
||||
// Paged attention v2
|
||||
namespace {
|
||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
|
||||
struct paged_attention_v2_impl {
|
||||
static void call(
|
||||
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float
|
||||
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size/x, block_size, x]
|
||||
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int
|
||||
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int *__restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float *__restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const int num_seqs, const int num_heads, const int max_num_partitions) {
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
|
||||
static_assert(BLOCK_SIZE == 16);
|
||||
static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0);
|
||||
static_assert(PARTITION_SIZE % BLOCK_SIZE == 0);
|
||||
|
||||
#pragma omp parallel for collapse(3) schedule(static, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int partition_idx = 0; partition_idx < max_num_partitions;
|
||||
++partition_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int start_token_idx = partition_idx * PARTITION_SIZE;
|
||||
|
||||
if (start_token_idx >= seq_len)
|
||||
continue;
|
||||
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
const bool no_reduce = (partition_num == 1);
|
||||
const int token_num =
|
||||
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
|
||||
start_token_idx);
|
||||
const int block_num =
|
||||
(token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int last_block_token_num =
|
||||
token_num - (block_num - 1) * BLOCK_SIZE;
|
||||
const int *seq_block_table = block_tables +
|
||||
max_num_blocks_per_seq * seq_idx +
|
||||
start_token_idx / BLOCK_SIZE;
|
||||
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const scalar_t *__restrict__ q_vec_ptr =
|
||||
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
|
||||
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
|
||||
|
||||
// Compute logits
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const scalar_t *__restrict__ k_block_cache_ptr =
|
||||
k_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride;
|
||||
float *__restrict__ head_block_logits =
|
||||
logits + block_idx * BLOCK_SIZE;
|
||||
|
||||
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
|
||||
q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
|
||||
block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
|
||||
}
|
||||
|
||||
std::pair<float, float> max_and_sum;
|
||||
if (alibi_slopes) {
|
||||
max_and_sum = reduceSoftmaxAlibi(
|
||||
logits, token_num, block_num * BLOCK_SIZE,
|
||||
alibi_slopes[head_idx], start_token_idx, seq_len);
|
||||
} else {
|
||||
max_and_sum = reduceSoftmax(logits, token_num,
|
||||
block_num * BLOCK_SIZE);
|
||||
}
|
||||
|
||||
auto &&[max_logit, exp_sum] = max_and_sum;
|
||||
|
||||
scalar_t *__restrict__ output_buffer = nullptr;
|
||||
if (!no_reduce) {
|
||||
auto idx = seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions + partition_idx;
|
||||
max_logits[idx] = max_logit;
|
||||
exp_sums[idx] = exp_sum;
|
||||
output_buffer =
|
||||
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||
head_idx * max_num_partitions * HEAD_SIZE +
|
||||
partition_idx * HEAD_SIZE;
|
||||
} else {
|
||||
output_buffer =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
}
|
||||
|
||||
// Compute value
|
||||
constexpr int head_elem_num_per_partition = 16;
|
||||
constexpr int head_partition_num =
|
||||
HEAD_SIZE / head_elem_num_per_partition;
|
||||
for (int head_part_idx = 0; head_part_idx < head_partition_num;
|
||||
++head_part_idx) {
|
||||
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
|
||||
scalar_t *__restrict__ out_ptr =
|
||||
output_buffer + head_part_idx * head_elem_num_per_partition;
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int64_t physical_block_idx = seq_block_table[block_idx];
|
||||
const float *__restrict__ prob_vec_ptr =
|
||||
logits + block_idx * BLOCK_SIZE;
|
||||
const scalar_t *__restrict__ v_block_cache_ptr =
|
||||
v_cache + physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
|
||||
head_elem_num_per_partition>(
|
||||
prob_vec_ptr, v_block_cache_ptr, accums);
|
||||
|
||||
if (block_idx != block_num - 1) {
|
||||
const int64_t next_physical_block_idx =
|
||||
seq_block_table[block_idx + 1];
|
||||
const scalar_t *__restrict__ next_v_block_cache_ptr =
|
||||
v_cache + next_physical_block_idx * kv_block_stride +
|
||||
kv_head_idx * kv_head_stride +
|
||||
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
if (head_elem_idx % 2 == 0) {
|
||||
vec_op::prefetch(next_v_block_cache_ptr +
|
||||
BLOCK_SIZE * head_elem_idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
vec_op::unroll_loop<int, head_elem_num_per_partition>(
|
||||
[&](int head_elem_idx) {
|
||||
float value = accums[head_elem_idx].reduce_sum();
|
||||
vec_op::storeFP32(value, out_ptr + head_elem_idx);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rescale partition softmax and store the factors to exp_sums
|
||||
#pragma omp parallel for collapse(2) schedule(static, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1)
|
||||
continue;
|
||||
|
||||
reducePartitonSoftmax(
|
||||
max_logits + seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions,
|
||||
exp_sums + seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions,
|
||||
partition_num);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce values
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
|
||||
constexpr int head_elem_num_per_group =
|
||||
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
|
||||
// didn't align with 64 bytes
|
||||
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
|
||||
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
|
||||
const float *__restrict__ rescale_factors = exp_sums;
|
||||
#pragma omp parallel for collapse(3) schedule(static, 1)
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int partition_num =
|
||||
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
if (partition_num == 1)
|
||||
continue;
|
||||
|
||||
const float *__restrict__ seq_head_rescale_factors =
|
||||
rescale_factors + seq_idx * num_heads * max_num_partitions +
|
||||
head_idx * max_num_partitions;
|
||||
const scalar_t *__restrict__ seq_head_tmp_out =
|
||||
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
||||
head_idx * max_num_partitions * HEAD_SIZE +
|
||||
group_idx * head_elem_num_per_group;
|
||||
scalar_t *__restrict__ seq_head_output =
|
||||
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
|
||||
group_idx * head_elem_num_per_group;
|
||||
|
||||
vec_op::FP32Vec16 acc;
|
||||
for (int i = 0; i < partition_num; ++i) {
|
||||
vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]);
|
||||
v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE);
|
||||
vec_op::FP32Vec16 fp32_value(value);
|
||||
acc = acc + fp32_value * rescale_factor;
|
||||
}
|
||||
v_load_vec_type cast_acc(acc);
|
||||
cast_acc.save(seq_head_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
|
||||
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
|
||||
max_num_partitions);
|
||||
|
||||
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
|
||||
void paged_attention_v2_impl_launcher(
|
||||
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
|
||||
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
|
||||
torch::Tensor &value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor> &alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
int max_num_partitions = exp_sums.size(-1);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float *alibi_slopes_ptr =
|
||||
alibi_slopes
|
||||
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
|
||||
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr());
|
||||
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr());
|
||||
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr());
|
||||
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
|
||||
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
|
||||
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
|
||||
int *block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int *seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
switch (head_size) {
|
||||
case 64:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, block_size, \
|
||||
max_seq_len, alibi_slopes);
|
||||
|
||||
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_V2_KERNEL_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
|
||||
torch::Tensor &max_logits, torch::Tensor &tmp_out,
|
||||
torch::Tensor &query, torch::Tensor &key_cache,
|
||||
torch::Tensor &value_cache, int num_kv_heads,
|
||||
float scale, torch::Tensor &block_tables,
|
||||
torch::Tensor &seq_lens, int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor> &alibi_slopes,
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||
[&] {
|
||||
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
|
||||
CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
|
||||
CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
|
||||
});
|
||||
}
|
||||
141
csrc_musa/cpu/cache.cpp
Normal file
141
csrc_musa/cpu/cache.cpp
Normal file
@@ -0,0 +1,141 @@
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void copy_blocks_cpu_impl(
|
||||
std::vector<torch::Tensor> &key_caches,
|
||||
std::vector<torch::Tensor> &value_caches,
|
||||
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
|
||||
const int element_num_per_block, const int layer_num) {
|
||||
const size_t pair_num = mapping_pairs.size();
|
||||
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int layer = 0; layer < layer_num; ++layer) {
|
||||
for (size_t pair = 0; pair < pair_num; ++pair) {
|
||||
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
|
||||
int64_t target_offset =
|
||||
element_num_per_block * mapping_pairs[pair].second;
|
||||
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
|
||||
scalar_t *source_ptr = key_cache_ptr + source_offset;
|
||||
scalar_t *target_ptr = key_cache_ptr + target_offset;
|
||||
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||
|
||||
scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
|
||||
source_ptr = value_cache_ptr + source_offset;
|
||||
target_ptr = value_cache_ptr + target_offset;
|
||||
std::memcpy(target_ptr, source_ptr, block_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void reshape_and_cache_cpu_impl(
|
||||
const scalar_t *__restrict__ key, const scalar_t *__restrict__ value,
|
||||
scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache,
|
||||
const int64_t *__restrict__ slot_mapping, const int num_tokens,
|
||||
const int key_stride, const int value_stride, const int num_heads,
|
||||
const int head_size, const int block_size, const int x) {
|
||||
const int block_elem_num = num_heads * head_size * block_size;
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx >= 0) {
|
||||
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
|
||||
int src_value_head_idx =
|
||||
token_idx * value_stride + head_idx * head_size;
|
||||
const scalar_t *src_key_head_ptr = key + src_key_head_idx;
|
||||
const scalar_t *src_value_head_ptr = value + src_value_head_idx;
|
||||
const int64_t block_index = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
scalar_t *target_key_head_ptr = key_cache +
|
||||
block_elem_num * block_index +
|
||||
head_idx * block_size * head_size;
|
||||
scalar_t *target_value_head_ptr = value_cache +
|
||||
block_elem_num * block_index +
|
||||
head_idx * block_size * head_size;
|
||||
|
||||
for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
|
||||
const int64_t target_offset =
|
||||
src_key_idx * block_size + block_offset * x;
|
||||
for (int i = 0; i < x; ++i) {
|
||||
target_key_head_ptr[target_offset + i] =
|
||||
src_key_head_ptr[src_key_idx + i];
|
||||
}
|
||||
}
|
||||
|
||||
for (int src_value_idx = 0; src_value_idx < head_size;
|
||||
++src_value_idx) {
|
||||
const int64_t target_offset =
|
||||
src_value_idx * block_size + block_offset;
|
||||
target_value_head_ptr[target_offset] =
|
||||
src_value_head_ptr[src_value_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
||||
std::vector<torch::Tensor> &value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
|
||||
int num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
|
||||
mapping_pairs.reserve(block_mapping.size());
|
||||
for (const auto &pair : block_mapping) {
|
||||
for (const auto &dst : pair.second) {
|
||||
mapping_pairs.emplace_back(pair.first, dst);
|
||||
}
|
||||
}
|
||||
|
||||
const int element_num_per_block = key_caches[0][0].numel();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
||||
torch::Tensor &key_cache, torch::Tensor &value_cache,
|
||||
torch::Tensor &slot_mapping,
|
||||
const std::string &kv_cache_dtype, float kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
||||
reshape_and_cache_cpu_impl<scalar_t>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
|
||||
value_stride, num_heads, head_size, block_size, x);
|
||||
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
|
||||
const std::map<int64_t, int64_t> &block_mapping) {
|
||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||
}
|
||||
352
csrc_musa/cpu/cpu_types.hpp
Normal file
352
csrc_musa/cpu/cpu_types.hpp
Normal file
@@ -0,0 +1,352 @@
|
||||
|
||||
#ifndef CPU_TYPES_HPP
|
||||
#define CPU_TYPES_HPP
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace vec_op {
|
||||
|
||||
// FIXME: FP16 is not fully supported in Torch-CPU
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F &&f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T> struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||
};
|
||||
|
||||
struct FP32Vec8;
|
||||
struct FP32Vec16;
|
||||
|
||||
#ifdef __AVX512FP16__
|
||||
struct FP16Vec8 : public Vec<FP16Vec8> {
|
||||
constexpr static int VEC_ELEM_NUM = 8;
|
||||
|
||||
__m128h reg;
|
||||
|
||||
explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
|
||||
|
||||
explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
|
||||
|
||||
explicit FP16Vec8(__m128h data) : reg(data) {}
|
||||
|
||||
FP16Vec8 operator*(const FP16Vec8 &b) const {
|
||||
return FP16Vec8(_mm_mul_ph(reg, b.reg));
|
||||
}
|
||||
|
||||
FP16Vec8 operator+(const FP16Vec8 &b) const {
|
||||
return FP16Vec8(_mm_add_ph(reg, b.reg));
|
||||
}
|
||||
|
||||
FP16Vec8 operator-(const FP16Vec8 &b) const {
|
||||
return FP16Vec8(_mm_sub_ph(reg, b.reg));
|
||||
}
|
||||
|
||||
FP16Vec8 operator/(const FP16Vec8 &b) const {
|
||||
return FP16Vec8(_mm_div_ph(reg, b.reg));
|
||||
}
|
||||
|
||||
void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
|
||||
};
|
||||
#endif
|
||||
|
||||
struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
constexpr static int VEC_ELEM_NUM = 8;
|
||||
|
||||
__m128i reg;
|
||||
|
||||
explicit BF16Vec8(const void *ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
|
||||
|
||||
explicit BF16Vec8(const FP32Vec8 &);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
|
||||
__m256i reg;
|
||||
|
||||
explicit BF16Vec16(const void *ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16 &);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
constexpr static int VEC_ELEM_NUM = 32;
|
||||
|
||||
__m512i reg;
|
||||
|
||||
explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(__m512i data) : reg(data) {}
|
||||
|
||||
explicit BF16Vec32(BF16Vec8 &vec8_data)
|
||||
: reg((__m512i)_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
|
||||
(__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1),
|
||||
(__m128i)vec8_data.reg, 2),
|
||||
(__m128i)vec8_data.reg, 3)) {}
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
constexpr static int VEC_ELEM_NUM = 4;
|
||||
union AliasReg {
|
||||
__m128 reg;
|
||||
float values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
__m128 reg;
|
||||
|
||||
explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
|
||||
|
||||
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec4(__m128 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
constexpr static int VEC_ELEM_NUM = 8;
|
||||
union AliasReg {
|
||||
__m256 reg;
|
||||
float values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
__m256 reg;
|
||||
|
||||
explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
|
||||
|
||||
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec8(__m256 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
|
||||
|
||||
#ifdef __AVX512FP16__
|
||||
explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
|
||||
#endif
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8 &v)
|
||||
: reg(_mm256_castsi256_ps(
|
||||
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
FP32Vec8 exp() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
|
||||
expf(ar.values[5]), expf(ar.values[4]),
|
||||
expf(ar.values[3]), expf(ar.values[2]),
|
||||
expf(ar.values[1]), expf(ar.values[0])));
|
||||
}
|
||||
|
||||
FP32Vec8 tanh() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
|
||||
tanhf(ar.values[5]), tanhf(ar.values[4]),
|
||||
tanhf(ar.values[3]), tanhf(ar.values[2]),
|
||||
tanhf(ar.values[1]), tanhf(ar.values[0])));
|
||||
}
|
||||
|
||||
FP32Vec8 er() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
|
||||
erf(ar.values[5]), erf(ar.values[4]),
|
||||
erf(ar.values[3]), erf(ar.values[2]),
|
||||
erf(ar.values[1]), erf(ar.values[0])));
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(_mm256_add_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(_mm256_div_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
|
||||
};
|
||||
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
__m512 reg;
|
||||
float values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
__m512 reg;
|
||||
|
||||
explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
|
||||
|
||||
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data)
|
||||
: reg((__m512)_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
|
||||
(__m128i)data.reg, 1),
|
||||
(__m128i)data.reg, 2),
|
||||
(__m128i)data.reg, 3)) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data)
|
||||
: reg((__m512)_mm512_inserti32x8(
|
||||
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16 &v)
|
||||
: reg(_mm512_castsi512_ps(
|
||||
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(_mm512_add_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(_mm512_div_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
|
||||
return _mm512_mask_reduce_add_ps(mask, reg);
|
||||
}
|
||||
|
||||
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||
};
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
|
||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||
|
||||
#ifdef __AVX512FP16__
|
||||
template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
|
||||
#endif
|
||||
|
||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||
|
||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||
|
||||
#ifdef __AVX512FP16__
|
||||
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
||||
*reinterpret_cast<_Float16 *>(ptr) = v;
|
||||
}
|
||||
#endif
|
||||
|
||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
acc = acc + a * b;
|
||||
}
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
|
||||
}
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
||||
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
|
||||
|
||||
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
||||
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
|
||||
}
|
||||
#else
|
||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
|
||||
reinterpret_cast<c10::BFloat16 *>(&v);
|
||||
*ptr = *(v_ptr + 1);
|
||||
}
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
: reg(_mm256_cvtepi32_epi16(
|
||||
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
||||
: reg(_mm512_cvtepi32_epi16(
|
||||
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
|
||||
#endif
|
||||
|
||||
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||
|
||||
}; // namespace vec_op
|
||||
|
||||
#endif
|
||||
117
csrc_musa/cpu/layernorm.cpp
Normal file
117
csrc_musa/cpu/layernorm.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void rms_norm_impl(scalar_t *__restrict__ out,
|
||||
const scalar_t *__restrict__ input,
|
||||
const scalar_t *__restrict__ weight, const float epsilon,
|
||||
const int num_tokens, const int hidden_size) {
|
||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
vec_op::FP32Vec8 variance(0.0);
|
||||
auto input_p = input + i * hidden_size;
|
||||
auto output_p = out + i * hidden_size;
|
||||
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
||||
scalar_vec_t x(input_p + j);
|
||||
vec_op::FP32Vec8 fp32_x(x);
|
||||
variance = variance + fp32_x * fp32_x;
|
||||
}
|
||||
|
||||
float s_variance =
|
||||
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
|
||||
vec_op::FP32Vec8 fp32_s_variance(s_variance);
|
||||
|
||||
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
||||
scalar_vec_t x(input_p + j);
|
||||
scalar_vec_t w(weight + j);
|
||||
|
||||
vec_op::FP32Vec8 fp32_x(x);
|
||||
vec_op::FP32Vec8 fp32_w(w);
|
||||
|
||||
vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w;
|
||||
|
||||
scalar_vec_t out(fp32_out);
|
||||
out.save(output_p + j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
|
||||
scalar_t *__restrict__ residual,
|
||||
const scalar_t *__restrict__ weight,
|
||||
const float epsilon, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
vec_op::FP32Vec8 variance(0.0);
|
||||
auto input_p = input + i * hidden_size;
|
||||
auto residual_p = residual + i * hidden_size;
|
||||
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
||||
scalar_vec_t x(input_p + j);
|
||||
scalar_vec_t res(residual_p + j);
|
||||
vec_op::FP32Vec8 fp32_x(x);
|
||||
vec_op::FP32Vec8 fp32_res(res);
|
||||
|
||||
fp32_x = fp32_x + fp32_res;
|
||||
variance = variance + fp32_x * fp32_x;
|
||||
scalar_vec_t out(fp32_x);
|
||||
out.save(residual_p + j);
|
||||
}
|
||||
|
||||
float s_variance =
|
||||
1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
|
||||
vec_op::FP32Vec8 fp32_s_variance(s_variance);
|
||||
|
||||
for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
|
||||
scalar_vec_t w(weight + j);
|
||||
scalar_vec_t res(residual_p + j);
|
||||
|
||||
vec_op::FP32Vec8 fp32_w(w);
|
||||
vec_op::FP32Vec8 fp32_res(res);
|
||||
|
||||
vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w;
|
||||
|
||||
scalar_vec_t out(fp32_out);
|
||||
out.save(input_p + j);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void rms_norm(torch::Tensor &out, torch::Tensor &input,
|
||||
torch::Tensor &weight, float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(rms_norm_impl)
|
||||
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
CPU_KERNEL_GUARD_OUT(rms_norm_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual,
|
||||
torch::Tensor &weight, float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "fused_add_rms_norm_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl)
|
||||
fused_add_rms_norm_impl(
|
||||
input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
||||
CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl)
|
||||
});
|
||||
}
|
||||
199
csrc_musa/cpu/pos_encoding.cpp
Normal file
199
csrc_musa/cpu/pos_encoding.cpp
Normal file
@@ -0,0 +1,199 @@
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void rotary_embedding_impl(
|
||||
const int64_t
|
||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t
|
||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
||||
/// [num_tokens, num_heads, head_size]
|
||||
scalar_t
|
||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
||||
// [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t
|
||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size,
|
||||
const int num_tokens) {
|
||||
using scalar_vec_t = vec_op::vec_t<scalar_t>;
|
||||
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
|
||||
constexpr int ELEM_SIZE = sizeof(scalar_t);
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
for (int i = 0; i < num_heads; ++i) {
|
||||
const int head_idx = i;
|
||||
const int64_t token_head =
|
||||
token_idx * query_stride + head_idx * head_size;
|
||||
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
|
||||
const int rot_offset = j;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
|
||||
const int64_t out_x = token_head + x_index;
|
||||
const int64_t out_y = token_head + y_index;
|
||||
|
||||
const scalar_vec_t cos(cache_ptr + x_index);
|
||||
const scalar_vec_t sin(cache_ptr + y_index);
|
||||
|
||||
const scalar_vec_t q_x(query + out_x);
|
||||
const scalar_vec_t q_y(query + out_y);
|
||||
|
||||
vec_op::FP32Vec8 fp32_cos(cos);
|
||||
vec_op::FP32Vec8 fp32_sin(sin);
|
||||
|
||||
vec_op::FP32Vec8 fp32_q_x(q_x);
|
||||
vec_op::FP32Vec8 fp32_q_y(q_y);
|
||||
|
||||
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
|
||||
scalar_vec_t(out1).save(query + out_x);
|
||||
|
||||
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
|
||||
scalar_vec_t(out2).save(query + out_y);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_kv_heads; ++i) {
|
||||
const int head_idx = i;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) {
|
||||
const int rot_offset = j;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
|
||||
const int64_t out_x = token_head + x_index;
|
||||
const int64_t out_y = token_head + y_index;
|
||||
|
||||
const scalar_vec_t cos(cache_ptr + x_index);
|
||||
const scalar_vec_t sin(cache_ptr + y_index);
|
||||
|
||||
const scalar_vec_t k_x(key + out_x);
|
||||
const scalar_vec_t k_y(key + out_y);
|
||||
|
||||
vec_op::FP32Vec8 fp32_cos(cos);
|
||||
vec_op::FP32Vec8 fp32_sin(sin);
|
||||
|
||||
vec_op::FP32Vec8 fp32_k_x(k_x);
|
||||
vec_op::FP32Vec8 fp32_k_y(k_y);
|
||||
|
||||
auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
|
||||
scalar_vec_t(out1).save(key + out_x);
|
||||
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
|
||||
scalar_vec_t(out2).save(key + out_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void rotary_embedding_gptj_impl(
|
||||
const int64_t
|
||||
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t
|
||||
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
|
||||
/// [num_tokens, num_heads, head_size]
|
||||
scalar_t
|
||||
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
|
||||
// [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t
|
||||
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int num_heads, const int num_kv_heads, const int head_size,
|
||||
const int num_tokens) {
|
||||
const int embed_dim = rot_dim / 2;
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (int i = 0; i < num_heads; ++i) {
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
const scalar_t *cos_cache_ptr = cache_ptr;
|
||||
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
|
||||
const int head_idx = i;
|
||||
const int64_t token_head =
|
||||
token_idx * query_stride + head_idx * head_size;
|
||||
scalar_t *head_query = token_head + query;
|
||||
for (int j = 0; j < embed_dim; j += 1) {
|
||||
const int rot_offset = j;
|
||||
const int x_index = 2 * rot_offset;
|
||||
const int y_index = 2 * rot_offset + 1;
|
||||
|
||||
const float cos = cos_cache_ptr[rot_offset];
|
||||
const float sin = sin_cache_ptr[rot_offset];
|
||||
|
||||
const float x = head_query[x_index];
|
||||
const float y = head_query[y_index];
|
||||
|
||||
head_query[x_index] = x * cos - y * sin;
|
||||
head_query[y_index] = y * cos + x * sin;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
for (int i = 0; i < num_kv_heads; ++i) {
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
const scalar_t *cos_cache_ptr = cache_ptr;
|
||||
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
|
||||
const int head_idx = i;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
scalar_t *head_key = key + token_head;
|
||||
for (int j = 0; j < embed_dim; j += 1) {
|
||||
const int rot_offset = j;
|
||||
const int x_index = 2 * rot_offset;
|
||||
const int y_index = 2 * rot_offset + 1;
|
||||
|
||||
const float cos = cos_cache_ptr[rot_offset];
|
||||
const float sin = sin_cache_ptr[rot_offset];
|
||||
|
||||
const float x = head_key[x_index];
|
||||
const float y = head_key[y_index];
|
||||
|
||||
head_key[x_index] = x * cos - y * sin;
|
||||
head_key[y_index] = y * cos + x * sin;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
|
||||
torch::Tensor &key, int head_size,
|
||||
torch::Tensor &cos_sin_cache, bool is_neox) {
|
||||
int num_tokens = query.numel() / query.size(-1);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int64_t key_stride = key.stride(-2);
|
||||
int64_t query_stride = query.stride(-2);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(), "rotary_embedding_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(rotary_embedding_impl)
|
||||
if (is_neox) {
|
||||
rotary_embedding_impl(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||
head_size, num_tokens);
|
||||
} else {
|
||||
rotary_embedding_gptj_impl(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
|
||||
head_size, num_tokens);
|
||||
}
|
||||
|
||||
CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)
|
||||
});
|
||||
}
|
||||
73
csrc_musa/cpu/pybind.cpp
Normal file
73
csrc_musa/cpu/pybind.cpp
Normal file
@@ -0,0 +1,73 @@
|
||||
#include "cache.h"
|
||||
#include "cuda_utils.h"
|
||||
#include "ops.h"
|
||||
#include <torch/extension.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// vLLM custom ops
|
||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||
|
||||
// Attention ops
|
||||
ops.def(
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
ops.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
|
||||
// Activation ops
|
||||
ops.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
ops.def(
|
||||
"gelu_and_mul",
|
||||
&gelu_and_mul,
|
||||
"Activation function used in GeGLU with `none` approximation.");
|
||||
ops.def(
|
||||
"gelu_tanh_and_mul",
|
||||
&gelu_tanh_and_mul,
|
||||
"Activation function used in GeGLU with `tanh` approximation.");
|
||||
ops.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
ops.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
|
||||
// Layernorm
|
||||
ops.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
ops.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
|
||||
// Rotary embedding
|
||||
ops.def(
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
// Cache ops
|
||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||
cache_ops.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
}
|
||||
148
csrc_musa/custom_all_reduce.mu
Normal file
148
csrc_musa/custom_all_reduce.mu
Normal file
@@ -0,0 +1,148 @@
|
||||
#include "torch_musa/csrc/core/MUSAException.h"
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
#include "torch_musa/csrc/core/MUSAStream.h"
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "custom_all_reduce.muh"
|
||||
|
||||
// fake pointer type
|
||||
using fptr_t = uint64_t;
|
||||
static_assert(sizeof(void *) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = offsets.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (world_size != handles.size())
|
||||
throw std::invalid_argument(
|
||||
"handles length should equal to offsets length");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
musaIpcMemHandle_t ipc_handles[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(musaIpcMemHandle_t));
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||
* dimension). Currently, we require this because stride information is not
|
||||
* passed into the kernels and we treat input tensors as flat.
|
||||
*
|
||||
* Examples
|
||||
* A = torch.zeros(3, 3, 3)
|
||||
* 1. A: OK
|
||||
* 2. A[1:]: OK
|
||||
* 3. A.permute(2, 0, 1): OK
|
||||
* 4. A[1:].permute(2, 0, 1): OK
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor &t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||
t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
bool full_nvlink) {
|
||||
auto inp_size = inp.numel() * inp.element_size();
|
||||
// custom allreduce requires input byte size to be multiples of 16
|
||||
if (inp_size % 16 != 0) return false;
|
||||
if (!_is_weak_contiguous(inp)) return false;
|
||||
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
|
||||
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
|
||||
// performance improvement over NCCL.
|
||||
return false;
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
||||
musaStream_t stream) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
|
||||
reinterpret_cast<float *>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
|
||||
reinterpret_cast<half *>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__MUSA_ARCH__ >= 800 || !defined(__MUSA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<mt_bfloat16>(
|
||||
stream, reinterpret_cast<mt_bfloat16 *>(inp.data_ptr()),
|
||||
reinterpret_cast<mt_bfloat16 *>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::musa::getCurrentMUSAStream().stream();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
_all_reduce(_fa, inp, out, stream);
|
||||
}
|
||||
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||
torch::Tensor &out) {
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::musa::getCurrentMUSAStream().stream();
|
||||
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
||||
"registered buffer is too small to contain the input");
|
||||
C10_MUSA_CHECK(musaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
||||
input_size, musaMemcpyDeviceToDevice, stream));
|
||||
_all_reduce(_fa, reg_buffer, out, stream);
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
int meta_size() { return sizeof(vllm::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
return fa->get_graph_buffer_ipc_meta();
|
||||
}
|
||||
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
fa->register_graph_buffers(handles, offsets);
|
||||
}
|
||||
485
csrc_musa/custom_all_reduce.muh
Normal file
485
csrc_musa/custom_all_reduce.muh
Normal file
@@ -0,0 +1,485 @@
|
||||
#pragma once
|
||||
|
||||
#include <musa.h>
|
||||
#include <musa_bf16.h>
|
||||
#include <musa_fp16.h>
|
||||
#include <musa_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
musaError_t e = cmd; \
|
||||
if (e != musaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
musaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace vllm {
|
||||
|
||||
constexpr int kMaxBlocks = 64;
|
||||
// note: we don't want to use atomics for signals because peer atomics are no
|
||||
// supported on PCIe links
|
||||
struct Signal {
|
||||
alignas(128) uint32_t start[kMaxBlocks][8];
|
||||
alignas(128) uint32_t end[kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; RankData& operator=(const RankData& ){return *this;} };
|
||||
|
||||
struct __align__(16) RankSignals { volatile Signal *signals[8]; };
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) { return __half2float(val); }
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half &assign_add(half &a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float &assign_add(float &a, float b) { return a += b; }
|
||||
|
||||
#if (__MUSA_ARCH__ >= 800 || !defined(__MUSA_ARCH__))
|
||||
DINLINE float upcast_s(mt_bfloat16 val) { return __bfloat162float(val); }
|
||||
template <>
|
||||
DINLINE mt_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE mt_bfloat16 &assign_add(mt_bfloat16 &a, mt_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
// This function is meant to be used as the first synchronization in the all
|
||||
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
int rank) {
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->end[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// This function is meant to be used as the second or the final synchronization
|
||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
||||
// we don't need to make any visibility guarantees for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
int rank) {
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
if constexpr (!final_sync) __threadfence_system();
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->start[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
|
||||
volatile Signal *self_sg, T *__restrict__ result,
|
||||
int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P *)result)[idx] =
|
||||
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P *get_tmp_buf(volatile Signal *sg) {
|
||||
return (P *)(((Signal *)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
|
||||
volatile Signal *self_sg, T *__restrict__ result,
|
||||
int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P *ptrs[ngpus];
|
||||
P *tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P *)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
end_sync<ngpus>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P *)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(musaIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(musaIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(musaIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
// below are device pointers
|
||||
RankSignals sg_;
|
||||
std::unordered_map<void *, RankData *> buffers_;
|
||||
Signal *self_sg_;
|
||||
|
||||
// stores the registered device pointers from all ranks
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void *> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char *> ipc_handles_;
|
||||
|
||||
/**
|
||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||
*
|
||||
* There's a total of sizeof(Signal) of prefix before the actual data,
|
||||
* so meta + 1 points to actual temporary buffer.
|
||||
*
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
*/
|
||||
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
|
||||
const musaIpcMemHandle_t *handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(meta),
|
||||
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
Signal *rank_sg;
|
||||
if (i != rank_) {
|
||||
char *handle = open_ipc_handle(&handles[i]);
|
||||
handle += offsets[i];
|
||||
rank_sg = (Signal *)handle;
|
||||
} else {
|
||||
rank_sg = self_sg_;
|
||||
}
|
||||
sg_.signals[i] = rank_sg;
|
||||
}
|
||||
}
|
||||
|
||||
char *open_ipc_handle(const void *ipc_handle) {
|
||||
auto [it, new_handle] =
|
||||
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char *ipc_ptr;
|
||||
CUDACHECK(musaIpcOpenMemHandle((void **)&ipc_ptr,
|
||||
*((const musaIpcMemHandle_t *)ipc_handle),
|
||||
musaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(musaIpcMemHandle_t);
|
||||
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void *base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (muPointerGetAttribute(&base_ptr,
|
||||
MU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
(MUdeviceptr)ptr) != MUSA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(musaIpcGetMemHandle(
|
||||
(musaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " +
|
||||
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
void register_buffer(const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, void *self) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
if (i != rank_) {
|
||||
char *handle = open_ipc_handle(handles[i].data());
|
||||
handle += offsets[i];
|
||||
data.ptrs[i] = handle;
|
||||
} else {
|
||||
data.ptrs[i] = self;
|
||||
}
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CUDACHECK(
|
||||
musaMemcpy(d_data, &data, sizeof(RankData), musaMemcpyHostToDevice));
|
||||
buffers_[self] = d_data;
|
||||
}
|
||||
|
||||
// note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto &rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char *handle =
|
||||
open_ipc_handle(&handles[j][i * sizeof(musaIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDACHECK(musaMemcpy(d_rank_data_base_, rank_data.data(),
|
||||
sizeof(RankData) * num_buffers,
|
||||
musaMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the result after careful grid search. Using 36 blocks give the best
|
||||
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
||||
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
||||
* Not quite sure the underlying reason, but my guess is that too many SMs
|
||||
* will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(musaStream_t stream, T *input, T *output, int size,
|
||||
int threads = 512, int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error("max supported block limit is " +
|
||||
std::to_string(kMaxBlocks) + ". Got " +
|
||||
std::to_string(block_limit));
|
||||
|
||||
RankData *ptrs;
|
||||
musaStreamCaptureStatus status;
|
||||
CUDACHECK(at::musa::musaStreamIsCapturing(stream, &status));
|
||||
if (status == musaStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " +
|
||||
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
||||
" is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||
rank_, size);
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(musaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void vllm::CustomAllreduce::allreduce<half>(musaStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace vllm
|
||||
316
csrc_musa/custom_all_reduce_test.mu
Normal file
316
csrc_musa/custom_all_reduce_test.mu
Normal file
@@ -0,0 +1,316 @@
|
||||
/**
|
||||
* This is a standalone test for custom allreduce.
|
||||
* To compile, make sure you have MPI and NCCL installed in your system.
|
||||
* export MPI_HOME=XXX
|
||||
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
||||
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
||||
*
|
||||
* Warning: this C++ test is not designed to be very readable and was used
|
||||
* during the rapid prototyping process.
|
||||
*
|
||||
* To run:
|
||||
* mpirun -np 8 ./custom_all_reduce_test
|
||||
*/
|
||||
#include <musa.h>
|
||||
#include <murand_kernel.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "musa_profiler_api.h"
|
||||
#include "custom_all_reduce.muh"
|
||||
#include "mpi.h"
|
||||
#include "nccl.h"
|
||||
|
||||
#define MPICHECK(cmd) \
|
||||
do { \
|
||||
int e = cmd; \
|
||||
if (e != MPI_SUCCESS) { \
|
||||
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define NCCLCHECK(cmd) \
|
||||
do { \
|
||||
ncclResult_t r = cmd; \
|
||||
if (r != ncclSuccess) { \
|
||||
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
ncclGetErrorString(r)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
__global__ void dummy_kernel() {
|
||||
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void set_data(T *data, int size, int myRank) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
data[idx] = myRank * 0.11f;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
||||
double *fdata2, int size) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
fdata1[idx] = data1[idx];
|
||||
fdata2[idx] = data2[idx];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
||||
int myRank, int nRanks, int size) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
double sum = 0.0;
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
|
||||
T hval = val; // downcast first
|
||||
sum += static_cast<double>(hval);
|
||||
if (i == myRank) data[idx] = hval;
|
||||
}
|
||||
ground_truth[idx] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
int data_size, bool performance_test) {
|
||||
T *result;
|
||||
musaStream_t stream;
|
||||
CUDACHECK(musaStreamCreateWithFlags(&stream, musaStreamNonBlocking));
|
||||
CUDACHECK(musaMalloc(&result, data_size * sizeof(T)));
|
||||
CUDACHECK(musaMemset(result, 0, data_size * sizeof(T)));
|
||||
|
||||
musaIpcMemHandle_t self_data_handle;
|
||||
musaIpcMemHandle_t data_handles[8];
|
||||
vllm::Signal *buffer;
|
||||
T *self_data_copy;
|
||||
/**
|
||||
* Allocate IPC buffer
|
||||
*
|
||||
* The first section is a temporary buffer for storing intermediate allreduce
|
||||
* results, if a particular algorithm requires it. The second section is for
|
||||
* the input to the allreduce. The actual API takes the input pointer as an
|
||||
* argument (that is, they can and usually should be allocated separately).
|
||||
* But since the input pointers and the temporary buffer all require IPC
|
||||
* registration, they are allocated and registered together in the test for
|
||||
* convenience.
|
||||
*/
|
||||
CUDACHECK(
|
||||
musaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||
CUDACHECK(
|
||||
musaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||
CUDACHECK(musaMalloc(&self_data_copy, data_size * sizeof(T)));
|
||||
CUDACHECK(musaIpcGetMemHandle(&self_data_handle, buffer));
|
||||
|
||||
MPICHECK(MPI_Allgather(&self_data_handle, sizeof(musaIpcMemHandle_t),
|
||||
MPI_BYTE, data_handles, sizeof(musaIpcMemHandle_t),
|
||||
MPI_BYTE, MPI_COMM_WORLD));
|
||||
|
||||
void *rank_data;
|
||||
size_t rank_data_sz = 16 * 1024 * 1024;
|
||||
CUDACHECK(musaMalloc(&rank_data, rank_data_sz));
|
||||
std::vector<int64_t> offsets(nRanks, 0);
|
||||
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
||||
offsets, myRank);
|
||||
auto *self_data =
|
||||
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
|
||||
sizeof(vllm::Signal) + data_size * sizeof(T));
|
||||
// hack buffer registration
|
||||
{
|
||||
std::vector<std::string> handles;
|
||||
handles.reserve(nRanks);
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
char *begin = (char *)&data_handles[i];
|
||||
char *end = (char *)&data_handles[i + 1];
|
||||
handles.emplace_back(begin, end);
|
||||
}
|
||||
std::vector<int64_t> offsets(nRanks,
|
||||
sizeof(vllm::Signal) + data_size * sizeof(T));
|
||||
fa.register_buffer(handles, offsets, self_data);
|
||||
}
|
||||
|
||||
double *ground_truth;
|
||||
CUDACHECK(musaMallocHost(&ground_truth, data_size * sizeof(double)));
|
||||
curandState_t *states;
|
||||
CUDACHECK(musaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
|
||||
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
|
||||
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
|
||||
nRanks, data_size);
|
||||
CUDACHECK(musaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
|
||||
musaMemcpyDeviceToDevice, stream));
|
||||
musaEvent_t start, stop;
|
||||
CUDACHECK(musaEventCreate(&start));
|
||||
CUDACHECK(musaEventCreate(&stop));
|
||||
|
||||
ncclDataType_t ncclDtype;
|
||||
if (std::is_same<T, half>::value) {
|
||||
ncclDtype = ncclFloat16;
|
||||
} else if (std::is_same<T, mt_bfloat16>::value) {
|
||||
ncclDtype = ncclBfloat16;
|
||||
} else {
|
||||
ncclDtype = ncclFloat;
|
||||
}
|
||||
double *nccl_result, *my_result;
|
||||
CUDACHECK(musaMallocHost(&nccl_result, data_size * sizeof(double)));
|
||||
CUDACHECK(musaMallocHost(&my_result, data_size * sizeof(double)));
|
||||
if (performance_test) {
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
constexpr int warmup_iters = 5;
|
||||
constexpr int num_iters = 100;
|
||||
// warmup
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
|
||||
comm, stream));
|
||||
}
|
||||
CUDACHECK(musaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
|
||||
comm, stream));
|
||||
}
|
||||
CUDACHECK(musaEventRecord(stop, stream));
|
||||
CUDACHECK(musaStreamSynchronize(stream));
|
||||
float allreduce_ms = 0;
|
||||
musaEventElapsedTime(&allreduce_ms, start, stop);
|
||||
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
// warm up
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads,
|
||||
block_limit);
|
||||
}
|
||||
CUDACHECK(musaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads,
|
||||
block_limit);
|
||||
}
|
||||
CUDACHECK(musaEventRecord(stop, stream));
|
||||
CUDACHECK(musaStreamSynchronize(stream));
|
||||
|
||||
float duration_ms = 0;
|
||||
musaEventElapsedTime(&duration_ms, start, stop);
|
||||
if (myRank == 0)
|
||||
printf(
|
||||
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
|
||||
"time:%.2fus\n",
|
||||
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
|
||||
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
|
||||
|
||||
// And wait for all the queued up work to complete
|
||||
CUDACHECK(musaStreamSynchronize(stream));
|
||||
|
||||
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
|
||||
ncclSum, comm, stream));
|
||||
|
||||
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
|
||||
my_result, data_size);
|
||||
CUDACHECK(musaStreamSynchronize(stream));
|
||||
|
||||
for (unsigned long j = 0; j < data_size; j++) {
|
||||
auto diff = abs(nccl_result[j] - my_result[j]);
|
||||
if (diff >= 4e-2) {
|
||||
printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
||||
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
long double nccl_diffs = 0.0;
|
||||
long double my_diffs = 0.0;
|
||||
for (int j = 0; j < data_size; j++) {
|
||||
nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
||||
my_diffs += abs(my_result[j] - ground_truth[j]);
|
||||
}
|
||||
if (myRank == 0)
|
||||
std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
||||
<< " me: " << my_diffs / data_size << std::endl;
|
||||
} else {
|
||||
for (int i = 0; i < 100; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads,
|
||||
block_limit);
|
||||
CUDACHECK(musaStreamSynchronize(stream));
|
||||
NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype,
|
||||
ncclSum, comm, stream));
|
||||
convert_data<T><<<108, 1024, 0, stream>>>(
|
||||
self_data_copy, result, nccl_result, my_result, data_size);
|
||||
CUDACHECK(musaStreamSynchronize(stream));
|
||||
|
||||
for (unsigned long j = 0; j < data_size; j++) {
|
||||
auto diff = abs(nccl_result[j] - my_result[j]);
|
||||
if (diff >= 4e-2) {
|
||||
printf(
|
||||
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
||||
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (myRank == 0)
|
||||
printf("Test passed: nGPUs:%d, sz (kb): %d, %d, %d\n", nRanks,
|
||||
data_size * sizeof(T) / 1024, threads, block_limit);
|
||||
// long double nccl_diffs = 0.0;
|
||||
// long double my_diffs = 0.0;
|
||||
// for (int j = 0; j < data_size; j++) {
|
||||
// nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
||||
// my_diffs += abs(my_result[j] - ground_truth[j]);
|
||||
// }
|
||||
// if (myRank == 0)
|
||||
// std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
||||
// << " me: " << my_diffs / data_size << std::endl;
|
||||
}
|
||||
|
||||
CUDACHECK(musaFree(result));
|
||||
CUDACHECK(musaFree(self_data_copy));
|
||||
CUDACHECK(musaFree(rank_data));
|
||||
CUDACHECK(musaFree(buffer));
|
||||
CUDACHECK(musaFree(states));
|
||||
CUDACHECK(musaFreeHost(ground_truth));
|
||||
CUDACHECK(musaFreeHost(nccl_result));
|
||||
CUDACHECK(musaFreeHost(my_result));
|
||||
CUDACHECK(musaStreamDestroy(stream));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
int nRanks, myRank;
|
||||
MPICHECK(MPI_Init(&argc, &argv));
|
||||
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
||||
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
|
||||
CUDACHECK(musaSetDevice(myRank));
|
||||
ncclUniqueId id;
|
||||
ncclComm_t comm;
|
||||
if (myRank == 0) ncclGetUniqueId(&id);
|
||||
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
|
||||
MPI_COMM_WORLD));
|
||||
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
||||
|
||||
bool performance_test = true;
|
||||
cudaProfilerStart();
|
||||
// for (int threads : {256, 512}) {
|
||||
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
|
||||
// }
|
||||
// }
|
||||
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
|
||||
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
|
||||
}
|
||||
|
||||
cudaProfilerStop();
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
37
csrc_musa/dispatch_utils.h
Normal file
37
csrc_musa/dispatch_utils.h
Normal file
@@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
352
csrc_musa/layernorm_kernels.mu
Normal file
352
csrc_musa/layernorm_kernels.mu
Normal file
@@ -0,0 +1,352 @@
|
||||
#include <torch/extension.h>
|
||||
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "reduction_utils.muh"
|
||||
#ifndef USE_ROCM
|
||||
#include <musa_bf16.h>
|
||||
#include <musa_fp16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
using __mt_bfloat16 = __hip_bfloat16;
|
||||
using __mt_bfloat162 = __hip_bfloat162;
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template<typename scalar_t>
|
||||
__global__ void rms_norm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float) input[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float) input[blockIdx.x * hidden_size + idx];
|
||||
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||||
and the associated type conversions within HIP/CUDA. These helpers need
|
||||
to be implemented for now because the relevant type conversion
|
||||
operators/constructors are not consistently implemented by HIP/CUDA, so
|
||||
a generic conversion via type casts cannot be implemented.
|
||||
|
||||
Each struct should have the member static constexpr bool `exists`:
|
||||
If false, the optimized kernel is not used for the corresponding torch type.
|
||||
If true, the struct should be fully defined as shown in the examples below.
|
||||
*/
|
||||
template<typename torch_type>
|
||||
struct _typeConvert { static constexpr bool exists = false; };
|
||||
|
||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
// CUDA < 12.0 runs into issues with packed type conversion
|
||||
template<>
|
||||
struct _typeConvert<c10::Half> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __half;
|
||||
using packed_hip_type = __half2;
|
||||
|
||||
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
|
||||
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
|
||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
|
||||
};
|
||||
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ >= 800
|
||||
// CUDA_ARCH < 800 does not have BF16 support
|
||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||||
template<>
|
||||
struct _typeConvert<c10::BFloat16> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __mt_bfloat16;
|
||||
using packed_hip_type = __mt_bfloat162;
|
||||
|
||||
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
|
||||
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
|
||||
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
|
||||
};
|
||||
#endif // defined(__MUSA_ARCH__) && __MUSA_ARCH__ >= 800
|
||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
|
||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||
Only functions that are necessary in that kernel are implemented.
|
||||
Alignment to 16 bytes is required to use 128-bit global memory ops.
|
||||
*/
|
||||
template<typename scalar_t, int width>
|
||||
struct alignas(16) _f16Vec {
|
||||
/* Not theoretically necessary that width is a power of 2 but should
|
||||
almost always be the case for optimization purposes */
|
||||
static_assert(width > 0 && (width & (width - 1)) == 0,
|
||||
"Width is not a positive power of 2!");
|
||||
using Converter = _typeConvert<scalar_t>;
|
||||
using T1 = typename Converter::hip_type;
|
||||
using T2 = typename Converter::packed_hip_type;
|
||||
T1 data[width];
|
||||
|
||||
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i+1]};
|
||||
temp += T2{other.data[i], other.data[i+1]};
|
||||
data[i] = temp.x;
|
||||
data[i+1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i)
|
||||
data[i] += other.data[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i+1]};
|
||||
temp *= T2{other.data[i], other.data[i+1]};
|
||||
data[i] = temp.x;
|
||||
data[i+1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i)
|
||||
data[i] *= other.data[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ _f16Vec& operator*=(const float scale) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 temp_f = Converter::convert(T2{data[i], data[i+1]});
|
||||
temp_f.x *= scale;
|
||||
temp_f.y *= scale;
|
||||
T2 temp = Converter::convert(temp_f);
|
||||
data[i] = temp.x;
|
||||
data[i+1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
float temp = Converter::convert(data[i]) * scale;
|
||||
data[i] = Converter::convert(temp);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ float sum_squares() const {
|
||||
float result = 0.0f;
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i+1]});
|
||||
result += z.x * z.x + z.y * z.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
float x = Converter::convert(data[i]);
|
||||
result += x * x;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/* Function specialization in the case of FP16/BF16 tensors.
|
||||
Additional optimizations we can make in this case are
|
||||
packed and vectorized operations, which help with the
|
||||
memory latency bottleneck. */
|
||||
template<typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<
|
||||
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
|
||||
const int vec_hidden_size = hidden_size / width;
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
not aliased in practice. Argument pointers should not be dereferenced
|
||||
in this kernel as that would be undefined behavior */
|
||||
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16Vec<scalar_t, width> temp = input_v[id];
|
||||
temp += residual_v[id];
|
||||
variance += temp.sum_squares();
|
||||
residual_v[id] = temp;
|
||||
}
|
||||
/* Keep the following if-else block in sync with the
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
if (num_tokens < 256) {
|
||||
variance = blockReduceSum<float, 1024>(variance);
|
||||
} else variance = blockReduceSum<float, 256>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16Vec<scalar_t, width> temp = residual_v[id];
|
||||
temp *= s_variance;
|
||||
temp *= weight_v[idx];
|
||||
input_v[id] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Generic fused_add_rms_norm_kernel
|
||||
The width field is not used here but necessary for other specializations.
|
||||
*/
|
||||
template<typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<
|
||||
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
||||
z += residual[blockIdx.x * hidden_size + idx];
|
||||
float x = (float) z;
|
||||
variance += x * x;
|
||||
residual[blockIdx.x * hidden_size + idx] = z;
|
||||
}
|
||||
/* Keep the following if-else block in sync with the
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
if (num_tokens < 256) {
|
||||
variance = blockReduceSum<float, 1024>(variance);
|
||||
} else variance = blockReduceSum<float, 256>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float) residual[blockIdx.x * hidden_size + idx];
|
||||
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(input));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"rms_norm_kernel",
|
||||
[&] {
|
||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"fused_add_rms_norm_kernel", \
|
||||
[&] { \
|
||||
vllm::fused_add_rms_norm_kernel \
|
||||
<scalar_t, width><<<grid, block, 0, stream>>>( \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), \
|
||||
epsilon, \
|
||||
num_tokens, \
|
||||
hidden_size); \
|
||||
});
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
/* This kernel is memory-latency bound in many scenarios.
|
||||
When num_tokens is large, a smaller block size allows
|
||||
for increased block occupancy on CUs and better latency
|
||||
hiding on global mem ops. */
|
||||
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||
dim3 block(std::min(hidden_size, max_block_size));
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(input));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
/*If the tensor types are FP16/BF16, try to use the optimized kernel
|
||||
with packed + vectorized ops.
|
||||
Max optimization is achieved with a width-8 vector of FP16/BF16s
|
||||
since we can load at most 128 bits at once in a global memory op.
|
||||
However, this requires each tensor's data to be aligned to 16
|
||||
bytes.
|
||||
*/
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
|
||||
&& wt_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
}
|
||||
}
|
||||
7
csrc_musa/moe/moe_ops.cpp
Normal file
7
csrc_musa/moe/moe_ops.cpp
Normal file
@@ -0,0 +1,7 @@
|
||||
#include "moe_ops.h"
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
|
||||
}
|
||||
9
csrc_musa/moe/moe_ops.h
Normal file
9
csrc_musa/moe/moe_ops.h
Normal file
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output);
|
||||
500
csrc_musa/moe/topk_softmax_kernels.mu
Normal file
500
csrc_musa/moe/topk_softmax_kernels.mu
Normal file
@@ -0,0 +1,500 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
|
||||
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
||||
* Copyright (c) 2024, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <torch/extension.h>
|
||||
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
|
||||
/// Aligned array type
|
||||
template <
|
||||
typename T,
|
||||
/// Number of elements in the array
|
||||
int N,
|
||||
/// Alignment requirement in bytes
|
||||
int Alignment = sizeof(T) * N
|
||||
>
|
||||
class alignas(Alignment) AlignedArray {
|
||||
float data[N];
|
||||
};
|
||||
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing the output
|
||||
// in the softmax kernel when we extend this module to support expert-choice routing.
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
|
||||
{
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
const int thread_row_offset = blockIdx.x * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
// Don't touch finished rows.
|
||||
if ((finished != nullptr) && finished[blockIdx.x])
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
|
||||
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int num_rows = gridDim.x;
|
||||
const int block_row = blockIdx.x;
|
||||
|
||||
const bool row_is_active = finished ? !finished[block_row] : true;
|
||||
const int thread_read_offset = blockIdx.x * num_experts;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
|
||||
{
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
|
||||
{
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert)
|
||||
{
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
// Ignore experts the node isn't responsible for with expert parallelism
|
||||
const int expert = result_kvp.key;
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = result_kvp.value;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
|
||||
assert(indices[idx] >= 0);
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== TopK softmax things ===============================
|
||||
|
||||
/*
|
||||
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
|
||||
are a small power of 2. This allows us to cleanly share the rows among the threads in
|
||||
a single warp and eliminate communication between warps (so no need to use shared mem).
|
||||
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is intended for when the number of experts is a small power of 2.
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
// Number of bytes each thread pulls in per load
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
||||
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
||||
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
||||
|
||||
// Restrictions based on previous section.
|
||||
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
|
||||
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
|
||||
|
||||
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
|
||||
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||
|
||||
// Restrictions for previous section.
|
||||
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
|
||||
|
||||
// ===================== From this point, we finally start computing run-time variables. ========================
|
||||
|
||||
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
|
||||
// This, each block processes a chunk of rows. We start by computing the start row for each block.
|
||||
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
|
||||
|
||||
// Now, using the base row per thread block, we compute the base row per warp.
|
||||
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
|
||||
|
||||
// The threads in a warp are split into sub-groups that will work on a row.
|
||||
// We compute row offset for each thread sub-group
|
||||
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
||||
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||
|
||||
// Threads with indices out of bounds should early exit here.
|
||||
if (thread_row >= num_rows)
|
||||
{
|
||||
return;
|
||||
}
|
||||
const bool row_is_active = finished ? !finished[thread_row] : true;
|
||||
|
||||
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
|
||||
// row it will read.
|
||||
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
|
||||
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
|
||||
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
|
||||
// this can support all powers of 2 up to 16.
|
||||
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
|
||||
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
|
||||
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
|
||||
|
||||
// Finally, we pull in the data from global mem
|
||||
float row_chunk[VPT];
|
||||
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
|
||||
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
|
||||
{
|
||||
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
}
|
||||
|
||||
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
|
||||
// convert to float afterwards for the exp + sum reduction.
|
||||
float thread_max = row_chunk[0];
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < VPT; ++ii)
|
||||
{
|
||||
thread_max = max(thread_max, row_chunk[ii]);
|
||||
}
|
||||
|
||||
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
|
||||
}
|
||||
|
||||
// From this point, thread max in all the threads have the max within the row.
|
||||
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
|
||||
float row_sum = 0;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii)
|
||||
{
|
||||
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
|
||||
row_sum += row_chunk[ii];
|
||||
}
|
||||
|
||||
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
|
||||
}
|
||||
|
||||
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
||||
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
|
||||
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
|
||||
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
|
||||
// argmax after computing the softmax.
|
||||
const float reciprocal_row_sum = 1.f / row_sum;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii)
|
||||
{
|
||||
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
||||
}
|
||||
|
||||
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
|
||||
// with the max index.
|
||||
int start_col = first_elt_read_by_thread;
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
int expert = start_col;
|
||||
#pragma unroll
|
||||
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
|
||||
{
|
||||
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||
|
||||
// No check on the experts here since columns with the smallest index are processed first and only
|
||||
// updated if > (not >=)
|
||||
if (val > max_val)
|
||||
{
|
||||
max_val = val;
|
||||
expert = col + ii;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
|
||||
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
|
||||
// then blank out their max with -inf and the warp can run more iterations...
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
|
||||
|
||||
// We want lower indices to "win" in every thread so we break ties this way
|
||||
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
||||
{
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
if (thread_group_idx == 0)
|
||||
{
|
||||
// Add a guard to ignore experts not included by this node
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
|
||||
// single) thread per row of the input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = max_val;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
|
||||
if (k_idx + 1 < k)
|
||||
{
|
||||
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
|
||||
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
|
||||
if (thread_group_idx == thread_to_clear_in_group)
|
||||
{
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail
|
||||
{
|
||||
// Constructs some constants needed to partition the work across threads at compile time.
|
||||
template <int EXPERTS, int BYTES_PER_LDG>
|
||||
struct TopkConstants
|
||||
{
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, musaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||
}
|
||||
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
|
||||
gating_output, nullptr, topk_weights, topk_indicies, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
||||
stream);
|
||||
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
const float* gating_output,
|
||||
float* topk_weights,
|
||||
int* topk_indicies,
|
||||
int* token_expert_indices,
|
||||
float* softmax_workspace,
|
||||
const int num_tokens,
|
||||
const int num_experts,
|
||||
const int topk,
|
||||
musaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
||||
break;
|
||||
default: {
|
||||
TORCH_CHECK(softmax_workspace != nullptr,
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
||||
static constexpr int TPB = 256;
|
||||
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
gating_output, nullptr, softmax_workspace, num_experts);
|
||||
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
|
||||
num_experts, topk, 0, num_experts);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights, // [num_tokens, topk]
|
||||
torch::Tensor& topk_indices, // [num_tokens, topk]
|
||||
torch::Tensor& token_expert_indices, // [num_tokens, topk]
|
||||
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
||||
{
|
||||
const int num_experts = gating_output.size(-1);
|
||||
const int num_tokens = gating_output.numel() / num_experts;
|
||||
const int topk = topk_weights.size(-1);
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
const bool needs_workspace = !is_pow_2 || num_experts > 256;
|
||||
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
|
||||
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(gating_output));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
stream);
|
||||
}
|
||||
125
csrc_musa/moe_align_block_size_kernels.mu
Normal file
125
csrc_musa/moe_align_block_size_kernels.mu
Normal file
@@ -0,0 +1,125 @@
|
||||
#include <torch/extension.h>
|
||||
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <THC/THCAtomics.muh>
|
||||
|
||||
#include "musa_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
|
||||
|
||||
namespace vllm {
|
||||
|
||||
namespace {
|
||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
|
||||
// don't worry about overflow because num_experts is relatively small
|
||||
return row * total_col + col;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
||||
int32_t *sorted_token_ids,
|
||||
int32_t *expert_ids,
|
||||
int32_t *total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t block_size,
|
||||
size_t numel) {
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
|
||||
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
||||
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||
* which counts how many tokens in the token shard of thread_index are assigned
|
||||
* to expert expert_index.
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// For each expert we accumulate the token counts from the different threads.
|
||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// We accumulate the token counts of all experts in thread 0.
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/**
|
||||
* For each expert, each thread processes the tokens of the corresponding blocks
|
||||
* and stores the corresponding expert_id for each block.
|
||||
*/
|
||||
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
|
||||
/**
|
||||
* Each thread processes a token shard, calculating the index of each token after
|
||||
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
||||
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
||||
* where * represents a padding value(preset in python).
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
||||
* stores the indices of the tokens processed by the expert with expert_id within
|
||||
* the current thread's token shard.
|
||||
*/
|
||||
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
|
||||
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
||||
|
||||
// set dynamic shared mem
|
||||
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
||||
AT_MUSA_CHECK(
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
|
||||
kernel<<<1, num_experts, shared_mem, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
block_size,
|
||||
topk_ids.numel());
|
||||
});
|
||||
}
|
||||
38
csrc_musa/musa_compat.h
Normal file
38
csrc_musa/musa_compat.h
Normal file
@@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_LDG(arg) __ldg(arg)
|
||||
#else
|
||||
#define VLLM_LDG(arg) *(arg)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||
#else
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
||||
#else
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
musaFuncSetAttribute(FUNC, musaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#else
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
||||
10
csrc_musa/musa_utils.h
Normal file
10
csrc_musa/musa_utils.h
Normal file
@@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id);
|
||||
|
||||
int get_max_shared_memory_per_block_device_attribute(
|
||||
int device_id);
|
||||
35
csrc_musa/musa_utils_kernels.mu
Normal file
35
csrc_musa/musa_utils_kernels.mu
Normal file
@@ -0,0 +1,35 @@
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#endif
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id)
|
||||
{
|
||||
int device, value;
|
||||
if (device_id < 0) {
|
||||
musaGetDevice(&device);
|
||||
}
|
||||
else {
|
||||
device = device_id;
|
||||
}
|
||||
musaDeviceGetAttribute(&value, static_cast<musaDeviceAttr>(attribute), device);
|
||||
return value;
|
||||
}
|
||||
|
||||
|
||||
int get_max_shared_memory_per_block_device_attribute(
|
||||
int device_id)
|
||||
{
|
||||
int attribute;
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||
|
||||
#ifdef USE_ROCM
|
||||
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
||||
#else
|
||||
attribute = musaDevAttrMaxSharedMemoryPerBlockOptin;
|
||||
#endif
|
||||
|
||||
return get_device_attribute(attribute, device_id);
|
||||
}
|
||||
206
csrc_musa/ops.h
Normal file
206
csrc_musa/ops.h
Normal file
@@ -0,0 +1,206 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale);
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& seq_lens,
|
||||
int block_size,
|
||||
int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype,
|
||||
float kv_scale);
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& residual,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
|
||||
void batched_rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox,
|
||||
int rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets);
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_tanh_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(
|
||||
const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias
|
||||
);
|
||||
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes
|
||||
);
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy);
|
||||
|
||||
torch::Tensor marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor &a,
|
||||
torch::Tensor &b_q_weight,
|
||||
torch::Tensor &b_scales,
|
||||
torch::Tensor &g_idx,
|
||||
torch::Tensor &perm,
|
||||
torch::Tensor &workspace,
|
||||
int64_t num_bits,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(
|
||||
torch::Tensor &b_q_weight,
|
||||
torch::Tensor &perm,
|
||||
int64_t size_k,
|
||||
int64_t size_n,
|
||||
int64_t num_bits);
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
torch::Tensor gptq_gemm(
|
||||
torch::Tensor a,
|
||||
torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales,
|
||||
torch::Tensor b_g_idx,
|
||||
bool use_exllama,
|
||||
int bit);
|
||||
|
||||
void gptq_shuffle(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm,
|
||||
int bit);
|
||||
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void dynamic_scaled_fp8_quant(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = uint64_t;
|
||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink);
|
||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||
torch::Tensor &out);
|
||||
void dispose(fptr_t _fa);
|
||||
int meta_size();
|
||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets);
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets);
|
||||
#endif
|
||||
226
csrc_musa/pos_encoding_kernels.mu
Normal file
226
csrc_musa/pos_encoding_kernels.mu
Normal file
@@ -0,0 +1,226 @@
|
||||
#include <torch/extension.h>
|
||||
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
|
||||
#include "musa_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_token_rotary_embedding(
|
||||
scalar_t* __restrict__ arr,
|
||||
const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr,
|
||||
int rot_offset,
|
||||
int embed_dim)
|
||||
{
|
||||
int x_index, y_index;
|
||||
scalar_t cos, sin;
|
||||
if (IS_NEOX) {
|
||||
// GPT-NeoX style rotary embedding.
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = VLLM_LDG(cos_ptr + x_index);
|
||||
sin = VLLM_LDG(sin_ptr + x_index);
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = VLLM_LDG(cos_ptr + x_index / 2);
|
||||
sin = VLLM_LDG(sin_ptr + x_index / 2);
|
||||
}
|
||||
|
||||
const scalar_t x = arr[x_index];
|
||||
const scalar_t y = arr[y_index];
|
||||
arr[x_index] = x * cos - y * sin;
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* cache_ptr,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int rot_dim,
|
||||
const int token_idx,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride)
|
||||
{
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const scalar_t* cos_ptr = cache_ptr;
|
||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
__global__ void batched_rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
|
||||
const int rot_dim,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
int64_t num_tokens = query.numel() / query.size(-1);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int64_t query_stride = query.stride(-2);
|
||||
int64_t key_stride = key.stride(-2);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(query));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding",
|
||||
[&] {
|
||||
if (is_neox) {
|
||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/*
|
||||
Batched version of rotary embedding, pack multiple LoRAs together
|
||||
and process in batched manner.
|
||||
*/
|
||||
void batched_rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox,
|
||||
int rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
||||
) {
|
||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int64_t query_stride = query.stride(-2);
|
||||
int64_t key_stride = key.stride(-2);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(query));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding",
|
||||
[&] {
|
||||
if (is_neox) {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
217
csrc_musa/punica/.LICENSE
Normal file
217
csrc_musa/punica/.LICENSE
Normal file
@@ -0,0 +1,217 @@
|
||||
Contains code from https://github.com/punica-ai/punica
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
|
||||
This product bundles various third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses. See licenses/
|
||||
for text of these licenses.
|
||||
|
||||
|
||||
Apache-2.0
|
||||
* third_party/nvbench (with LLVM exception)
|
||||
* third_party/flashinfer
|
||||
|
||||
BSD-3-Clause:
|
||||
* third_party/cutlass
|
||||
5
csrc_musa/punica/bgmv/bgmv_bf16_bf16_bf16.mu
Normal file
5
csrc_musa/punica/bgmv/bgmv_bf16_bf16_bf16.mu
Normal file
@@ -0,0 +1,5 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, mt_bfloat16, mt_bfloat16, mt_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, mt_bfloat16, mt_bfloat16, mt_bfloat16)
|
||||
5
csrc_musa/punica/bgmv/bgmv_bf16_fp32_bf16.mu
Normal file
5
csrc_musa/punica/bgmv/bgmv_bf16_fp32_bf16.mu
Normal file
@@ -0,0 +1,5 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, mt_bfloat16, float, mt_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, mt_bfloat16, float, mt_bfloat16)
|
||||
162
csrc_musa/punica/bgmv/bgmv_config.h
Normal file
162
csrc_musa/punica/bgmv/bgmv_config.h
Normal file
@@ -0,0 +1,162 @@
|
||||
#pragma once
|
||||
|
||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||
typename W_T>
|
||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale);
|
||||
|
||||
// clang-format off
|
||||
|
||||
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
|
||||
f(in_T, out_T, W_T, narrow, 128) \
|
||||
f(in_T, out_T, W_T, narrow, 256) \
|
||||
f(in_T, out_T, W_T, narrow, 512) \
|
||||
f(in_T, out_T, W_T, narrow, 640) \
|
||||
f(in_T, out_T, W_T, narrow, 768) \
|
||||
f(in_T, out_T, W_T, narrow, 1024) \
|
||||
f(in_T, out_T, W_T, narrow, 1152) \
|
||||
f(in_T, out_T, W_T, narrow, 1280) \
|
||||
f(in_T, out_T, W_T, narrow, 1536) \
|
||||
f(in_T, out_T, W_T, narrow, 1728) \
|
||||
f(in_T, out_T, W_T, narrow, 1792) \
|
||||
f(in_T, out_T, W_T, narrow, 2048) \
|
||||
f(in_T, out_T, W_T, narrow, 2304) \
|
||||
f(in_T, out_T, W_T, narrow, 2560) \
|
||||
f(in_T, out_T, W_T, narrow, 2752) \
|
||||
f(in_T, out_T, W_T, narrow, 2816) \
|
||||
f(in_T, out_T, W_T, narrow, 3072) \
|
||||
f(in_T, out_T, W_T, narrow, 3456) \
|
||||
f(in_T, out_T, W_T, narrow, 3584) \
|
||||
f(in_T, out_T, W_T, narrow, 4096) \
|
||||
f(in_T, out_T, W_T, narrow, 4608) \
|
||||
f(in_T, out_T, W_T, narrow, 5120) \
|
||||
f(in_T, out_T, W_T, narrow, 5504) \
|
||||
f(in_T, out_T, W_T, narrow, 5632) \
|
||||
f(in_T, out_T, W_T, narrow, 6144) \
|
||||
f(in_T, out_T, W_T, narrow, 6848) \
|
||||
f(in_T, out_T, W_T, narrow, 6912) \
|
||||
f(in_T, out_T, W_T, narrow, 7168) \
|
||||
f(in_T, out_T, W_T, narrow, 8192) \
|
||||
f(in_T, out_T, W_T, narrow, 9216) \
|
||||
f(in_T, out_T, W_T, narrow, 10240) \
|
||||
f(in_T, out_T, W_T, narrow, 11008) \
|
||||
f(in_T, out_T, W_T, narrow, 12288) \
|
||||
f(in_T, out_T, W_T, narrow, 13696) \
|
||||
f(in_T, out_T, W_T, narrow, 13824) \
|
||||
f(in_T, out_T, W_T, narrow, 14336) \
|
||||
f(in_T, out_T, W_T, narrow, 15360) \
|
||||
f(in_T, out_T, W_T, narrow, 16384) \
|
||||
f(in_T, out_T, W_T, narrow, 20480) \
|
||||
f(in_T, out_T, W_T, narrow, 22016) \
|
||||
f(in_T, out_T, W_T, narrow, 24576) \
|
||||
f(in_T, out_T, W_T, narrow, 27392) \
|
||||
f(in_T, out_T, W_T, narrow, 28672) \
|
||||
f(in_T, out_T, W_T, narrow, 32000) \
|
||||
f(in_T, out_T, W_T, narrow, 32256) \
|
||||
f(in_T, out_T, W_T, narrow, 32512) \
|
||||
f(in_T, out_T, W_T, narrow, 32768) \
|
||||
f(in_T, out_T, W_T, narrow, 33024) \
|
||||
f(in_T, out_T, W_T, narrow, 36864) \
|
||||
f(in_T, out_T, W_T, narrow, 43264) \
|
||||
f(in_T, out_T, W_T, narrow, 49152) \
|
||||
f(in_T, out_T, W_T, narrow, 64000) \
|
||||
f(in_T, out_T, W_T, narrow, 64256) \
|
||||
f(in_T, out_T, W_T, narrow, 64512) \
|
||||
f(in_T, out_T, W_T, narrow, 102400) \
|
||||
f(in_T, out_T, W_T, narrow, 102656) \
|
||||
f(in_T, out_T, W_T, narrow, 102912) \
|
||||
f(in_T, out_T, W_T, narrow, 128000) \
|
||||
f(in_T, out_T, W_T, narrow, 128256) \
|
||||
f(in_T, out_T, W_T, narrow, 128512) \
|
||||
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
|
||||
// and vllm/tests/lora/test_punica.py
|
||||
|
||||
// Used for defining kernels going from the variety of
|
||||
// dim in to the narrow dim out
|
||||
// Using it for the fully sharded column
|
||||
// parallel LoRA A which splits the rank dim
|
||||
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
|
||||
f(in_T, out_T, W_T, 128, narrow) \
|
||||
f(in_T, out_T, W_T, 256, narrow) \
|
||||
f(in_T, out_T, W_T, 512, narrow) \
|
||||
f(in_T, out_T, W_T, 640, narrow) \
|
||||
f(in_T, out_T, W_T, 768, narrow) \
|
||||
f(in_T, out_T, W_T, 1024, narrow) \
|
||||
f(in_T, out_T, W_T, 1152, narrow) \
|
||||
f(in_T, out_T, W_T, 1280, narrow) \
|
||||
f(in_T, out_T, W_T, 1536, narrow) \
|
||||
f(in_T, out_T, W_T, 1728, narrow) \
|
||||
f(in_T, out_T, W_T, 1792, narrow) \
|
||||
f(in_T, out_T, W_T, 2048, narrow) \
|
||||
f(in_T, out_T, W_T, 2304, narrow) \
|
||||
f(in_T, out_T, W_T, 2560, narrow) \
|
||||
f(in_T, out_T, W_T, 2752, narrow) \
|
||||
f(in_T, out_T, W_T, 2816, narrow) \
|
||||
f(in_T, out_T, W_T, 3072, narrow) \
|
||||
f(in_T, out_T, W_T, 3456, narrow) \
|
||||
f(in_T, out_T, W_T, 3584, narrow) \
|
||||
f(in_T, out_T, W_T, 4096, narrow) \
|
||||
f(in_T, out_T, W_T, 4608, narrow) \
|
||||
f(in_T, out_T, W_T, 5120, narrow) \
|
||||
f(in_T, out_T, W_T, 5504, narrow) \
|
||||
f(in_T, out_T, W_T, 5632, narrow) \
|
||||
f(in_T, out_T, W_T, 6144, narrow) \
|
||||
f(in_T, out_T, W_T, 6848, narrow) \
|
||||
f(in_T, out_T, W_T, 6912, narrow) \
|
||||
f(in_T, out_T, W_T, 7168, narrow) \
|
||||
f(in_T, out_T, W_T, 8192, narrow) \
|
||||
f(in_T, out_T, W_T, 9216, narrow) \
|
||||
f(in_T, out_T, W_T, 10240, narrow) \
|
||||
f(in_T, out_T, W_T, 11008, narrow) \
|
||||
f(in_T, out_T, W_T, 12288, narrow) \
|
||||
f(in_T, out_T, W_T, 13696, narrow) \
|
||||
f(in_T, out_T, W_T, 13824, narrow) \
|
||||
f(in_T, out_T, W_T, 14336, narrow) \
|
||||
f(in_T, out_T, W_T, 15360, narrow) \
|
||||
f(in_T, out_T, W_T, 16384, narrow) \
|
||||
f(in_T, out_T, W_T, 20480, narrow) \
|
||||
f(in_T, out_T, W_T, 22016, narrow) \
|
||||
f(in_T, out_T, W_T, 24576, narrow) \
|
||||
f(in_T, out_T, W_T, 27392, narrow) \
|
||||
f(in_T, out_T, W_T, 28672, narrow) \
|
||||
f(in_T, out_T, W_T, 32000, narrow) \
|
||||
f(in_T, out_T, W_T, 32256, narrow) \
|
||||
f(in_T, out_T, W_T, 32512, narrow) \
|
||||
f(in_T, out_T, W_T, 32768, narrow) \
|
||||
f(in_T, out_T, W_T, 33024, narrow) \
|
||||
f(in_T, out_T, W_T, 36864, narrow) \
|
||||
f(in_T, out_T, W_T, 43264, narrow) \
|
||||
f(in_T, out_T, W_T, 49152, narrow) \
|
||||
f(in_T, out_T, W_T, 64000, narrow) \
|
||||
f(in_T, out_T, W_T, 64256, narrow) \
|
||||
f(in_T, out_T, W_T, 64512, narrow) \
|
||||
f(in_T, out_T, W_T, 102400, narrow) \
|
||||
f(in_T, out_T, W_T, 102656, narrow) \
|
||||
f(in_T, out_T, W_T, 102912, narrow) \
|
||||
f(in_T, out_T, W_T, 128000, narrow) \
|
||||
f(in_T, out_T, W_T, 128256, narrow) \
|
||||
f(in_T, out_T, W_T, 128512, narrow) \
|
||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||
|
||||
|
||||
// Keep this in sync with vllm/config::LoRAConfig
|
||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
|
||||
|
||||
|
||||
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
|
||||
f(in_T, out_T, W_T, 8, 64) \
|
||||
f(in_T, out_T, W_T, 16, 64) \
|
||||
f(in_T, out_T, W_T, 32, 64) \
|
||||
f(in_T, out_T, W_T, 64, 64)
|
||||
|
||||
// clang-format on
|
||||
5
csrc_musa/punica/bgmv/bgmv_fp16_fp16_fp16.mu
Normal file
5
csrc_musa/punica/bgmv/bgmv_fp16_fp16_fp16.mu
Normal file
@@ -0,0 +1,5 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)
|
||||
5
csrc_musa/punica/bgmv/bgmv_fp16_fp32_fp16.mu
Normal file
5
csrc_musa/punica/bgmv/bgmv_fp16_fp32_fp16.mu
Normal file
@@ -0,0 +1,5 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)
|
||||
5
csrc_musa/punica/bgmv/bgmv_fp32_bf16_bf16.mu
Normal file
5
csrc_musa/punica/bgmv/bgmv_fp32_bf16_bf16.mu
Normal file
@@ -0,0 +1,5 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, mt_bfloat16, mt_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, mt_bfloat16, mt_bfloat16)
|
||||
5
csrc_musa/punica/bgmv/bgmv_fp32_fp16_fp16.mu
Normal file
5
csrc_musa/punica/bgmv/bgmv_fp32_fp16_fp16.mu
Normal file
@@ -0,0 +1,5 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)
|
||||
297
csrc_musa/punica/bgmv/bgmv_impl.muh
Normal file
297
csrc_musa/punica/bgmv/bgmv_impl.muh
Normal file
@@ -0,0 +1,297 @@
|
||||
#pragma once
|
||||
|
||||
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
|
||||
#include <cooperative_groups.h>
|
||||
#include <cuda/pipeline>
|
||||
#include <musa_runtime.h>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "vec_dtypes.cuh"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
// nthrs = (32, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||
typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
size_t j = blockIdx.x;
|
||||
constexpr size_t num_pipeline_stages = 2;
|
||||
constexpr size_t tile_size = tx * ty * vec_size;
|
||||
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
|
||||
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
|
||||
__shared__ float y_warpwise[ty];
|
||||
|
||||
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||
auto pipe = cuda::make_pipeline();
|
||||
|
||||
// pipeline load W/X and compute WX;
|
||||
pipe.producer_acquire();
|
||||
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
W + (idx * feat_out + j) * feat_in +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
X + (batch_idx * feat_in) +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||
pipe.producer_commit();
|
||||
size_t copy_idx, compute_idx;
|
||||
float y = 0.f;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
size_t tile_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
|
||||
++tile_idx) {
|
||||
copy_idx = tile_idx % num_pipeline_stages;
|
||||
// pipeline stage: async copy W fragment
|
||||
pipe.producer_acquire();
|
||||
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
|
||||
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
W + (idx * feat_out + j) * feat_in +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
X + (batch_idx * feat_in) + tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||
}
|
||||
pipe.producer_commit();
|
||||
|
||||
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||
// pipeline stage: compute WX
|
||||
pipe.consumer_wait();
|
||||
block.sync();
|
||||
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||
}
|
||||
y_warpwise[threadIdx.y] = sum;
|
||||
block.sync();
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y += y_warpwise[i];
|
||||
}
|
||||
|
||||
block.sync();
|
||||
pipe.consumer_release();
|
||||
}
|
||||
|
||||
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||
// final pipeline stage
|
||||
pipe.consumer_wait();
|
||||
block.sync();
|
||||
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||
}
|
||||
y_warpwise[threadIdx.y] =
|
||||
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
|
||||
? sum
|
||||
: 0.f;
|
||||
block.sync();
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y += y_warpwise[i];
|
||||
}
|
||||
|
||||
block.sync();
|
||||
pipe.consumer_release();
|
||||
|
||||
// write Y;
|
||||
if (block.thread_rank() == 0) {
|
||||
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
|
||||
}
|
||||
}
|
||||
|
||||
// nthrs = (2, 16, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
||||
typename in_T, typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
size_t tile_idx = blockIdx.x;
|
||||
|
||||
// load X;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
|
||||
|
||||
// load W;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
|
||||
block.thread_rank() * vec_size);
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
|
||||
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += g.shfl_down(sum, offset);
|
||||
}
|
||||
sum = g.shfl(sum, 0);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
||||
}
|
||||
}
|
||||
|
||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||
typename W_T>
|
||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
constexpr size_t vec_size = 8;
|
||||
constexpr int tz = 4;
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
|
||||
if constexpr (feat_in <= feat_out) {
|
||||
static_assert(feat_in % vec_size == 0);
|
||||
constexpr int tx = feat_in / vec_size;
|
||||
|
||||
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
|
||||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
|
||||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
|
||||
|
||||
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
|
||||
constexpr int ty = 32 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
|
||||
constexpr int ty = 16 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else {
|
||||
constexpr int ty = 8 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
} else {
|
||||
static_assert(feat_in % (vec_size * 32) == 0 ||
|
||||
feat_in % (vec_size * 16) == 0 ||
|
||||
feat_in % (vec_size * 8) == 0);
|
||||
|
||||
if constexpr (feat_in % (vec_size * 32) == 0) {
|
||||
constexpr int tx = 32;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
|
||||
vec_size * sizeof(W_T), tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
|
||||
constexpr int tx = 32;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||
vec_size * sizeof(in_T) / 2,
|
||||
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
|
||||
constexpr int tx = 16;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||
vec_size * sizeof(in_T) / 2,
|
||||
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
|
||||
template void bgmv_kernel<feat_in, feat_out>( \
|
||||
out_T * __restrict__ Y, const in_T *__restrict__ X, \
|
||||
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
|
||||
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
|
||||
int64_t num_layers, int64_t layer_idx, float scale);
|
||||
|
||||
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
|
||||
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
|
||||
|
||||
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
|
||||
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
|
||||
INST_BGMV(wide, narrow, in_T, out_T, W_T)
|
||||
48
csrc_musa/punica/bgmv/generator.py
Normal file
48
csrc_musa/punica/bgmv/generator.py
Normal file
@@ -0,0 +1,48 @@
|
||||
DTYPES = ["fp16", "bf16", "fp32"]
|
||||
DTYPE_MAP = {
|
||||
"fp16": "nv_half",
|
||||
"bf16": "mt_bfloat16",
|
||||
"fp32": "float",
|
||||
}
|
||||
|
||||
TEMPLATE = """
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
for input_dtype in DTYPES:
|
||||
for output_dtype in DTYPES:
|
||||
for weight_dtype in DTYPES:
|
||||
if weight_dtype == "fp32":
|
||||
# FP32 weights are not supported.
|
||||
continue
|
||||
if output_dtype == "fp32":
|
||||
# LoRA A matrix.
|
||||
if input_dtype != weight_dtype:
|
||||
# NOTE(woosuk): While Punica supports the case where the
|
||||
# input and weight dtypes are different, we only generate
|
||||
# the kernels the same dtypes to reduce the binary size.
|
||||
continue
|
||||
elif input_dtype == "fp32":
|
||||
# LoRA B matrix.
|
||||
if output_dtype != weight_dtype:
|
||||
# NOTE(woosuk): While Punica supports the case where the
|
||||
# output and weight dtypes are different, we only generate
|
||||
# the kernels the same dtypes to reduce the binary size.
|
||||
continue
|
||||
elif not (input_dtype == output_dtype == weight_dtype):
|
||||
# NOTE(woosuk): While Punica supports mixed data types for
|
||||
# input, output, and weight, we only generate the kernels with
|
||||
# the same data types to reduce the binary size.
|
||||
continue
|
||||
|
||||
kernel_definition = TEMPLATE.format(
|
||||
input_dtype=DTYPE_MAP[input_dtype],
|
||||
output_dtype=DTYPE_MAP[output_dtype],
|
||||
weight_dtype=DTYPE_MAP[weight_dtype])
|
||||
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
|
||||
with open(filename, "w") as f:
|
||||
f.write(kernel_definition)
|
||||
1324
csrc_musa/punica/bgmv/vec_dtypes.muh
Normal file
1324
csrc_musa/punica/bgmv/vec_dtypes.muh
Normal file
File diff suppressed because it is too large
Load Diff
582
csrc_musa/punica/punica_ops.cc
Normal file
582
csrc_musa/punica/punica_ops.cc
Normal file
@@ -0,0 +1,582 @@
|
||||
#include <musa_bf16.h>
|
||||
#include <musa_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
#include <cstdint>
|
||||
|
||||
#include "bgmv/bgmv_config.h"
|
||||
|
||||
namespace {
|
||||
|
||||
//====== utils ======
|
||||
|
||||
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
||||
const char *a_name, const char *b_name) {
|
||||
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
|
||||
a.dim(), " vs ", b.dim());
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
|
||||
".size(", i, ")");
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
|
||||
return (uint64_t(a) << 32) | uint64_t(b);
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) \
|
||||
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
|
||||
|
||||
#define CHECK_EQ(a, b) \
|
||||
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
//====== bgmv ======
|
||||
|
||||
template <typename in_T, typename out_T, typename W_T>
|
||||
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||
const int64_t *lora_indices,
|
||||
uint32_t in_features, uint32_t out_features,
|
||||
int64_t y_offset, int64_t full_y_size,
|
||||
int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
// NOTE(woosuk): While Punica supports various combinations of input/output
|
||||
// data types, we limit the supported data types to reduce the binary size.
|
||||
constexpr bool is_input_float = std::is_same<in_T, float>::value;
|
||||
constexpr bool is_output_float = std::is_same<out_T, float>::value;
|
||||
if (is_input_float) {
|
||||
if (!std::is_same<out_T, W_T>::value) {
|
||||
return false;
|
||||
}
|
||||
} else if (is_output_float) {
|
||||
if (!std::is_same<in_T, W_T>::value) {
|
||||
return false;
|
||||
}
|
||||
} else if (!(std::is_same<in_T, W_T>::value &&
|
||||
std::is_same<out_T, W_T>::value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (pack_u32(in_features, out_features)) {
|
||||
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||
case pack_u32(feat_in, feat_out): \
|
||||
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||
full_y_size, batch_size, num_layers, \
|
||||
layer_idx, scale); \
|
||||
break;
|
||||
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
|
||||
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
|
||||
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
|
||||
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
|
||||
#undef CASE
|
||||
#undef CASE_ONESIDE
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx, float scale) {
|
||||
CHECK_INPUT(y);
|
||||
CHECK_INPUT(x);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(indicies);
|
||||
|
||||
CHECK_DIM(2, y);
|
||||
CHECK_DIM(2, x);
|
||||
CHECK_DIM(4, w);
|
||||
CHECK_DIM(1, indicies);
|
||||
|
||||
int64_t B = x.size(0);
|
||||
int64_t h_in = x.size(1);
|
||||
int64_t h_out = y.size(1);
|
||||
int64_t num_layers = w.size(1);
|
||||
CHECK_EQ(w.size(3), h_in);
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in <= 128512 && h_out <= 128512) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
||||
|
||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx,
|
||||
float scale, int64_t h_in, int64_t h_out,
|
||||
int64_t y_offset) {
|
||||
CHECK_INPUT(y);
|
||||
CHECK_INPUT(x);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(indicies);
|
||||
|
||||
CHECK_DIM(2, y);
|
||||
CHECK_DIM(2, x);
|
||||
CHECK_DIM(4, w);
|
||||
CHECK_DIM(1, indicies);
|
||||
|
||||
int64_t B = x.size(0);
|
||||
int64_t num_layers = w.size(1);
|
||||
int64_t full_y_size = y.size(1);
|
||||
CHECK_EQ(w.size(3), h_in);
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in <= 128512 && h_out <= 128512) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<mt_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<mt_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//====== pybind ======
|
||||
|
||||
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
||||
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
||||
"dispatch_bgmv_low_level");
|
||||
}
|
||||
136
csrc_musa/pybind.cpp
Normal file
136
csrc_musa/pybind.cpp
Normal file
@@ -0,0 +1,136 @@
|
||||
#include "cache.h"
|
||||
#include "musa_utils.h"
|
||||
#include "ops.h"
|
||||
#include <torch/extension.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// vLLM custom ops
|
||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||
|
||||
// Attention ops
|
||||
ops.def(
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
ops.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
|
||||
// Activation ops
|
||||
ops.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
ops.def(
|
||||
"gelu_and_mul",
|
||||
&gelu_and_mul,
|
||||
"Activation function used in GeGLU with `none` approximation.");
|
||||
ops.def(
|
||||
"gelu_tanh_and_mul",
|
||||
&gelu_tanh_and_mul,
|
||||
"Activation function used in GeGLU with `tanh` approximation.");
|
||||
ops.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
ops.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
|
||||
// Layernorm
|
||||
ops.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
ops.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
|
||||
// Rotary embedding
|
||||
ops.def(
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
ops.def(
|
||||
"batched_rotary_embedding",
|
||||
&batched_rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
// ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
||||
// ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
||||
// ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
// ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
||||
// ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
|
||||
// ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
|
||||
// ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||
#endif
|
||||
|
||||
// ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||
// ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||
// ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
// ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
|
||||
// ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
||||
// ops.def(
|
||||
// "moe_align_block_size",
|
||||
// &moe_align_block_size,
|
||||
// "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
||||
|
||||
// Cache ops
|
||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||
cache_ops.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache_flash",
|
||||
&reshape_and_cache_flash,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"convert_fp8",
|
||||
&convert_fp8,
|
||||
"Convert the key and value cache to fp8 data type");
|
||||
|
||||
// Cuda utils
|
||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||
cuda_utils.def(
|
||||
"get_device_attribute",
|
||||
&get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
|
||||
cuda_utils.def(
|
||||
"get_max_shared_memory_per_block_device_attribute",
|
||||
&get_max_shared_memory_per_block_device_attribute,
|
||||
"Gets the maximum shared memory per block device attribute.");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Custom all-reduce kernels
|
||||
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
|
||||
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
|
||||
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
|
||||
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
|
||||
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
|
||||
custom_ar.def("dispose", &dispose, "dispose");
|
||||
custom_ar.def("meta_size", &meta_size, "meta_size");
|
||||
custom_ar.def("register_buffer", ®ister_buffer, "register_buffer");
|
||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
|
||||
"get_graph_buffer_ipc_meta");
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
||||
"register_graph_buffers");
|
||||
#endif
|
||||
|
||||
}
|
||||
712
csrc_musa/quantization/aqlm/gemm_kernels.mu
Normal file
712
csrc_musa/quantization/aqlm/gemm_kernels.mu
Normal file
@@ -0,0 +1,712 @@
|
||||
/*
|
||||
* Modified by Neural Magic
|
||||
* Adapted from https://github.com/Vahe1994/AQLM
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <musa.h>
|
||||
#include <musa_fp16.h>
|
||||
#include <musa_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
#include "torch_musa/csrc/core/MUSAStream.h"
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
|
||||
namespace vllm {
|
||||
namespace aqlm {
|
||||
|
||||
__global__ void Code1x16MatVec(
|
||||
const int4* __restrict__ A,
|
||||
const int4* __restrict__ B,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
const int prob_m,
|
||||
const int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
int b_gl_rd = 0;
|
||||
int c_gl_wr = a_gl_rd;
|
||||
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||
|
||||
__shared__ int4 sh_b[32 * 9];
|
||||
float res = 0;
|
||||
|
||||
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
|
||||
while (iters--) {
|
||||
// We pad shared memory to avoid bank conflicts during reads
|
||||
__syncthreads();
|
||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||
if (b_gl_rd + i < prob_k / 8)
|
||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
}
|
||||
__syncthreads();
|
||||
b_gl_rd += 32 * 8;
|
||||
|
||||
int b_sh_rd = 9 * (threadIdx.x % 32);
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint32_t dec[4];
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
||||
// actually help us; this brings > 2x speedup.
|
||||
asm volatile (
|
||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||
: "l"((void*) &codebook[enc[i]])
|
||||
);
|
||||
half2* a = reinterpret_cast<half2*>(&dec);
|
||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||
half2 res2 = {};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
res2 = __hfma2(a[j], b[j], res2);
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
b_sh_rd++;
|
||||
}
|
||||
a_gl_rd += 32;
|
||||
}
|
||||
}
|
||||
|
||||
if (pred) {
|
||||
#pragma unroll
|
||||
for (int i = 16; i > 0; i /= 2)
|
||||
res += __shfl_down_sync(0xffffffff, res, i);
|
||||
if (threadIdx.x % 32 == 0)
|
||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Code2x8MatVec(
|
||||
const int4* __restrict__ A,
|
||||
const int4* __restrict__ B,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
int b_gl_rd = 0;
|
||||
int c_gl_wr = a_gl_rd;
|
||||
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||
int lane = threadIdx.x % 8;
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
int4* sh_b = sh;
|
||||
int4* sh_code = sh_b + 32 * 9;
|
||||
int4* sh_code0 = sh_code;
|
||||
int4* sh_code1 = sh_code + 256 * 8;
|
||||
|
||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||
int4 dec = codebook[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; j++)
|
||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float res = 0;
|
||||
|
||||
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
|
||||
while (iters--) {
|
||||
// We pad shared memory to avoid bank conflicts during reads
|
||||
__syncthreads();
|
||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||
if (b_gl_rd + i < prob_k / 8)
|
||||
sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
}
|
||||
__syncthreads();
|
||||
b_gl_rd += 32 * 8;
|
||||
|
||||
int b_sh_rd = 9 * (threadIdx.x % 32);
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||
half2 res2 = {};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
b_sh_rd++;
|
||||
}
|
||||
a_gl_rd += 32;
|
||||
}
|
||||
}
|
||||
|
||||
if (pred) {
|
||||
#pragma unroll
|
||||
for (int i = 16; i > 0; i /= 2)
|
||||
res += __shfl_down_sync(0xffffffff, res, i);
|
||||
if (threadIdx.x % 32 == 0)
|
||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void Code1x16Dequant(
|
||||
const int4* __restrict__ A,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||
|
||||
int c_gl_stride = prob_k / 8;
|
||||
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
|
||||
|
||||
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
|
||||
while (iters--) {
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int4 chunk;
|
||||
auto dec = reinterpret_cast<uint32_t*>(&chunk);
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
|
||||
// actually help us; this brings > 2x speedup.
|
||||
asm volatile (
|
||||
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||
: "l"((void*) &codebook[enc[i]])
|
||||
);
|
||||
|
||||
C[a_gl_rd * 8 + i] = chunk;
|
||||
}
|
||||
}
|
||||
a_gl_rd += 32;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void Code2x8Dequant(
|
||||
const int4* __restrict__ A,
|
||||
int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred)
|
||||
{
|
||||
// advance to the correct codebook, this easy because we only multiply one column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size)
|
||||
{
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||
int lane = threadIdx.x % 8;
|
||||
|
||||
int c_gl_stride = prob_k / 8;
|
||||
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
int4* sh_code = sh;
|
||||
int4* sh_code0 = sh_code;
|
||||
int4* sh_code1 = sh_code + 256 * 8;
|
||||
|
||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||
int4 dec = codebook[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; j++)
|
||||
sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float res = 0;
|
||||
|
||||
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
|
||||
while (iters--) {
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int4 chunk;
|
||||
half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
|
||||
C[a_gl_rd * 8 + i] = chunk;
|
||||
}
|
||||
}
|
||||
a_gl_rd += 32;
|
||||
}
|
||||
}
|
||||
|
||||
inline int ceildiv(int a, int b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
const int THREAD_M = 16;
|
||||
|
||||
void code1x16_matvec_cuda(
|
||||
const void* __restrict__ A,
|
||||
const void* __restrict__ B,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes,
|
||||
const int codebook_stride
|
||||
) {
|
||||
int sms;
|
||||
musaDeviceGetAttribute(&sms, musaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
int thread_m;
|
||||
do {
|
||||
waves++;
|
||||
thread_m = ceildiv(prob_m, waves * sms);
|
||||
} while (thread_m > THREAD_M);
|
||||
|
||||
int blocks = ceildiv(prob_m, thread_m);
|
||||
int threads = 32 * thread_m;
|
||||
musaStream_t stream = at::musa::getCurrentMUSAStream().stream();
|
||||
Code1x16MatVec<<<blocks, threads, 16*32*9, stream>>>(
|
||||
(const int4*) A,
|
||||
(const int4*) B,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride
|
||||
);
|
||||
}
|
||||
|
||||
void code2x8_matvec_cuda(
|
||||
const void* __restrict__ A,
|
||||
const void* __restrict__ B,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes,
|
||||
const int codebook_stride
|
||||
) {
|
||||
int sms;
|
||||
musaDeviceGetAttribute(&sms, musaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
int thread_m;
|
||||
do {
|
||||
waves++;
|
||||
thread_m = ceildiv(prob_m, waves * sms);
|
||||
} while (thread_m > THREAD_M);
|
||||
|
||||
int blocks = ceildiv(prob_m, thread_m);
|
||||
int threads = 32 * thread_m;
|
||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||
musaFuncSetAttribute(
|
||||
Code2x8MatVec, musaFuncAttributeMaxDynamicSharedMemorySize, shared
|
||||
);
|
||||
musaStream_t stream = at::musa::getCurrentMUSAStream().stream();
|
||||
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
|
||||
(const int4*) A,
|
||||
(const int4*) B,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride
|
||||
);
|
||||
}
|
||||
|
||||
void code1x16_dequant_cuda(
|
||||
const void* __restrict__ A,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
) {
|
||||
int sms;
|
||||
musaDeviceGetAttribute(&sms, musaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
int thread_m;
|
||||
do {
|
||||
waves++;
|
||||
thread_m = ceildiv(prob_m, waves * sms);
|
||||
} while (thread_m > THREAD_M);
|
||||
|
||||
int blocks = ceildiv(prob_m, thread_m);
|
||||
int threads = 32 * thread_m;
|
||||
musaStream_t stream = at::musa::getCurrentMUSAStream().stream();
|
||||
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
|
||||
(const int4*) A,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
codebook_stride // as int4.
|
||||
);
|
||||
}
|
||||
|
||||
// Dequantizes the code and codebook into weights.
|
||||
void code2x8_dequant_cuda(
|
||||
const void* __restrict__ A,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook,
|
||||
int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int sms;
|
||||
musaDeviceGetAttribute(&sms, musaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
int thread_m;
|
||||
do {
|
||||
waves++;
|
||||
thread_m = ceildiv(prob_m, waves * sms);
|
||||
} while (thread_m > THREAD_M);
|
||||
|
||||
int blocks = ceildiv(prob_m, thread_m);
|
||||
int threads = 32 * thread_m;
|
||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||
musaStream_t stream = at::musa::getCurrentMUSAStream().stream();
|
||||
|
||||
musaFuncSetAttribute(
|
||||
Code2x8Dequant, musaFuncAttributeMaxDynamicSharedMemorySize, shared
|
||||
);
|
||||
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
|
||||
(const int4*) A,
|
||||
(int4*) C,
|
||||
(const int4*) codebook,
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride
|
||||
);
|
||||
}
|
||||
|
||||
int codebook_stride(const torch::Tensor& codebooks)
|
||||
{
|
||||
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
|
||||
}
|
||||
|
||||
void code1x16_matvec(
|
||||
const torch::Tensor& A,
|
||||
const torch::Tensor& B,
|
||||
torch::Tensor& C,
|
||||
const torch::Tensor& codebook,
|
||||
const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
|
||||
) {
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(A));
|
||||
int prob_m = C.size(0);
|
||||
int prob_k = B.size(0);
|
||||
|
||||
code1x16_matvec_cuda(
|
||||
A.data_ptr(),
|
||||
B.data_ptr(),
|
||||
C.data_ptr(),
|
||||
codebook.data_ptr(),
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
codebook_stride(codebook)
|
||||
);
|
||||
}
|
||||
|
||||
torch::Tensor code1x16_matmat(
|
||||
const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const int4 codebook_a_sizes,
|
||||
const std::optional<torch::Tensor>& bias) {
|
||||
auto input_sizes = input.sizes();
|
||||
auto out_features = codes.size(0) * codebooks.size(2);
|
||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
||||
torch::TensorOptions()
|
||||
.dtype(input.dtype())
|
||||
.device(input.device())
|
||||
);
|
||||
|
||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||
auto input_vec = flat_input.index({i});
|
||||
auto output_vec = flat_output.index({i});
|
||||
code1x16_matvec(
|
||||
codes.squeeze(2),
|
||||
input_vec,
|
||||
output_vec,
|
||||
codebooks,
|
||||
codebook_a_sizes
|
||||
);
|
||||
}
|
||||
flat_output *= scales.flatten().unsqueeze(0);
|
||||
|
||||
if (bias.has_value()) {
|
||||
flat_output += bias->unsqueeze(0);
|
||||
}
|
||||
|
||||
auto output_sizes = input_sizes.vec();
|
||||
output_sizes.pop_back();
|
||||
output_sizes.push_back(-1);
|
||||
auto output = flat_output.reshape(output_sizes);
|
||||
return output;
|
||||
}
|
||||
|
||||
void code2x8_matvec(
|
||||
const torch::Tensor& A,
|
||||
const torch::Tensor& B,
|
||||
torch::Tensor& C,
|
||||
const torch::Tensor& codebook,
|
||||
const int4 codebook_a_sizes
|
||||
) {
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(A));
|
||||
int prob_m = C.size(0);
|
||||
int prob_k = B.size(0);
|
||||
code2x8_matvec_cuda(
|
||||
A.data_ptr(),
|
||||
B.data_ptr(),
|
||||
C.data_ptr(),
|
||||
codebook.data_ptr(),
|
||||
prob_m,
|
||||
prob_k,
|
||||
codebook_a_sizes,
|
||||
2 * codebook_stride(codebook)
|
||||
);
|
||||
}
|
||||
|
||||
torch::Tensor code2x8_matmat(
|
||||
const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const int4 codebook_a_sizes,
|
||||
const std::optional<torch::Tensor>& bias
|
||||
) {
|
||||
auto input_sizes = input.sizes();
|
||||
auto out_features = codes.size(0) * codebooks.size(2);
|
||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||
auto flat_output = torch::empty({flat_input.size(0), out_features},
|
||||
torch::TensorOptions()
|
||||
.dtype(input.dtype())
|
||||
.device(input.device())
|
||||
);
|
||||
|
||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||
auto input_vec = flat_input.index({i});
|
||||
auto output_vec = flat_output.index({i});
|
||||
code2x8_matvec(
|
||||
codes.squeeze(2),
|
||||
input_vec,
|
||||
output_vec,
|
||||
codebooks,
|
||||
codebook_a_sizes
|
||||
);
|
||||
}
|
||||
flat_output *= scales.flatten().unsqueeze(0);
|
||||
if (bias.has_value()) {
|
||||
flat_output += bias->unsqueeze(0);
|
||||
}
|
||||
|
||||
auto output_sizes = input_sizes.vec();
|
||||
output_sizes.pop_back();
|
||||
output_sizes.push_back(-1);
|
||||
auto output = flat_output.reshape(output_sizes);
|
||||
return output;
|
||||
}
|
||||
|
||||
// Accumulate the partition sizes.
|
||||
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
|
||||
{
|
||||
int4 cumulative_sizes;
|
||||
auto cumulative_size = &cumulative_sizes.x;
|
||||
int i = 0;
|
||||
int last = 0;
|
||||
assert(codebook_partition_sizes.size(0) <= 4);
|
||||
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size)
|
||||
{
|
||||
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
|
||||
last = *cumulative_size;
|
||||
}
|
||||
// fill in the rest with unreachable.
|
||||
for (; i < 4; ++i, ++cumulative_size)
|
||||
{
|
||||
*cumulative_size = last*10;
|
||||
}
|
||||
return cumulative_sizes;
|
||||
}
|
||||
|
||||
} // namespace aqlm
|
||||
} // namespace vllm
|
||||
|
||||
|
||||
torch::Tensor aqlm_gemm(
|
||||
const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias
|
||||
)
|
||||
{
|
||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||
int const entries = codebooks.size(1);
|
||||
|
||||
if (nbooks == 1 && entries == (1 << 16))
|
||||
{
|
||||
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
||||
}
|
||||
if (nbooks == 2 && entries == (1 << 8))
|
||||
{
|
||||
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
||||
return {};
|
||||
}
|
||||
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes
|
||||
)
|
||||
{
|
||||
int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
|
||||
int const entries = codebooks.size(1);
|
||||
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(codes));
|
||||
int rows = codes.size(1);
|
||||
int cols = codes.size(0);
|
||||
|
||||
auto in_features = codes.size(1) * 8;
|
||||
auto out_features = codes.size(0);
|
||||
|
||||
assert(out_features = codebook_partition_sizes.sum().item<int>());
|
||||
|
||||
auto weights = torch::empty({out_features, in_features},
|
||||
torch::TensorOptions()
|
||||
.dtype(codebooks.dtype())
|
||||
.device(codebooks.device())
|
||||
);
|
||||
|
||||
if (nbooks == 1 && entries == (1 << 16))
|
||||
{
|
||||
vllm::aqlm::code1x16_dequant_cuda(
|
||||
codes.data_ptr(),
|
||||
weights.data_ptr(),
|
||||
codebooks.data_ptr(),
|
||||
out_features,
|
||||
in_features,
|
||||
cumulative_sizes,
|
||||
vllm::aqlm::codebook_stride(codebooks));
|
||||
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
|
||||
// weights *= scales.index({"...", 0, 0});
|
||||
|
||||
return weights;
|
||||
}
|
||||
|
||||
if (nbooks == 2 && entries == (1 << 8))
|
||||
{
|
||||
vllm::aqlm::code2x8_dequant_cuda(
|
||||
codes.data_ptr(),
|
||||
weights.data_ptr(),
|
||||
codebooks.data_ptr(),
|
||||
out_features,
|
||||
in_features,
|
||||
cumulative_sizes,
|
||||
vllm::aqlm::codebook_stride(codebooks));
|
||||
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
|
||||
// weights *= scales.index({"...", 0, 0});
|
||||
|
||||
return weights;
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
|
||||
return {};
|
||||
}
|
||||
87
csrc_musa/quantization/awq/dequantize.muh
Normal file
87
csrc_musa/quantization/awq/dequantize.muh
Normal file
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||
{
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 750
|
||||
assert(false);
|
||||
#else
|
||||
uint4 result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||
// immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
|
||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||
|
||||
// This is the half2 {1032, 1032} represented as an integer.
|
||||
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-72, -72} represented as an integer.
|
||||
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||
// Haotian: Let's use {-64, -64}.
|
||||
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
446
csrc_musa/quantization/awq/gemm_kernels.mu
Normal file
446
csrc_musa/quantization/awq/gemm_kernels.mu
Normal file
@@ -0,0 +1,446 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
|
||||
#include "dequantize.cuh"
|
||||
|
||||
#include <musa_fp16.h>
|
||||
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
// Pack two half values.
|
||||
static inline __device__ __host__ unsigned
|
||||
__pack_half2(const half x, const half y) {
|
||||
unsigned v0 = *((unsigned short *)&x);
|
||||
unsigned v1 = *((unsigned short *)&y);
|
||||
return (v1 << 16) | v0;
|
||||
}
|
||||
|
||||
template<int N>
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
int G,
|
||||
int split_k_iters,
|
||||
half* __restrict__ A,
|
||||
int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros,
|
||||
int M,
|
||||
int IC,
|
||||
int OC,
|
||||
half* __restrict__ C)
|
||||
{
|
||||
// Only support matrix n = 64 or 128
|
||||
assert(N == 64 || N == 128);
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 750
|
||||
assert(false);
|
||||
#else
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (N + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[N];
|
||||
__shared__ half zeros_shared[N];
|
||||
|
||||
int j_factors1 = ((OC + N - 1) / N);
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
|
||||
half A_shared_warp[8];
|
||||
half B_shared_warp[N / 4];
|
||||
for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / N;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
|
||||
half* A_ptr = A
|
||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * (256 / N)
|
||||
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
|
||||
+ (((int)threadIdx.x) / (N / 8)) * (N + 8)
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||
+ ((int)threadIdx.x) % (N / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * N
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * N
|
||||
+ ((int)threadIdx.y) * (N / 2)
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||
__syncthreads();
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
if (ld_A_flag)
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
/*
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
}
|
||||
*/
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
/*
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
}
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
|
||||
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
}
|
||||
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ == 750
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
#else
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M)
|
||||
{
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(64) dequantize_weights(
|
||||
int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros,
|
||||
half* __restrict__ C,
|
||||
int G
|
||||
)
|
||||
{
|
||||
int j_factors1 = 4;
|
||||
int row_stride2 = 4;
|
||||
int split_k_iters = 1;
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
half B_shared[32 * (128 + 8)];
|
||||
|
||||
half* B_shared_ptr2 = B_shared;
|
||||
|
||||
half B_shared_warp[32];
|
||||
int OC = 512;
|
||||
|
||||
int N = blockDim.x * gridDim.x; // 2
|
||||
int col = (blockIdx.x * blockDim.x + threadIdx.x);
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int index1 = 8 * col + 8 * row * N;
|
||||
half* C_ptr2 = C + index1;
|
||||
|
||||
int index2 = col + row * N;
|
||||
int* B_ptr2 = B + index2;
|
||||
|
||||
int index3 = col + (int)(row / G) * N;
|
||||
int* zeros_ptr2 = zeros + index3;
|
||||
int index4 = 8 * col + (int)(row / G) * N * 8;
|
||||
half* scaling_factors_ptr2 = scaling_factors + index4;
|
||||
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
|
||||
|
||||
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
|
||||
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
||||
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
*(C_ptr2 + i) = B_shared[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy)
|
||||
{
|
||||
int in_c = _kernel.size(0);
|
||||
int qout_c = _kernel.size(1);
|
||||
int out_c = qout_c * 8;
|
||||
int G = in_c / _scaling_factors.size(0);
|
||||
|
||||
int x_thread = thx;
|
||||
int y_thread = thy;
|
||||
|
||||
int x_blocks = 1;
|
||||
int y_blocks = 1;
|
||||
if (thx==0) {
|
||||
x_thread = qout_c;
|
||||
}
|
||||
if (thy==0) {
|
||||
y_thread = in_c;
|
||||
}
|
||||
if (thx==0 && thy==0) {
|
||||
x_thread = 8;
|
||||
y_thread = 8;
|
||||
x_blocks = (int)(qout_c / 8);
|
||||
y_blocks = (int)(in_c / 8);
|
||||
}
|
||||
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(_scaling_factors));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
|
||||
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
|
||||
dim3 num_blocks(x_blocks, y_blocks);
|
||||
dim3 threads_per_block(x_thread, y_thread);
|
||||
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
kernel, scaling_factors, zeros, de_kernel, G);
|
||||
|
||||
return _de_kernel;
|
||||
}
|
||||
|
||||
// in_feats: M, IC [float16]
|
||||
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||
// scaling_factors: IC // G, OC [float16]
|
||||
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||
// assume that batch_size < 16 for now
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters)
|
||||
{
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(_in_feats));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
||||
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
|
||||
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||
|
||||
if (num_out_channels % 64 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||
if (num_out_channels % 8 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||
if (group_size % 32 != 0)
|
||||
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||
if (num_out_channels % group_size != 0)
|
||||
throw std::invalid_argument("OC is not multiple of Group size");
|
||||
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
if (num_out_channels % 128 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 128 / 1;
|
||||
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||
num_out_channels, out_feats);
|
||||
}
|
||||
else if (num_out_channels % 64 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 64 / 1;
|
||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||
num_out_channels, out_feats);
|
||||
}
|
||||
return _out_feats.sum(0);
|
||||
}
|
||||
167
csrc_musa/quantization/fp8/amd_detail/hip_float8.h
Normal file
167
csrc_musa/quantization/fp8/amd_detail/hip_float8.h
Normal file
@@ -0,0 +1,167 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#include <hip/hip_runtime.h>
|
||||
#else
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <iostream>
|
||||
#endif
|
||||
|
||||
#include "hip_float8_impl.h"
|
||||
|
||||
struct alignas(1) hip_fp8
|
||||
{
|
||||
struct from_bits_t
|
||||
{
|
||||
};
|
||||
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); }
|
||||
uint8_t data;
|
||||
|
||||
hip_fp8() = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
||||
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
||||
: data(v)
|
||||
{
|
||||
}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// NOTE: ON-DEVICE... always optimal bias
|
||||
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
||||
: data(hip_fp8_impl::to_fp8_from_fp32(v))
|
||||
{
|
||||
}
|
||||
|
||||
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
||||
: hip_fp8(static_cast<float>(v))
|
||||
{
|
||||
}
|
||||
|
||||
// Host only implementation using s/w simulation
|
||||
explicit HIP_FP8_HOST
|
||||
#else // __HIP__MI300__
|
||||
// both Host and DEVICE for non-MI300 using s/w simulation
|
||||
explicit HIP_FP8_HOST_DEVICE
|
||||
#endif // __HIP__MI300__
|
||||
hip_fp8(float v)
|
||||
{
|
||||
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v);
|
||||
}
|
||||
|
||||
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
||||
: hip_fp8(static_cast<float>(v))
|
||||
{
|
||||
}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// upcast using device specific intrinsic
|
||||
explicit inline HIP_FP8_DEVICE operator float() const
|
||||
{
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(data);
|
||||
|
||||
// upcast
|
||||
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
|
||||
|
||||
return fval;
|
||||
}
|
||||
|
||||
explicit inline HIP_FP8_HOST operator float() const
|
||||
#else // __HIP__MI300__
|
||||
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
||||
#endif // __HIP__MI300__
|
||||
{
|
||||
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data);
|
||||
}
|
||||
};
|
||||
|
||||
namespace std
|
||||
{
|
||||
inline hip_fp8 sin(hip_fp8 a)
|
||||
{
|
||||
return hip_fp8(sinf(float(a)));
|
||||
}
|
||||
inline hip_fp8 cos(hip_fp8 a)
|
||||
{
|
||||
return hip_fp8(cosf(float(a)));
|
||||
}
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a)
|
||||
{
|
||||
return a;
|
||||
}
|
||||
} // namespace std
|
||||
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8)
|
||||
{
|
||||
return os << float(f8);
|
||||
}
|
||||
|
||||
// all + operator overloading with mixed types
|
||||
// mixed types, always converts to f32, does computation in f32, and returns float
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b)
|
||||
{
|
||||
return (fa + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb)
|
||||
{
|
||||
return (float(a) + fb);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b)
|
||||
{
|
||||
return a = hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
// overloading multiplication, always returns float,
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b)
|
||||
{
|
||||
return (a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b)
|
||||
{
|
||||
return (float(a) * b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b)
|
||||
{
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b)
|
||||
{
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
// overloading for compare
|
||||
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return (a.data == b.data);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return (a.data != b.data);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return static_cast<float>(a) >= static_cast<float>(b);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b)
|
||||
{
|
||||
return static_cast<float>(a) > static_cast<float>(b);
|
||||
}
|
||||
316
csrc_musa/quantization/fp8/amd_detail/hip_float8_impl.h
Normal file
316
csrc_musa/quantization/fp8/amd_detail/hip_float8_impl.h
Normal file
@@ -0,0 +1,316 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#define __HIP__MI300__
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
||||
#define HIP_FP8_HOST __host__
|
||||
#define HIP_FP8_DEVICE __device__
|
||||
#else
|
||||
#define HIP_FP8_HOST_DEVICE
|
||||
#define HIP_FP8_HOST
|
||||
#define HIP_FP8_DEVICE
|
||||
#endif
|
||||
|
||||
namespace hip_fp8_impl
|
||||
{
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
|
||||
{
|
||||
uint8_t i8data;
|
||||
union {
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // NOTE: not endian independent
|
||||
} val;
|
||||
|
||||
uint32_t ival = 0;
|
||||
val.fval = v;
|
||||
|
||||
if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||
}
|
||||
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
||||
false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
i8data = val.i8val[0];
|
||||
|
||||
return i8data;
|
||||
}
|
||||
#endif // __HIP__MI300__
|
||||
|
||||
HIP_FP8_HOST inline int clz(uint32_t x)
|
||||
{
|
||||
return __builtin_clz(x);
|
||||
}
|
||||
#if defined(__HIPCC__) || defined(__MUSA_ARCH__)
|
||||
HIP_FP8_DEVICE inline int clz(uint32_t x)
|
||||
{
|
||||
return __clz(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
||||
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0)
|
||||
{
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(wm + we == 7, "wm+we==7");
|
||||
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
||||
|
||||
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
||||
uint32_t x;
|
||||
if (sizeof(T) == 4) {
|
||||
x = reinterpret_cast<uint32_t&>(_x);
|
||||
} else {
|
||||
x = reinterpret_cast<uint16_t&>(_x);
|
||||
}
|
||||
|
||||
uint32_t head, mantissa;
|
||||
int exponent, bias;
|
||||
uint32_t sign;
|
||||
|
||||
if (sizeof(T) == 4) {
|
||||
head = x & 0xFF800000;
|
||||
mantissa = x & 0x7FFFFF;
|
||||
exponent = (head >> 23) & 0xFF;
|
||||
sign = head >> 31;
|
||||
bias = 127;
|
||||
} else {
|
||||
head = x & 0xFC00;
|
||||
mantissa = x & 0x3FF;
|
||||
exponent = (head >> 10) & 0x1F;
|
||||
sign = head >> 15;
|
||||
bias = 15;
|
||||
}
|
||||
|
||||
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
||||
|
||||
// Deal with inf and NaNs
|
||||
if (negative_zero_nan) {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return 0x80;
|
||||
}
|
||||
} else {
|
||||
// if(__hisinf(x) || __hisnan(x))
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return 0x80;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
} else {
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of
|
||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
||||
// need to check whether there is carry and adjust exponent and mantissa again
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||
// bits
|
||||
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// f8_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
// the difference needs to be adjusted and mantissa shifted
|
||||
int act_exponent, f8_exponent, exponent_diff;
|
||||
|
||||
if (exponent == 0) { // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
||||
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
||||
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
||||
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
||||
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
} else { // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if (act_exponent <= f8_denormal_act_exponent) {
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
||||
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
||||
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||
} else { // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
|
||||
// for this case,
|
||||
// act_exponent could be larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||
done before we shift right as shift right could rip off some residual part
|
||||
and make something not midpoint look like midpoint. For example, the fp16
|
||||
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
||||
shift right by 4 bits, it would look like midpoint.
|
||||
*/
|
||||
|
||||
if (exponent_diff > 0) {
|
||||
mantissa >>= exponent_diff;
|
||||
} else if (exponent_diff == -1) {
|
||||
mantissa <<= -exponent_diff;
|
||||
}
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||
// to denorm exponent
|
||||
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
||||
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that
|
||||
// is not truncated is 1
|
||||
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if (f8_exponent == 0) {
|
||||
if ((1 << mfmt) & mantissa) {
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
} else {
|
||||
if ((1 << (mfmt + 1)) & mantissa) {
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - wm);
|
||||
|
||||
// above range: quantize to maximum possible float of the same sign
|
||||
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
||||
if (f8_exponent > max_exp) {
|
||||
if (clip) {
|
||||
mantissa = (1 << wm) - 1;
|
||||
f8_exponent = max_exp;
|
||||
} else {
|
||||
return signed_inf;
|
||||
}
|
||||
}
|
||||
|
||||
if (f8_exponent == 0 && mantissa == 0) {
|
||||
return negative_zero_nan ? 0 : (sign << 7);
|
||||
}
|
||||
mantissa &= (1 << wm) - 1;
|
||||
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||
}
|
||||
|
||||
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
||||
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
|
||||
{
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(is_half || is_float, "only half and float are supported");
|
||||
|
||||
constexpr int weo = is_half ? 5 : 8;
|
||||
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
||||
|
||||
T fInf, fNegInf, fNaN, fNeg0;
|
||||
|
||||
#ifdef __HIPCC__
|
||||
if (is_half) {
|
||||
const uint16_t ihInf = 0x7C00;
|
||||
const uint16_t ihNegInf = 0xFC00;
|
||||
const uint16_t ihNaN = 0x7C01;
|
||||
const uint16_t ihNeg0 = 0x8000;
|
||||
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
||||
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
||||
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
||||
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
||||
} else
|
||||
#endif
|
||||
if (is_float) {
|
||||
const uint32_t ifInf = 0x7F800000;
|
||||
const uint32_t ifNegInf = 0xFF800000;
|
||||
const uint32_t ifNaN = 0x7F800001;
|
||||
const uint32_t ifNeg0 = 0x80000000;
|
||||
fInf = reinterpret_cast<const float&>(ifInf);
|
||||
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
||||
fNaN = reinterpret_cast<const float&>(ifNaN);
|
||||
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
||||
}
|
||||
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t sign = x >> 7;
|
||||
uint32_t mantissa = x & ((1 << wm) - 1);
|
||||
int exponent = (x & 0x7F) >> wm;
|
||||
if (negative_zero_nan) {
|
||||
if (x == 0x80) {
|
||||
return fNaN;
|
||||
}
|
||||
} else {
|
||||
if (x == 0x80) {
|
||||
return fNeg0;
|
||||
}
|
||||
if (exponent == ((1 << we) - 1)) {
|
||||
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
||||
}
|
||||
}
|
||||
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
||||
if (we == 5 && is_half && !negative_zero_nan) {
|
||||
retval = x << 8;
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
|
||||
// subnormal input
|
||||
if (exponent == 0) {
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + clz(mantissa) - (32 - wm);
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << wm) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= wmo - wm;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if (exponent <= 0) {
|
||||
mantissa |= 1 << wmo;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
if (sizeof(T) == 2) {
|
||||
retval = (sign << 15) | (exponent << 10) | mantissa;
|
||||
} else {
|
||||
retval = (sign << 31) | (exponent << 23) | mantissa;
|
||||
}
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
} // namespace hip_fp8_impl
|
||||
517
csrc_musa/quantization/fp8/amd_detail/quant_utils.muh
Normal file
517
csrc_musa/quantization/fp8/amd_detail/quant_utils.muh
Normal file
@@ -0,0 +1,517 @@
|
||||
#pragma once
|
||||
#include "hip_float8.h"
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_bfloat16.h>
|
||||
|
||||
#include "../../../attention/dtype_float32.cuh"
|
||||
#include "../../../attention/dtype_bfloat16.cuh"
|
||||
|
||||
namespace vllm
|
||||
{
|
||||
namespace fp8_e4m3 {
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0];
|
||||
tmp.h2r.y.data = f2[1];
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
||||
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||
{
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __mt_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __mt_bfloat16 vec_conversion<__mt_bfloat16, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f);
|
||||
}
|
||||
|
||||
using __mt_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __mt_bfloat162 vec_conversion<__mt_bfloat162, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__mt_bfloat162 res;
|
||||
res.x = vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||
{
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0];
|
||||
res.y = f2[1];
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
||||
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||
{
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data)};
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __mt_bfloat16>(const __mt_bfloat16& a)
|
||||
{
|
||||
hip_fp8 res{__bfloat162float(a)};
|
||||
return res.data;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||
{
|
||||
hip_fp8 f8(a);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
// float2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||
{
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
// Float4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||
{
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||
{
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||
{
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
// float2 -> bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __mt_bfloat162 vec_conversion<__mt_bfloat162, float2>(const float2& a)
|
||||
{
|
||||
__mt_bfloat162 b = __float22bfloat162_rn(a);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float4 -> bfloat162x2
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
|
||||
{
|
||||
bf16_4_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
return b;
|
||||
}
|
||||
|
||||
// Float8 -> bfloat162x4
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
|
||||
{
|
||||
bf16_8_t b;
|
||||
b.x = __float22bfloat162_rn(a.x);
|
||||
b.y = __float22bfloat162_rn(a.y);
|
||||
b.z = __float22bfloat162_rn(a.z);
|
||||
b.w = __float22bfloat162_rn(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
|
||||
/* Scaled and vectorized conversions, for data exchange between high and low precision domains
|
||||
|
||||
Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
|
||||
s.t.
|
||||
Quantize(HP / scale) => FP8
|
||||
Dequant(FP8) * scale => HP
|
||||
|
||||
*/
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8) * scale;
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0] * scale;
|
||||
tmp.h2r.y.data = f2[1] * scale;
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||
tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __mt_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __mt_bfloat16 scaled_vec_conversion<__mt_bfloat16, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f * scale);
|
||||
}
|
||||
|
||||
using __mt_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __mt_bfloat162 scaled_vec_conversion<__mt_bfloat162, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
__mt_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
|
||||
{
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0] * scale;
|
||||
res.y = f2[1] * scale;
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
Float4_ res;
|
||||
res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
|
||||
{
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
/* Quantize(HP / scale) => FP8 */
|
||||
|
||||
// TODO(Hai): vectorized to add
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
|
||||
{
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data)/scale};
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __mt_bfloat16>(const __mt_bfloat16& a, const float scale)
|
||||
{
|
||||
hip_fp8 res{__bfloat162float(a)/scale};
|
||||
return res.data;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
|
||||
{
|
||||
hip_fp8 f8(a/scale);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
|
||||
{
|
||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace vllm
|
||||
126
csrc_musa/quantization/fp8/fp8_cuda_kernels.mu
Normal file
126
csrc_musa/quantization/fp8/fp8_cuda_kernels.mu
Normal file
@@ -0,0 +1,126 @@
|
||||
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
|
||||
#include <torch/extension.h>
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "musa_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
float old;
|
||||
old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
|
||||
__uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
|
||||
return old;
|
||||
}
|
||||
|
||||
// Compute the absolute maximum m of the input tensor and store
|
||||
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
||||
// reduction tree and the memory in scale is atomically updated.
|
||||
// So to get the right answer, *scale needs to be initialized to
|
||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||
// finish before consuming *scale.
|
||||
template<typename scalar_t>
|
||||
__global__ void segmented_max_reduction(
|
||||
float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
__shared__ float cache[1024];
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// First store maximum for all values processes by
|
||||
// the current thread in cache[threadIdx.x]
|
||||
scalar_t tmp = 0.0;
|
||||
while (i < num_elems) {
|
||||
float x = static_cast<float>(input[i]);
|
||||
tmp = max(tmp, fabs(x));
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
cache[threadIdx.x] = tmp;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Now perform parallel reduction within the thread block
|
||||
int ib = blockDim.x / 2;
|
||||
while (ib != 0) {
|
||||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||
}
|
||||
__syncthreads();
|
||||
ib /= 2;
|
||||
}
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(
|
||||
c10::Float8_e4m3fn* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
while (i < num_elems) {
|
||||
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(1024);
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(input));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"scaled_fp8_quant_kernel",
|
||||
[&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(),
|
||||
num_elems);
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(1024);
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(input));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"scaled_fp8_quant_kernel",
|
||||
[&] {
|
||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(),
|
||||
num_elems);
|
||||
});
|
||||
}
|
||||
|
||||
277
csrc_musa/quantization/fp8_e5m2_kvcache/quant_utils.muh
Normal file
277
csrc_musa/quantization/fp8_e5m2_kvcache/quant_utils.muh
Normal file
@@ -0,0 +1,277 @@
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <float.h>
|
||||
#include <type_traits>
|
||||
#include "../../attention/attention_dtypes.h"
|
||||
#include "../../attention/dtype_float32.cuh"
|
||||
#include "../../attention/dtype_float16.cuh"
|
||||
#include "../../attention/dtype_bfloat16.cuh"
|
||||
|
||||
|
||||
namespace vllm {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
namespace fp8_e5m2_unscaled {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template<>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template<>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
|
||||
tmp.u16[0] = res.x;
|
||||
tmp.u16[1] = res.y;
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template<>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template<>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||
{
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template<>
|
||||
__inline__ __device__ __mt_bfloat16 vec_conversion<__mt_bfloat16, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||
// half -> float -> bf16
|
||||
float tmp = half_to_float(res.x);
|
||||
return __float2bfloat16(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template<>
|
||||
__inline__ __device__ __mt_bfloat162 vec_conversion<__mt_bfloat162, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__mt_bfloat162 res;
|
||||
res.x = vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__mt_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template<>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__mt_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template<>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||
{
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template<>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
// fp8 -> half
|
||||
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
|
||||
// half -> float
|
||||
return half_to_float(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template<>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
|
||||
// half2 -> float2
|
||||
return half2_to_float2(tmp);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template<>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template<>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||
{
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
// half -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __mt_bfloat16>(const __mt_bfloat16& a)
|
||||
{
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__mt_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||
{
|
||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template<>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||
{
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||
{
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||
{
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||
{
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ __mt_bfloat162 vec_conversion<__mt_bfloat162, float2>(const float2 &a) {
|
||||
__mt_bfloat162 b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
|
||||
bf16_4_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
|
||||
bf16_8_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
} // namespace fp8_e5m2_unscaled
|
||||
#endif // ENABLE_FP8_E5M2
|
||||
} // namespace vllm
|
||||
64
csrc_musa/quantization/gptq/compat.muh
Normal file
64
csrc_musa/quantization/gptq/compat.muh
Normal file
@@ -0,0 +1,64 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _compat_cuh
|
||||
#define _compat_cuh
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__MUSA_ARCH__) || defined(USE_ROCM)
|
||||
#if __MUSA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __MUSA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
#endif
|
||||
274
csrc_musa/quantization/gptq/matrix_view.muh
Normal file
274
csrc_musa/quantization/gptq/matrix_view.muh
Normal file
@@ -0,0 +1,274 @@
|
||||
/*
|
||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
|
||||
*/
|
||||
|
||||
#ifndef _matrix_view_cuh
|
||||
#define _matrix_view_cuh
|
||||
|
||||
#include <musa_runtime.h>
|
||||
#include <musa_fp16.h>
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||
{
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||
};
|
||||
|
||||
class MatrixView_q2_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q2_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x0f) * 2;
|
||||
return (data[row * width / 16 + column / 16] >> shift) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x0f) * 2;
|
||||
uint32_t d = data[row * width / 16 + column / 16] >> shift;
|
||||
items[0] = d & 0x03;
|
||||
items[1] = (d >> 2) & 0x03;
|
||||
items[2] = (d >> 4) & 0x03;
|
||||
items[3] = (d >> 6) & 0x03;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q3_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q3_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int z_w = column * 3 / 32;
|
||||
int z_mod = column & 0x1f;
|
||||
|
||||
if (z_mod == 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 30) | ((data[row * width * 3 / 32 + (z_w + 1)] << 2) & 0x4);
|
||||
} else if (z_mod == 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> 31) | ((data[row * width * 3 / 32 + (z_w + 1)] << 1) & 0x6);
|
||||
} else if (z_mod < 10) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3)) & 0x07;
|
||||
} else if (z_mod < 21) {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 32)) & 0x07;
|
||||
} else {
|
||||
return (data[row * width * 3 / 32 + z_w] >> (z_mod * 3 - 64)) & 0x07;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x1f);
|
||||
uint32_t d;
|
||||
if (shift <= 4) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3);
|
||||
} else if (shift == 8) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 24) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0x0f) << 8);
|
||||
} else if (shift <= 16) {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 32);
|
||||
} else if (shift == 20) {
|
||||
d = (data[row * width / 32 * 3 + column * 3 / 32] >> 28) | ((data[row * width / 32 * 3 + column * 3 / 32 + 1] & 0xff) << 4);
|
||||
} else {
|
||||
d = data[row * width / 32 * 3 + column * 3 / 32] >> (shift * 3 - 64);
|
||||
}
|
||||
items[0] = d & 0x07;
|
||||
items[1] = (d >> 3) & 0x07;
|
||||
items[2] = (d >> 6) & 0x07;
|
||||
items[3] = (d >> 9) & 0x07;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q8_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q8_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x03) * 8;
|
||||
return (data[row * width / 4 + column / 4] >> shift) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x03) * 8;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x03) * 2;
|
||||
uint32_t d = data[row * width / 4 + column / 4] >> shift;
|
||||
items[0] = d & 0xff;
|
||||
items[1] = (d >> 8) & 0xff;
|
||||
items[2] = (d >> 16) & 0xff;
|
||||
items[3] = (d >> 24) & 0xff;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
#endif
|
||||
2075
csrc_musa/quantization/gptq/q_gemm.mu
Normal file
2075
csrc_musa/quantization/gptq/q_gemm.mu
Normal file
File diff suppressed because it is too large
Load Diff
87
csrc_musa/quantization/gptq/qdq_2.muh
Normal file
87
csrc_musa/quantization/gptq/qdq_2.muh
Normal file
@@ -0,0 +1,87 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_2_cuh
|
||||
#define _qdq_2_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// ffddbb99 77553311 eeccaa88 66442200
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x03;
|
||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||
qa >>= 4;
|
||||
qb |= (qa1 << (i * 2 + 16));
|
||||
qb |= (qa0 << (i * 2));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[8],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y4 = __halves2half2(y4_, y4_);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z4_ = __hsub(__int2half_rn(-256), __int2half_rn(zero));
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z4 = __half2half2(z4_);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
const half2 z64 = __half2half2(z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||
dq[4] = __hadd2(q4.as_half2, z1);
|
||||
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
#endif
|
||||
141
csrc_musa/quantization/gptq/qdq_3.muh
Normal file
141
csrc_musa/quantization/gptq/qdq_3.muh
Normal file
@@ -0,0 +1,141 @@
|
||||
#ifndef _qdq_3_cuh
|
||||
#define _qdq_3_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
// Permutation:
|
||||
//
|
||||
// v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
|
||||
// qa: aa999888 77766655 54443332 22111000
|
||||
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||
|
||||
uint32_t qd = qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: ..999888 77766655 54443332 22111000
|
||||
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
||||
|
||||
// za: 9997775 55333111 8886664 44222000
|
||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
za |= ((qd & 0x01) >> 0) << 15;
|
||||
zb |= ((qd & 0x02) >> 1) << 15;
|
||||
zc |= ((qd & 0x04) >> 2) << 15;
|
||||
za |= ((qd & 0x08) >> 3) << 31;
|
||||
zb |= ((qd & 0x10) >> 4) << 31;
|
||||
zc |= ((qd & 0x20) >> 5) << 31;
|
||||
|
||||
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y8 = __halves2half2(y8_, y8_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z8_ = __hsub(__int2half_rn(-128), __int2half_rn(zero));
|
||||
const half z64_ = __hsub(__int2half_rn(-16), __int2half_rn(zero));
|
||||
const half2 z1 = __halves2half2(z1_.as_half, z1_.as_half);
|
||||
const half2 z8 = __halves2half2(z8_, z8_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
|
||||
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||
qa >>= 6;
|
||||
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||
qa >>= 9;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||
qb >>= 6;
|
||||
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||
qb >>= 8;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||
qc >>= 6;
|
||||
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||
qc >>= 7;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q15((qa | qb | qc) | c0);
|
||||
|
||||
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
|
||||
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
|
||||
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
|
||||
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
|
||||
dq[ 7] = __hadd2( q7.as_half2, z1);
|
||||
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
|
||||
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
|
||||
dq[10] = __hadd2(q10.as_half2, z1);
|
||||
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
#endif
|
||||
147
csrc_musa/quantization/gptq/qdq_4.muh
Normal file
147
csrc_musa/quantization/gptq/qdq_4.muh
Normal file
@@ -0,0 +1,147 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_4_cuh
|
||||
#define _qdq_4_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
// Permutation:
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half_uint16 z1_(0xe400 | zero); // half(-1024.0f - zero);
|
||||
const half z16_ = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
const half2 z1 = __half2half2(z1_.as_half);
|
||||
const half2 z16 = __half2half2(z16_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1z16)[2],
|
||||
half2(&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
#endif
|
||||
40
csrc_musa/quantization/gptq/qdq_8.muh
Normal file
40
csrc_musa/quantization/gptq/qdq_8.muh
Normal file
@@ -0,0 +1,40 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_8_cuh
|
||||
#define _qdq_8_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
__forceinline__ __device__ void shuffle_8bit_4
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_8bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
half2 (&dq)[4],
|
||||
int stride,
|
||||
const uint32_t zero
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), zero);
|
||||
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), zero);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
#endif
|
||||
60
csrc_musa/quantization/gptq/qdq_util.muh
Normal file
60
csrc_musa/quantization/gptq/qdq_util.muh
Normal file
@@ -0,0 +1,60 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_util_cuh
|
||||
#define _qdq_util_cuh
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
union half2_uint32
|
||||
{
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
};
|
||||
|
||||
union half_uint16
|
||||
{
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||
{
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||
{
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||
{
|
||||
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||
{
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||
{
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
#endif
|
||||
1722
csrc_musa/quantization/gptq_marlin/gptq_marlin.mu
Normal file
1722
csrc_musa/quantization/gptq_marlin/gptq_marlin.mu
Normal file
File diff suppressed because it is too large
Load Diff
70
csrc_musa/quantization/gptq_marlin/gptq_marlin.muh
Normal file
70
csrc_musa/quantization/gptq_marlin/gptq_marlin.muh
Normal file
@@ -0,0 +1,70 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "torch_musa/csrc/aten/musa/MUSAContext.h"
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
#include <musa.h>
|
||||
#include <musa_fp16.h>
|
||||
#include <musa_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per
|
||||
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have
|
||||
// many registers per warp and small tiles.
|
||||
static constexpr int default_threads = 256;
|
||||
|
||||
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
|
||||
using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); }
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace gptq_marlin
|
||||
352
csrc_musa/quantization/gptq_marlin/gptq_marlin_repack.mu
Normal file
352
csrc_musa/quantization/gptq_marlin/gptq_marlin_repack.mu
Normal file
@@ -0,0 +1,352 @@
|
||||
#include "gptq_marlin.cuh"
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
static constexpr int repack_stages = 8;
|
||||
|
||||
static constexpr int repack_threads = 256;
|
||||
|
||||
static constexpr int tile_k_size = tile_size;
|
||||
static constexpr int tile_n_size = tile_k_size * 4;
|
||||
|
||||
#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void
|
||||
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
||||
uint32_t const *__restrict__ perm_ptr,
|
||||
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void
|
||||
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
||||
uint32_t const *__restrict__ perm_ptr,
|
||||
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int perm_size = tile_k_size / 4;
|
||||
|
||||
int4 *sh_perm_ptr = sh;
|
||||
int4 *sh_pipe_ptr = sh_perm_ptr;
|
||||
if constexpr (has_perm) {
|
||||
sh_pipe_ptr += perm_size;
|
||||
}
|
||||
|
||||
constexpr int tile_ints = tile_k_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_size / 4;
|
||||
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto load_perm_to_shared = [&](int k_tile_id) {
|
||||
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
||||
|
||||
int4 const *perm_int4_ptr = reinterpret_cast<int4 const *>(perm_ptr);
|
||||
|
||||
if (threadIdx.x < perm_size) {
|
||||
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
|
||||
int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
uint32_t const *sh_perm_int_ptr =
|
||||
reinterpret_cast<uint32_t const *>(sh_perm_ptr);
|
||||
|
||||
int src_k = sh_perm_int_ptr[k_id];
|
||||
int src_k_packed = src_k / pack_factor;
|
||||
|
||||
cp_async4(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const *>(&(
|
||||
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
||||
}
|
||||
|
||||
} else {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor;
|
||||
|
||||
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const *>(
|
||||
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
|
||||
first_n + (n_id * 4)])));
|
||||
}
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
|
||||
constexpr int sh_stride = 64;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);
|
||||
|
||||
uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);
|
||||
|
||||
uint32_t vals[8];
|
||||
|
||||
if constexpr (has_perm) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int k_idx = tc_row + tc_offsets[i];
|
||||
|
||||
uint32_t src_k = sh_perm_int_ptr[k_idx];
|
||||
uint32_t src_k_pos = src_k % pack_factor;
|
||||
|
||||
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
|
||||
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
|
||||
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
|
||||
|
||||
vals[i] = b1_cur_val;
|
||||
vals[4 + i] = b2_cur_val;
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
uint32_t b1_vals[tile_ints];
|
||||
uint32_t b2_vals[tile_ints];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_ints; i++) {
|
||||
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
|
||||
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
int cur_int = cur_elem / pack_factor;
|
||||
int cur_pos = cur_elem % pack_factor;
|
||||
|
||||
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
load_perm_to_shared(k_tile_id);
|
||||
}
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||
n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
musaFuncSetAttribute( \
|
||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
|
||||
NUM_BITS, HAS_PERM>, \
|
||||
musaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM> \
|
||||
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
|
||||
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
", size_k = ", size_k, ", pack_factor = ", pack_factor);
|
||||
TORCH_CHECK(b_q_weight.size(1) == size_n,
|
||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
" is not size_n = ", size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
torch::Tensor out =
|
||||
torch::empty({size_k / gptq_marlin::tile_size,
|
||||
size_n * gptq_marlin::tile_size / pack_factor},
|
||||
options);
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.size(0) != 0;
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const *b_q_weight_ptr =
|
||||
reinterpret_cast<uint32_t const *>(b_q_weight.data_ptr());
|
||||
uint32_t const *perm_ptr =
|
||||
reinterpret_cast<uint32_t const *>(perm.data_ptr());
|
||||
uint32_t *out_ptr = reinterpret_cast<uint32_t *>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
musaStream_t stream = at::cuda::getCurrentMUSAStream(dev);
|
||||
int blocks;
|
||||
musaDeviceGetAttribute(&blocks, musaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
musaDeviceGetAttribute(&max_shared_mem,
|
||||
musaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4, false)
|
||||
CALL_IF(4, true)
|
||||
CALL_IF(8, false)
|
||||
CALL_IF(8, true)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
|
||||
", has_perm = ", has_perm);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
#endif
|
||||
209
csrc_musa/quantization/marlin/.LICENSE
Normal file
209
csrc_musa/quantization/marlin/.LICENSE
Normal file
@@ -0,0 +1,209 @@
|
||||
Contains code from https://github.com/IST-DASLab/marlin
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
|
||||
This product bundles various third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses. See licenses/
|
||||
for text of these licenses.
|
||||
1138
csrc_musa/quantization/marlin/marlin_cuda_kernel.mu
Normal file
1138
csrc_musa/quantization/marlin/marlin_cuda_kernel.mu
Normal file
File diff suppressed because it is too large
Load Diff
225
csrc_musa/quantization/squeezellm/quant_cuda_kernel.mu
Normal file
225
csrc_musa/quantization/squeezellm/quant_cuda_kernel.mu
Normal file
@@ -0,0 +1,225 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/python.h>
|
||||
#include <musa.h>
|
||||
#include <musa_runtime.h>
|
||||
#include <musa_fp16.h>
|
||||
|
||||
// half-tensor
|
||||
#include "torch_musa/csrc/core/MUSAStream.h"
|
||||
#include <ATen/musa/MUSA_PORT_TensorMethods.muh>
|
||||
#include "torch_musa/csrc/core/MUSAGuard.h"
|
||||
|
||||
#define BLOCKWIDTH 128
|
||||
#define BLOCKHEIGHT4 16
|
||||
|
||||
namespace vllm {
|
||||
namespace squeezellm {
|
||||
|
||||
__device__ inline unsigned int as_unsigned(int i) {
|
||||
return *reinterpret_cast<unsigned int*>(&i);
|
||||
}
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
__global__ void NUQ4MatMulKernel(
|
||||
#ifndef USE_ROCM
|
||||
const half2* __restrict__ vec,
|
||||
#else
|
||||
const __half2* __restrict__ vec,
|
||||
#endif
|
||||
const int* __restrict__ mat,
|
||||
#ifndef USE_ROCM
|
||||
half2* __restrict__ mul,
|
||||
#else
|
||||
float2* __restrict__ mul,
|
||||
#endif
|
||||
const __half* __restrict__ lookup_table,
|
||||
int height,
|
||||
int width,
|
||||
int batch,
|
||||
int vec_height
|
||||
) {
|
||||
|
||||
const int blockwidth2 = BLOCKWIDTH / 2;
|
||||
|
||||
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
__shared__ half2 blockvec[blockwidth2];
|
||||
#else
|
||||
__shared__ __half2 blockvec[blockwidth2];
|
||||
#endif
|
||||
|
||||
__shared__ __half deq2[16][BLOCKWIDTH];
|
||||
int off = threadIdx.x;
|
||||
int column_offset = col * 16;
|
||||
for (int val = 0; val < 16; val += 1) {
|
||||
int lut_index = column_offset + val;
|
||||
deq2[val][off] = lookup_table[lut_index];
|
||||
}
|
||||
|
||||
__half res;
|
||||
#ifndef USE_ROCM
|
||||
half2 res2;
|
||||
half2 tmp2;
|
||||
#else
|
||||
__half2 res2;
|
||||
__half2 tmp2;
|
||||
#endif
|
||||
|
||||
int i;
|
||||
int k;
|
||||
|
||||
unsigned int tmp1;
|
||||
unsigned int lut_index1, lut_index2;
|
||||
|
||||
for (int b = 0; b < batch; ++b){
|
||||
i = width * row + col;
|
||||
res = __int2half_rd(0);
|
||||
k = 0;
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x < blockwidth2)
|
||||
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
while (k < blockwidth2) {
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res2 = {};
|
||||
tmp2 = {};
|
||||
#else
|
||||
res2.x = __half_as_ushort(__float2half(0));
|
||||
res2.y = __half_as_ushort(__float2half(0));
|
||||
tmp2.x = __half_as_ushort(__float2half(0));
|
||||
tmp2.y = __half_as_ushort(__float2half(0));
|
||||
#endif
|
||||
|
||||
lut_index1 = tmp1 & 0xF;
|
||||
lut_index2 = (tmp1 >> 4) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 8) & 0xF;
|
||||
lut_index2 = (tmp1 >> 12) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 16) & 0xF;
|
||||
lut_index2 = (tmp1 >> 20) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 24) & 0xF;
|
||||
lut_index2 = (tmp1 >> 28) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||
#else
|
||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
|
||||
#endif
|
||||
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
|
||||
// col%2 -> only set one of the two values
|
||||
#ifndef USE_ROCM
|
||||
half2 res3 = {};
|
||||
if (col % 2 == 0) {
|
||||
res3.x = res;
|
||||
} else {
|
||||
res3.y = res;
|
||||
}
|
||||
#else
|
||||
__half2 res3;
|
||||
res3.x = __half_as_ushort(__float2half(0));
|
||||
res3.y = __half_as_ushort(__float2half(0));
|
||||
if (col % 2 == 0) {
|
||||
res3.x = __half_as_ushort(res);
|
||||
} else {
|
||||
res3.y = __half_as_ushort(res);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
atomicAdd(&mul[b * width / 2 + col / 2], res3);
|
||||
#else
|
||||
int tmp_addr = b * width / 2 + col / 2;
|
||||
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
|
||||
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace squeezellm
|
||||
} // namespace vllm
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table
|
||||
) {
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
const at::musa::OptionalMUSAGuard device_guard(device_of(vec));
|
||||
const musaStream_t stream = at::musa::getCurrentMUSAStream();
|
||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
|
||||
#ifndef USE_ROCM
|
||||
(half2*) vec.data<at::Half>(),
|
||||
#else
|
||||
(__half2*) vec.data_ptr<at::Half>(),
|
||||
#endif
|
||||
mat.data_ptr<int>(),
|
||||
#ifndef USE_ROCM
|
||||
(half2*) mul.data<at::Half>(),
|
||||
(__half*) lookup_table.data<at::Half>(),
|
||||
#else
|
||||
(float2*) mul.data_ptr<float>(),
|
||||
(__half*) lookup_table.data_ptr<at::Half>(),
|
||||
#endif
|
||||
height, width, batch, vec_height
|
||||
);
|
||||
}
|
||||
|
||||
#undef BLOCKWIDTH
|
||||
#undef BLOCKHEIGHT4
|
||||
66
csrc_musa/reduction_utils.muh
Normal file
66
csrc_musa/reduction_utils.muh
Normal file
@@ -0,0 +1,66 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
||||
* Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "musa_compat.h"
|
||||
|
||||
namespace vllm {
|
||||
template<typename T, int numLanes = WARP_SIZE>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
||||
"numLanes is not a positive power of 2!");
|
||||
static_assert(numLanes <= WARP_SIZE);
|
||||
#pragma unroll
|
||||
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||
return val;
|
||||
}
|
||||
|
||||
// Helper function to return the next largest power of 2
|
||||
static constexpr int _nextPow2(unsigned int num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template<typename T, int maxBlockSize = 1024>
|
||||
__inline__ __device__ T blockReduceSum(T val) {
|
||||
static_assert(maxBlockSize <= 1024);
|
||||
if constexpr (maxBlockSize > WARP_SIZE) {
|
||||
val = warpReduceSum<T>(val);
|
||||
// Calculates max number of lanes that need to participate in the last warpReduce
|
||||
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
||||
static __shared__ T shared[maxActiveLanes];
|
||||
int lane = threadIdx.x % WARP_SIZE;
|
||||
int wid = threadIdx.x / WARP_SIZE;
|
||||
if (lane == 0)
|
||||
shared[wid] = val;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] : (T)(0.0f);
|
||||
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
|
||||
} else {
|
||||
// A single warpReduce is equal to blockReduce
|
||||
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
Reference in New Issue
Block a user