Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,147 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <torch/all.h>
// _dyn_quant_matmul_4bit is only available on AArch64.
#if defined(__aarch64__)
#include <ATen/ops/_dyn_quant_matmul_4bit.h>
#endif
inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w,
int64_t group_size_eff, int64_t in_features,
int64_t out_features) {
#if defined(__aarch64__)
return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff,
in_features, out_features);
#else
TORCH_CHECK(false,
"dynamic 4-bit int MoE path requires AArch64 (ARM64); "
"_dyn_quant_matmul_4bit is unavailable on this architecture");
return {};
#endif
}
enum ActivationKind : int64_t {
SwiGLU_Gu = 0, // act = SiLU(g) * u
SwiGLUOAI = 1, // act = SiLU(u) * g
SiLU = 2 // SiLU
};
torch::Tensor dynamic_4bit_int_moe_cpu(
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
int64_t activation_kind) {
TORCH_CHECK(x.dim() == 2, "x must be 2D");
TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2,
"topk tensors must be [T, K]");
TORCH_CHECK(
w13_packed.size(0) == w2_packed.size(0),
"w13_packed and w2_packed must have same number of experts in dim 0");
TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I");
const int64_t T = x.size(0);
const int64_t K = topk_ids.size(1);
const int64_t E = w13_packed.size(0);
const int64_t N = T * K;
auto x_c = x.contiguous();
auto ids_c = topk_ids.contiguous();
auto gates_c = topk_weights.to(at::kFloat).contiguous();
// bucketing tokens -> experts
c10::SmallVector<int64_t, 64> counts(
E, 0); // Small vector uses stack allocation
{
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
for (int64_t i = 0; i < N; ++i) {
const int64_t e_id = ids_ptr[i];
TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range");
counts[e_id]++;
}
}
c10::SmallVector<int64_t, 65> offsets(E + 1, 0); // ( E +1 )
for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e];
auto expert_tokens = at::empty({offsets[E]}, ids_c.options());
auto expert_gates = at::empty({offsets[E]}, gates_c.options());
{
c10::SmallVector<int64_t, 64> cursor(E, 0);
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
const auto* gts_ptr = gates_c.data_ptr<float>();
auto* tok_ptr = expert_tokens.data_ptr<int64_t>();
auto* gate_ptr = expert_gates.data_ptr<float>();
for (int64_t t = 0; t < T; ++t) {
const int64_t base = t * K;
for (int64_t k = 0; k < K; ++k) {
const int64_t idx = base + k;
const int64_t e = ids_ptr[idx];
const int64_t p = offsets[e] + (cursor[e]++);
tok_ptr[p] = t;
gate_ptr[p] = gts_ptr[idx];
}
}
}
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
auto X_all = x_c.index_select(/*dim=*/0, expert_tokens);
if (apply_router_weight_on_input) {
X_all = X_all.mul(expert_gates.unsqueeze(1));
}
auto Y_all = at::empty({offsets[E], H}, x_c.options());
at::parallel_for(0, offsets[E], 0, [&](int64_t idx_begin, int64_t idx_end) {
c10::InferenceMode guard;
for (int64_t e = 0; e < E; ++e) {
int64_t start = std::max(offsets[e], idx_begin);
int64_t end = std::min(offsets[e + 1], idx_end);
int64_t te = end - start;
if (te <= 0) {
continue;
}
auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto w13_e = w13_packed.select(/*dim=*/0, e);
auto w2_e = w2_packed.select(/*dim=*/0, e);
// W13
auto y13 =
mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2);
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
torch::Tensor act;
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
constexpr double kAlpha = 1.702; // GPT-OSS default
constexpr double kLimit = 7.0; // GPT-OSS default
auto gate_c = at::clamp_max(g_part, kLimit);
auto up_c = at::clamp(u_part, -kLimit, kLimit);
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
act = up_c.add(1.0).mul(glu);
} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
act = at::silu(g_part).mul(u_part);
}
// W2
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
// Store per-expert result
Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y);
}
});
if (!apply_router_weight_on_input) {
Y_all = Y_all.mul(expert_gates.unsqueeze(1));
}
auto out = at::zeros({T, H}, x.options());
out =
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
return out;
}

View File

@@ -0,0 +1,891 @@
/*
* Adapted from
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
* Copyright (c) 2025, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda/std/limits>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;
namespace vllm {
namespace moe {
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
constexpr int32_t WARP_SIZE = 32;
constexpr int32_t BLOCK_SIZE = 512;
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
namespace warp_topk {
template <int size, typename T>
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
if (len == 0) {
return 0;
}
return ((len - 1) / size + 1) * size;
}
template <typename T>
constexpr __host__ __device__ bool isPowerOf2(T v) {
return (v && !(v & (v - 1)));
}
template <bool greater, typename T>
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
return (val > baseline && greater) || (val < baseline && !greater);
}
template <bool greater, typename T, typename idxT>
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
idxT baseline_index) {
bool res = (val > baseline && greater) || (val < baseline && !greater);
if (val == baseline) {
res = (index < baseline_index && greater) ||
(index < baseline_index && !greater);
}
return res;
}
template <typename T, typename idxT>
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
return max(cache_topk,
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
}
template <int size, bool ascending, bool reverse, typename T, typename idxT,
bool is_stable>
struct BitonicMerge {
// input should be a bitonic sequence, and sort it to be a monotonic sequence
__device__ static void merge(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
static_assert(isPowerOf2(size));
static_assert(size >= 2 * WARP_SIZE);
constexpr int arr_len = size / WARP_SIZE;
constexpr int stride = arr_len / 2;
for (int i = 0; i < stride; ++i) {
int const other_i = i + stride;
T& val = val_arr[i];
T& other_val = val_arr[other_i];
bool is_better;
if constexpr (is_stable) {
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
idx_arr[other_i]);
} else {
is_better = is_better_than<ascending>(val, other_val);
}
if (is_better) {
T tmp = val;
val = other_val;
other_val = tmp;
idxT tmp2 = idx_arr[i];
idx_arr[i] = idx_arr[other_i];
idx_arr[other_i] = tmp2;
}
}
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
val_arr, idx_arr);
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
val_arr + arr_len / 2, idx_arr + arr_len / 2);
}
};
template <int size, bool ascending, typename T, typename idxT, bool is_stable>
struct BitonicSort {
__device__ static void sort(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
static_assert(isPowerOf2(size));
static_assert(size >= 2 * WARP_SIZE);
constexpr int arr_len = size / WARP_SIZE;
BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
val_arr + arr_len / 2, idx_arr + arr_len / 2);
BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
val_arr, idx_arr);
}
};
template <bool ascending, typename T, typename idxT, bool is_stable>
struct BitonicSort<32, ascending, T, idxT, is_stable> {
__device__ static void sort(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
int const lane = threadIdx.x % WARP_SIZE;
// ascending doesn't matter before merging since all we need is a bitonic
// sequence
for (int stage = 0; stage < 4; ++stage) {
for (int stride = (1 << stage); stride > 0; stride /= 2) {
bool reverse = (lane >> stage) & 2;
bool is_second = lane & stride;
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
bool is_better;
if constexpr (is_stable) {
if constexpr (ascending) {
is_better = ((*val_arr > other) ||
((*val_arr == other) && (*idx_arr < other_idx))) !=
(reverse != is_second);
} else {
is_better = ((*val_arr > other) ||
((*val_arr == other) && (*idx_arr > other_idx))) !=
(reverse != is_second);
}
} else {
is_better = (*val_arr != other &&
(*val_arr > other) != (reverse != is_second));
}
if (is_better) {
*val_arr = other;
*idx_arr = other_idx;
}
}
}
BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
idx_arr);
}
};
template <bool ascending, bool reverse, typename T, typename idxT,
bool is_stable>
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
__device__ static void merge(T* __restrict__ val_arr,
idxT* __restrict__ idx_arr) {
int const lane = threadIdx.x % WARP_SIZE;
for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
bool is_second = lane & stride;
T& val = *val_arr;
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
idxT& idx = *idx_arr;
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
bool is_better;
if constexpr (is_stable) {
if constexpr (ascending) {
is_better = ((*val_arr > other) ||
((*val_arr == other) && (*idx_arr < other_idx))) ==
(reverse != is_second); // for min
} else {
is_better = ((*val_arr > other) ||
((*val_arr == other) && (*idx_arr > other_idx))) ==
(reverse != is_second); // for max
}
} else {
is_better =
(val != other && ((val > other) == (ascending != is_second)));
}
if (is_better) {
val = other;
idx = other_idx;
}
}
}
};
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
class WarpSort {
public:
__device__ WarpSort(idxT k, T dummy)
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
for (int i = 0; i < max_arr_len_; ++i) {
val_arr_[i] = dummy_;
idx_arr_[i] = 0;
}
}
// load and merge k sorted values
__device__ void load_sorted(T const* __restrict__ in,
idxT const* __restrict__ in_idx, idxT start) {
idxT idx = start + WARP_SIZE - 1 - lane_;
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
if (idx < start + k_) {
T t = in[idx];
bool is_better;
if constexpr (is_stable) {
is_better =
is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
} else {
is_better = is_better_than<greater>(t, val_arr_[i]);
}
if (is_better) {
val_arr_[i] = t;
idx_arr_[i] = in_idx[idx];
}
}
}
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
val_arr_, idx_arr_);
}
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
for (int i = 0; i < max_arr_len_; ++i) {
idxT out_i = i * WARP_SIZE + lane_;
if (out_i < k_) {
out[out_i] = val_arr_[i];
out_idx[out_i] = idx_arr_[i];
}
}
}
__device__ void dumpIdx(idxT* __restrict__ out_idx) const {
for (int i = 0; i < max_arr_len_; ++i) {
idxT out_i = i * WARP_SIZE + lane_;
if (out_i < k_) {
out_idx[out_i] = idx_arr_[i];
}
}
}
protected:
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
T val_arr_[max_arr_len_];
idxT idx_arr_[max_arr_len_];
int const lane_;
idxT const k_;
T const dummy_;
}; // end class WarpSort
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
public:
__device__ WarpSelect(idxT k, T dummy)
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
k_th_(dummy),
k_th_lane_((k - 1) % WARP_SIZE) {
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
int const num_of_warp = blockDim.x / WARP_SIZE;
int const warp_id = threadIdx.x / WARP_SIZE;
val_smem_ = reinterpret_cast<T*>(smem_buf);
val_smem_ += warp_id * WARP_SIZE;
idx_smem_ = reinterpret_cast<idxT*>(
smem_buf +
round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
idx_smem_ += warp_id * WARP_SIZE;
}
__device__ void add(T const* in, idxT start, idxT end) {
idxT const end_for_fullwarp =
round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
T val = (i < end) ? in[i] : dummy_;
add(val, i);
}
}
__device__ void add(T val, idxT idx) {
bool do_add;
if constexpr (is_stable) {
do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
} else {
do_add = is_better_than<greater>(val, k_th_);
}
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
if (mask == 0) {
return;
}
int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
if (do_add && pos < WARP_SIZE) {
val_smem_[pos] = val;
idx_smem_[pos] = idx;
do_add = false;
}
smem_buf_len_ += __popc(mask);
if (smem_buf_len_ >= WARP_SIZE) {
__syncwarp();
merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
smem_buf_len_ -= WARP_SIZE;
}
if (do_add) {
pos -= WARP_SIZE;
val_smem_[pos] = val;
idx_smem_[pos] = idx;
}
__syncwarp();
}
__device__ void done() {
if (smem_buf_len_) {
T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
merge_buf_(val, idx);
}
// after done(), smem is used for merging results among warps
__syncthreads();
}
private:
__device__ void set_k_th_() {
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
if constexpr (is_stable) {
k_th_idx_ =
__shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
}
}
__device__ void merge_buf_(T val, idxT idx) {
BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);
T& old = val_arr_[max_arr_len_ - 1];
bool is_better;
if constexpr (is_stable) {
is_better =
is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
} else {
is_better = is_better_than<greater>(val, old);
}
if (is_better) {
old = val;
idx_arr_[max_arr_len_ - 1] = idx;
}
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
val_arr_, idx_arr_);
set_k_th_();
}
using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;
T* val_smem_;
idxT* idx_smem_;
int smem_buf_len_ = 0;
T k_th_;
idxT k_th_idx_;
int const k_th_lane_;
}; // end class WarpSelect
} // namespace warp_topk
template <typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val) {
return val;
}
template <>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
return __bfloat162float(val);
}
template <typename T>
__device__ inline T neg_inf() {
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
// so we need to cast from fp32
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
}
template <typename T>
__device__ inline bool is_finite(const T val) {
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
return cuda::std::isfinite(val);
#else
return isfinite(cuda_cast<float, T>(val));
#endif
}
// Scoring function enums
enum ScoringFunc {
SCORING_NONE = 0, // no activation function
SCORING_SIGMOID = 1 // apply sigmoid
};
// Efficient sigmoid approximation from TensorRT-LLM
__device__ inline float sigmoid_accurate(float x) {
return 0.5f * tanhf(0.5f * x) + 0.5f;
}
template <typename T>
__device__ inline T apply_sigmoid(T val) {
float f = cuda_cast<float, T>(val);
return cuda_cast<T, float>(sigmoid_accurate(f));
}
template <ScoringFunc SF, typename T>
__device__ inline T apply_scoring(T val) {
if constexpr (SF == SCORING_SIGMOID) {
return apply_sigmoid(val);
} else {
return val;
}
}
template <typename T, ScoringFunc SF>
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
cg::thread_block_tile<32> const& tile,
int32_t const lane_id,
int const num_experts_per_group) {
// Get the top2 per thread
T largest = neg_inf<T>();
T second_largest = neg_inf<T>();
if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = apply_scoring<SF>(input[i]);
value = value + bias[i];
if (value > largest) {
second_largest = largest;
largest = value;
} else if (value > second_largest) {
second_largest = value;
}
}
} else {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = apply_scoring<SF>(input[i]);
value = value + bias[i];
largest = value;
}
}
// Get the top2 warpwise
T max1 = cg::reduce(tile, largest, cg::greater<T>());
T max2 = max1;
bool equal_to_max1 = (max1 == largest);
int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));
if (count_max1 == 1) {
largest = (largest == max1) ? second_largest : largest;
max2 = cg::reduce(tile, largest, cg::greater<T>());
}
if (lane_id == 0) {
*output = max1 + max2;
}
}
template <typename T, ScoringFunc SF>
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
int64_t const num_tokens,
int64_t const num_cases,
int64_t const n_group,
int64_t const num_experts_per_group) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
if (case_id < num_cases) {
input += case_id * num_experts_per_group;
// bias is per expert group, offset to current group
int32_t group_id = case_id % n_group;
T const* group_bias = bias + group_id * num_experts_per_group;
output += case_id;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
topk_with_k2<T, SF>(output, input, group_bias, tile, lane_id,
num_experts_per_group);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, typename IdxT, ScoringFunc SF, int NGroup = -1>
__global__ void group_idx_and_topk_idx_kernel(
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
T const* bias, int64_t const num_tokens, int64_t const n_group,
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool renormalize,
double routed_scaling_factor) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id =
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
scores += case_id * num_experts;
group_scores += case_id * n_group;
topk_values += case_id * topk;
topk_indices += case_id * topk;
constexpr bool kUseStaticNGroup = (NGroup > 0);
// use int32 to avoid implicit conversion
int32_t const n_group_i32 =
kUseStaticNGroup ? NGroup : static_cast<int32_t>(n_group);
int32_t align_num_experts_per_group =
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
// store the target topk idx
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
T* s_topk_value =
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
warp_id * topk;
s_topk_idx += warp_id * topk;
T value = neg_inf<T>();
T topk_group_value = neg_inf<T>();
int32_t num_equalto_topkth_group;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
// acqbulk because it's ptr arithmetic
#endif
if (case_id < num_tokens) {
// calculate group_idx
int32_t target_num_min =
WARP_SIZE - n_group_i32 + static_cast<int32_t>(topk_group);
// The check is necessary to avoid abnormal input
if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) {
value = group_scores[lane_id];
}
int count_equal_to_top_value = WARP_SIZE - n_group_i32;
int pre_count_equal_to_top_value = 0;
// Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) {
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) {
value = neg_inf<T>();
}
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value =
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
}
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
}
__syncthreads();
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
/* is_stable */ true>
queue((int32_t)topk, neg_inf<T>());
int count_equalto_topkth_group = 0;
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
if (case_id < num_tokens && if_proceed_next_topk) {
auto process_group = [&](int i_group) {
if ((group_scores[i_group] > topk_group_value) ||
((group_scores[i_group] == topk_group_value) &&
(count_equalto_topkth_group < num_equalto_topkth_group))) {
int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates = neg_inf<T>();
if (i < num_experts_per_group) {
// apply scoring function (if any) and add bias
T input = scores[offset + i];
if (is_finite(input)) {
T score = apply_scoring<SF>(input);
candidates = score + bias[offset + i];
}
}
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
count_equalto_topkth_group++;
}
}
};
if constexpr (kUseStaticNGroup) {
#pragma unroll
for (int i_group = 0; i_group < NGroup; ++i_group) {
process_group(i_group);
}
} else {
for (int i_group = 0; i_group < n_group_i32; ++i_group) {
process_group(i_group);
}
}
queue.done();
// Get the topk_idx
queue.dumpIdx(s_topk_idx);
}
// Load the valid score value
// Calculate the summation
float topk_sum = 1e-20;
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i = lane_id;
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
i += WARP_SIZE) {
T value = cuda_cast<T, float>(0.0f);
if (i < topk) {
// Load the score value (without bias) for normalization
T input = scores[s_topk_idx[i]];
value = apply_scoring<SF>(input);
s_topk_value[i] = value;
}
if (renormalize) {
topk_sum +=
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
}
}
__syncthreads();
if (case_id < num_tokens) {
if (if_proceed_next_topk) {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float base = cuda_cast<float, T>(s_topk_value[i]);
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
: (base * routed_scaling_factor);
topk_indices[i] = s_topk_idx[i];
topk_values[i] = value;
}
} else {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
topk_indices[i] = i;
topk_values[i] = 1.0f / topk;
}
}
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
// default result.
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, typename IdxT, ScoringFunc SF>
inline void launch_group_idx_and_topk_kernel(
cudaLaunchConfig_t const& config, T* scores, T* group_scores,
float* topk_values, IdxT* topk_indices, T const* bias,
int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool const renormalize,
double const routed_scaling_factor) {
auto launch = [&](auto* kernel_instance2) {
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
topk_values, topk_indices, bias, num_tokens, n_group,
topk_group, topk, num_experts, num_experts_per_group,
renormalize, routed_scaling_factor);
};
switch (n_group) {
case 4: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 4>);
break;
}
case 8: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 8>);
break;
}
case 16: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 16>);
break;
}
case 32: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 32>);
break;
}
default: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF>);
break;
}
}
}
template <typename T, typename IdxT>
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
IdxT* topk_indices, T const* bias, int64_t const num_tokens,
int64_t const num_experts, int64_t const n_group,
int64_t const topk_group, int64_t const topk,
bool const renormalize, double const routed_scaling_factor,
int const scoring_func, bool enable_pdl = false,
cudaStream_t const stream = 0) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
cudaLaunchConfig_t config;
config.gridDim = topk_with_k2_num_blocks;
config.blockDim = BLOCK_SIZE;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
auto const sf = static_cast<ScoringFunc>(scoring_func);
int64_t const num_experts_per_group = num_experts / n_group;
auto launch_topk_with_k2 = [&](auto* kernel_instance1) {
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
num_tokens, num_cases, n_group, num_experts_per_group);
};
switch (sf) {
case SCORING_NONE: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_NONE>;
launch_topk_with_k2(kernel_instance1);
break;
}
case SCORING_SIGMOID: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_SIGMOID>;
launch_topk_with_k2(kernel_instance1);
break;
}
default:
// should be guarded by higher level checks.
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
}
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
size_t dynamic_smem_in_bytes =
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);
config.gridDim = topk_with_k_group_num_blocks;
config.blockDim = BLOCK_SIZE;
config.dynamicSmemBytes = dynamic_smem_in_bytes;
config.stream = stream;
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
switch (sf) {
case SCORING_NONE: {
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_NONE>(
config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor);
break;
}
case SCORING_SIGMOID: {
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_SIGMOID>(
config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor);
break;
}
default:
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
}
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>( \
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
T const* bias, int64_t const num_tokens, int64_t const num_experts, \
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
bool const renormalize, double const routed_scaling_factor, \
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, int32_t);
INSTANTIATE_NOAUX_TC(half, int32_t);
INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
} // end namespace moe
} // namespace vllm
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
torch::Tensor const& scores, int64_t n_group, int64_t topk_group,
int64_t topk, bool renormalize, double routed_scaling_factor,
torch::Tensor const& bias, int64_t scoring_func = 0) {
auto data_type = scores.scalar_type();
auto input_size = scores.sizes();
int64_t num_tokens = input_size[0];
int64_t num_experts = input_size[1];
TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor");
TORCH_CHECK(num_experts % n_group == 0,
"num_experts should be divisible by n_group");
TORCH_CHECK(n_group <= 32,
"n_group should be smaller than or equal to 32 for now");
TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE ||
scoring_func == vllm::moe::SCORING_SIGMOID,
"scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)");
torch::Tensor group_scores = torch::empty(
{num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
// Always output float32 for topk_values (eliminates Python-side conversion)
torch::Tensor topk_values = torch::empty(
{num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
torch::Tensor topk_indices = torch::empty(
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
switch (data_type) {
case torch::kFloat16:
// Handle Float16
vllm::moe::invokeNoAuxTc<half, int32_t>(
reinterpret_cast<half*>(scores.mutable_data_ptr()),
reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break;
case torch::kFloat32:
// Handle Float32
vllm::moe::invokeNoAuxTc<float, int32_t>(
reinterpret_cast<float*>(scores.mutable_data_ptr()),
reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break;
case torch::kBFloat16:
// Handle BFloat16
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens,
num_experts, n_group, topk_group, topk, renormalize,
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
break;
default:
// Handle other data types
throw std::invalid_argument(
"Invalid dtype, only supports float16, float32, and bfloat16");
break;
}
return {topk_values, topk_indices};
}

2
csrc/moe/marlin_moe_wna16/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
sm*_kernel_*.cu
kernel_selector.h

View File

@@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import glob
import itertools
import os
import subprocess
import sys
import jinja2
ARCHS = []
SUPPORT_FP8 = False
for arch in sys.argv[1].split(","):
arch = arch[: arch.index(".") + 2].replace(".", "")
arch = int(arch)
# only SM89 and SM120 fully support
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
# SM90 and SM100 can use this PTX, but its simulated
# with FP16 MMA, so it cannot achieve any acceleration.
if arch in [89, 120]:
SUPPORT_FP8 = True
FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
// clang-format off
""".lstrip()
FILE_HEAD = (
FILE_HEAD_COMMENT
+ """
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
"""
)
TEMPLATE = (
"template __global__ void Marlin<"
"{{a_type_id}}, "
"{{b_type_id}}, "
"{{c_type_id}}, "
"{{s_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{m_block_size_8}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{is_zp_float}}>"
"( MARLIN_KERNEL_PARAMS );"
)
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
QUANT_CONFIGS = [
# AWQ-INT4
{
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4
{
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# AWQ-INT8
{
"b_type": "kU8B128",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 0, 2, 4, 8],
},
# FP8
{
"b_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [-1, 8],
},
# NVFP4
{
"b_type": "kFE2M1f",
"s_type": "kFE4M3fn",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [1],
},
# MXFP4
{
"a_type": ["kBFloat16"],
"b_type": "kFE2M1f",
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": THREAD_M_BLOCKS,
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with INT8 activation
{
"a_type": ["kS8"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# GPTQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4B8",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# AWQ-INT4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kU4",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [-1, 2, 4, 8],
},
# MXFP4 with FP8 activation
{
"a_type": ["kFE4M3fn"],
"b_type": "kFE2M1f",
"c_type": ["kBFloat16"],
"s_type": "kFE8M0fnu",
"thread_configs": THREAD_CONFIGS,
"thread_m_blocks": [1, 2, 3, 4],
"group_blocks": [2],
},
]
def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
filename = os.path.dirname(__file__) + "/kernel_selector.h"
subprocess.call(["rm", "-f", filename])
def generate_new_kernels():
result_dict = {}
for quant_config in QUANT_CONFIGS:
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
b_type = quant_config["b_type"]
all_group_blocks = quant_config["group_blocks"]
all_m_blocks = quant_config["thread_m_blocks"]
all_thread_configs = quant_config["thread_configs"]
for a_type, c_type in itertools.product(a_types, c_types):
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
continue
if "16" in a_type and "16" in c_type and a_type != c_type:
continue
s_type = quant_config.get("s_type", c_type)
if (a_type, b_type, c_type) not in result_dict:
result_dict[(a_type, b_type, c_type)] = []
for group_blocks, m_blocks, thread_configs in itertools.product(
all_group_blocks, all_m_blocks, all_thread_configs
):
thread_k, thread_n, threads = thread_configs
if threads == 256:
# for small batch (m_blocks == 1),
# we only need (128, 128, 256)
# for large batch (m_blocks > 1),
# we only need (64, 256, 256)
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
continue
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
continue
config = {
"threads": threads,
"s_type": s_type,
"thread_m_blocks": max(m_blocks, 1),
"thread_k_blocks": thread_k // 16,
"thread_n_blocks": thread_n // 16,
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
"stages": "pipe_stages",
"group_blocks": group_blocks,
"is_zp_float": "false",
}
result_dict[(a_type, b_type, c_type)].append(config)
kernel_selector_str = FILE_HEAD_COMMENT
for (a_type, b_type, c_type), config_list in result_dict.items():
all_template_str_list = []
for config in config_list:
s_type = config["s_type"]
template_str = jinja2.Template(TEMPLATE).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
all_template_str_list.append(template_str)
conditions = [
f"a_type == vllm::{a_type}",
f"b_type == vllm::{b_type}",
f"c_type == vllm::{c_type}",
f"s_type == vllm::{s_type}",
f"threads == {config['threads']}",
f"thread_m_blocks == {config['thread_m_blocks']}",
f"thread_n_blocks == {config['thread_n_blocks']}",
f"thread_k_blocks == {config['thread_k_blocks']}",
f"m_block_size_8 == {config['m_block_size_8']}",
f"group_blocks == {config['group_blocks']}",
f"is_zp_float == {config['is_zp_float']}",
]
conditions = " && ".join(conditions)
if kernel_selector_str == FILE_HEAD_COMMENT:
kernel_selector_str += f"if ({conditions})\n kernel = "
else:
kernel_selector_str += f"else if ({conditions})\n kernel = "
kernel_template2 = (
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
"{{is_zp_float}}>;"
)
kernel_selector_str += (
jinja2.Template(kernel_template2).render(
a_type_id=f"vllm::{a_type}.id()",
b_type_id=f"vllm::{b_type}.id()",
c_type_id=f"vllm::{c_type}.id()",
s_type_id=f"vllm::{s_type}.id()",
**config,
)
+ "\n"
)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
if a_type == "kFE4M3fn":
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
else:
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
filename = filename.lower()
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
kernel_selector_str += (
"else if (a_type == vllm::kFE4M3fn)\n"
" TORCH_CHECK(false, "
'"marlin kernel with fp8 activation is not built.");'
)
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
f.write(kernel_selector_str)
if __name__ == "__main__":
remove_old_kernels()
generate_new_kernels()

View File

@@ -0,0 +1,47 @@
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, \
const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, \
const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME {
template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
const vllm::ScalarTypeId b_type_id, // B ScalarType id
const vllm::ScalarTypeId c_type_id, // C ScalarType id
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,863 @@
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "kernel.h"
#include "core/registration.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
namespace MARLIN_NAMESPACE_NAME {
__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
template <int moe_block_size>
__global__ void permute_cols_kernel(
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
int size_k, int top_k) {
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);
int32_t block_sorted_ids[moe_block_size];
int block_num_valid_tokens = 0;
int64_t old_expert_id = 0;
int64_t expert_id = 0;
int row_stride = size_k * sizeof(half) / 16;
auto read_moe_block_data = [&](int block_id) {
block_num_valid_tokens = moe_block_size;
int4* tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);
for (int i = 0; i < moe_block_size / 4; i++) {
tmp_block_sorted_ids[i] =
((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
}
for (int i = 0; i < moe_block_size; i++) {
if (block_sorted_ids[i] >= size_m * top_k) {
block_num_valid_tokens = i;
break;
};
}
};
auto permute_row = [&](int row) {
int iters = size_k / default_threads;
int rest = size_k % default_threads;
int in_offset = (row / top_k) * row_stride;
int out_offset = row * row_stride;
half const* a_row_half =
reinterpret_cast<half const*>(a_int4_ptr + in_offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += default_threads;
}
if (rest) {
if (threadIdx.x < rest) {
auto cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {
old_expert_id = expert_id;
int tmp_expert_id = expert_ids_ptr[index];
if (tmp_expert_id == -1) continue;
expert_id = tmp_expert_id;
perm_int_ptr += (expert_id - old_expert_id) * size_k;
read_moe_block_data(index);
for (int i = 0; i < block_num_valid_tokens; i++)
permute_row(block_sorted_ids[i]);
}
}
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256},
{64, 128, 128}};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256},
{64, 128, 128}};
typedef struct {
int blocks_per_sm;
thread_config_t tb_cfg;
} exec_config_t;
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
int tb_k = th_config.thread_k;
// Get max scale groups per thread-block
int tb_groups;
if (group_size == -1) {
tb_groups = 1;
} else if (group_size == 0) {
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
} else {
tb_groups = div_ceil(tb_k, group_size);
}
if (cache_scales_chunk) {
int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
}
}
int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float, bool is_a_8bit) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * 16;
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 16;
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_bias_size = tb_n * 2;
int tmp_size =
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
int sh_zp_size = 0;
if (has_zp) {
if (is_zp_float)
sh_zp_size = sh_s_size;
else if (num_bits == 4)
sh_zp_size = sh_s_size / 4;
else if (num_bits == 8)
sh_zp_size = sh_s_size / 2;
}
int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size +
sh_g_idx_size + sh_block_meta_size;
return total_size;
}
bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem, bool is_a_8bit) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
// Check that pipeline fits into cache
int cache_size =
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
return cache_size <= max_shared_mem;
}
MarlinFuncPtr get_marlin_kernel(
const vllm::ScalarType a_type, const vllm::ScalarType b_type,
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
int threads, bool is_zp_float) {
int num_bits = b_type.size_bits();
auto kernel = MarlinDefault;
#include "kernel_selector.h"
return kernel;
}
exec_config_t determine_exec_config(
const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
bool is_a_8bit) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1
? large_batch_thread_configs
: small_batch_thread_configs;
int thread_configs_size =
thread_m_blocks > 1
? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)
: sizeof(small_batch_thread_configs) / sizeof(thread_config_t);
int count = 0;
constexpr int device_max_reg_size = 255 * 1024;
for (int i = 0; i < thread_configs_size; i++) {
thread_config_t th_config = thread_configs[i];
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
is_a_8bit)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
is_a_8bit);
int group_blocks = 0;
if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : (group_size / 16);
}
auto kernel =
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
th_config.thread_n / 16, th_config.thread_k / 16,
m_block_size_8, has_act_order, has_zp, group_blocks,
th_config.num_threads, is_zp_float);
if (kernel == MarlinDefault) continue;
cudaFuncAttributes attr;
cudaFuncGetAttributes(&attr, kernel);
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
int allow_count = min(device_max_reg_size / reg_size,
max_shared_mem / (cache_size + 1536));
if (thread_m_blocks == 1)
allow_count = max(min(allow_count, 4), 1);
else
allow_count = max(min(allow_count, 2), 1);
if (prob_n / th_config.thread_n * prob_m * top_k * 4 < sms * allow_count) {
allow_count =
max(prob_n / th_config.thread_n * prob_m * top_k * 4 / sms, 1);
}
if (allow_count > count) {
count = allow_count;
exec_cfg = {count, th_config};
};
}
return exec_cfg;
}
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
void* perm, void* a_tmp, void* sorted_token_ids,
void* expert_ids, void* num_tokens_past_padded,
void* topk_weights, int moe_block_size, int num_experts,
int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int blocks_per_sm,
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8;
bool is_a_8bit = a_type.size_bits() == 8;
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
int group_blocks = 0;
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(group_size != -1);
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
} else {
TORCH_CHECK(group_size == 0);
group_blocks = 0;
}
} else {
if (group_size == -1) {
group_blocks = -1;
} else {
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
}
}
int num_bits = b_type.size_bits();
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* bias_ptr = (const int4*)b_bias;
const float* a_s_ptr = (const float*)a_s;
const int4* b_s_ptr = (const int4*)b_s;
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp;
const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids;
const int32_t* expert_ids_ptr = (const int32_t*)expert_ids;
const int32_t* num_tokens_past_padded_ptr =
(const int32_t*)num_tokens_past_padded;
const float* topk_weights_ptr = (const float*)topk_weights;
int* locks = (int*)workspace;
if (has_act_order) {
// Permute A columns
auto kernel = permute_cols_kernel<8>;
if (moe_block_size == 8) {
} else if (moe_block_size == 16)
kernel = permute_cols_kernel<16>;
else if (moe_block_size == 32)
kernel = permute_cols_kernel<32>;
else if (moe_block_size == 48)
kernel = permute_cols_kernel<48>;
else if (moe_block_size == 64)
kernel = permute_cols_kernel<64>;
else
TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size);
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<sms, default_threads, 0, stream>>>(
A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr,
num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
// clang-format on
A_ptr = a_tmp_ptr;
prob_m = prob_m * top_k;
top_k = 1;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if (is_k_full) has_act_order = false;
}
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
int major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
dev);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
dev);
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
"marlin kernel only support Ampere or newer GPUs.");
if (a_type == vllm::kFE4M3fn) {
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
"FP8 only support Ada Lovelace or newer GPUs.");
TORCH_CHECK(
major_capability * 10 + minor_capability == 89 ||
major_capability * 10 + minor_capability == 120,
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
"Marlin W4A16 on other devices).");
}
// Set thread config
exec_config_t exec_cfg;
thread_config_t thread_tfg;
if (thread_k != -1 && thread_n != -1) {
thread_tfg = thread_config_t{thread_k, thread_n, thread_k * thread_n / 64};
if (blocks_per_sm == -1) blocks_per_sm = 1;
exec_cfg = exec_config_t{blocks_per_sm, thread_tfg};
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
} else {
// Auto config
exec_cfg = determine_exec_config(
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
is_a_8bit);
thread_tfg = exec_cfg.tb_cfg;
}
int num_threads = thread_tfg.num_threads;
thread_k = thread_tfg.thread_k;
thread_n = thread_tfg.thread_n;
int blocks = sms * exec_cfg.blocks_per_sm;
if (exec_cfg.blocks_per_sm > 1)
max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float,
max_shared_mem, is_a_8bit),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem = ", max_shared_mem);
int sh_cache_size =
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, is_a_8bit);
auto kernel = get_marlin_kernel(
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
num_threads, is_zp_float);
if (kernel == MarlinDefault) {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
", ", prob_k, "]", ", has_act_order = ", has_act_order,
", num_groups = ", num_groups, ", group_size = ", group_size,
", thread_m_blocks = ", thread_m_blocks,
", thread_n_blocks = ", thread_n_blocks,
", thread_k_blocks = ", thread_k_blocks,
", num_bits = ", num_bits);
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
// clang-format on
}
} // namespace MARLIN_NAMESPACE_NAME
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
std::optional<torch::Tensor> const& a_scales_or_none,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
bool is_zp_float, int64_t thread_k, int64_t thread_n,
int64_t blocks_per_sm) {
vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
auto c_dtype = a.dtype();
if (a.scalar_type() == at::ScalarType::Half) {
a_type_id = vllm::kFloat16.id();
c_type_id = vllm::kFloat16.id();
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
a_type_id = vllm::kBFloat16.id();
c_type_id = vllm::kBFloat16.id();
} else {
c_dtype = b_scales.dtype();
if (b_scales.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
c_type_id = vllm::kBFloat16.id();
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
torch::Tensor c = c_or_none.value();
c_dtype = c.dtype();
if (c.scalar_type() == at::ScalarType::Half) {
c_type_id = vllm::kFloat16.id();
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
c_type_id = vllm::kBFloat16.id();
} else {
TORCH_CHECK(false, "unsupported c dtype");
}
}
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
a_type_id = vllm::kFE4M3fn.id();
} else if (a.scalar_type() == at::ScalarType::Char) {
a_type_id = vllm::kS8.id();
} else {
TORCH_CHECK(false, "unsupported `a` scalar_type");
}
}
s_type_id = c_type_id;
if (b_type_id == vllm::kFE2M1f.id()) {
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
s_type_id = vllm::kFE4M3fn.id();
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
s_type_id = vllm::kFE8M0fnu.id();
} else {
TORCH_CHECK(false,
"When b_type = float4_e2m1f, b_scale scalar type must be",
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
}
}
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
int pack_factor = 32 / b_type.size_bits();
int num_experts = b_q_weight.size(0);
if (moe_block_size != 8) {
TORCH_CHECK(moe_block_size % 16 == 0,
"unsupported moe_block_size=", moe_block_size);
TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64,
"unsupported moe_block_size=", moe_block_size);
}
// Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
", size_m = ", size_m);
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
", size_k = ", size_k);
// Verify B
TORCH_CHECK(
size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1),
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
", size_k = ", size_k,
", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
TORCH_CHECK(
b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0,
"b_q_weight.size(2) = ", b_q_weight.size(2),
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
int actual_size_n =
(b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n);
// Verify device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
torch::Tensor a_scales;
auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (a_scales_or_none.has_value()) {
a_scales = a_scales_or_none.value();
TORCH_CHECK(a_type.size_bits() == 8,
"a_scales can only be used for 8bit activation.");
} else {
a_scales = torch::empty({0}, options_fp32);
TORCH_CHECK(a_type.size_bits() != 8,
"the a_scales parameter must be passed for 8bit activation.");
}
// sms: number of SMs to use for the kernel
int sms = -1;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
torch::Tensor c;
if (c_or_none.has_value()) {
c = c_or_none.value();
TORCH_CHECK(c.device().is_cuda(), "c is not on GPU");
TORCH_CHECK(c.is_contiguous(), "c is not contiguous");
TORCH_CHECK(c.size(0) == size_m * top_k,
"Shape mismatch: c.size(0) = ", c.size(0),
", size_m * topk = ", size_m * top_k);
TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1),
", size_n = ", size_n);
} else {
c = torch::empty({size_m * top_k, size_n}, options);
}
// Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp;
if (use_fp32_reduce && !use_atomic_add) {
// max num of threadblocks is sms * 4
long max_c_tmp_size = min(
(long)size_n * sorted_token_ids.size(0),
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
if (moe_block_size == 8) max_c_tmp_size *= 2;
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
} else {
c_tmp = torch::empty({0}, options_fp32);
}
// Detect groupsize and act_order
int num_groups = -1;
int group_size = -1;
int rank = b_scales.sizes().size();
TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3");
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
" is not size_n = ", size_n);
num_groups = b_scales.size(1);
torch::Tensor g_idx, perm, a_tmp;
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
g_idx = g_idx_or_none.value();
perm = perm_or_none.value();
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
// Verify g_idx and perm
TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) ||
(g_idx.size(-1) == size_k && perm.size(-1) == size_k),
"Unexpected g_idx.size(-1) = ", g_idx.size(-1),
" and perm.size(-1) = ", perm.size(-1),
", where size_k = ", size_k);
} else {
g_idx = torch::empty({0}, options);
perm = torch::empty({0}, options);
a_tmp = torch::empty({0}, options);
}
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
if (has_act_order) {
a_tmp = torch::empty({size_m * top_k, size_k}, options);
if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by num_groups = ", num_groups);
group_size = size_k / num_groups;
} else {
group_size = 0;
}
} else {
a_tmp = torch::empty({0}, options);
if (num_groups > 1) {
TORCH_CHECK(
size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by b_scales.size(1) = ", b_scales.size(1));
group_size = size_k / num_groups;
} else {
group_size = -1;
}
}
torch::Tensor global_scale;
if (global_scale_or_none.has_value()) {
global_scale = global_scale_or_none.value();
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
"global_scale can only be used for nvfp4 format.");
} else {
global_scale = torch::empty({0}, options);
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
"the global_scale parameter must be passed for nvfp4 format.");
}
bool has_bias = b_bias_or_none.has_value();
torch::Tensor b_bias;
if (has_bias) {
b_bias = b_bias_or_none.value();
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
TORCH_CHECK(b_bias.size(1) == size_n, "b_bias.size(0) != size_n");
TORCH_CHECK(b_bias.stride(1) == 1, "b_bias.stride(1) != 1");
} else {
b_bias = torch::empty({0}, options);
}
torch::Tensor b_zeros;
if (b_zeros_or_none.has_value()) {
b_zeros = b_zeros_or_none.value();
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
} else {
b_zeros = torch::empty({0}, options);
}
bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) {
TORCH_CHECK(
b_type == vllm::kU4 || b_type == vllm::kU8,
"b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
} else {
TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
"b_type must be uint4b8, uint8b128, int4, int8, "
"float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
b_type.str());
}
if (has_zp && is_zp_float) {
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
"Computation type must be float16 (half) when using float zero "
"points.");
}
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
if (is_zp_float) {
TORCH_CHECK(b_zeros.size(2) == size_n,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n = ", size_n);
TORCH_CHECK(num_groups == b_zeros.size(1),
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
} else {
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
}
// Verify workspace size
TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0,
"size_n = ", size_n, ", is not divisible by min_thread_n = ",
MARLIN_NAMESPACE_NAME::min_thread_n);
int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n;
int min_workspace_size = min(
max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4);
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device();
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
"scalar type of a_scales must be float");
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
"scalar type of global_scale must be the same with c");
if (a_type.size_bits() == 16) {
TORCH_CHECK(
a.scalar_type() == c.scalar_type(),
"scalar type of a must be the same with c for 16 bit activation");
}
MARLIN_NAMESPACE_NAME::marlin_mm(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
is_zp_float);
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
}

View File

@@ -0,0 +1,759 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cub/cub.cuh>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace vllm {
namespace moe {
namespace batched_moe_align_block_size {
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
static constexpr int32_t num_threads = 1024;
static constexpr int32_t num_blocks = 1;
__global__ void batched_moe_align_block_size_kernel(
int32_t const num_batches, int32_t const max_tokens_per_batch,
int32_t const block_size, int32_t const* __restrict__ batch_num_tokens,
int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids,
int32_t* __restrict__ num_tokens_post_pad) {
// TODO(varun): This is a naive implementation. Could be optimized.
size_t const batch_id = threadIdx.x;
size_t const stride = blockDim.x * gridDim.x;
int32_t const num_blocks_per_batch =
CEILDIV(max_tokens_per_batch, block_size);
int32_t const sorted_ids_size =
num_blocks_per_batch * num_batches * block_size;
int32_t const block_ids_size = sorted_ids_size / block_size;
int32_t const SENTINEL =
num_batches * max_tokens_per_batch; // To denote invalid entries.
// Intialize sorted_ids
for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) {
sorted_ids[i] = SENTINEL;
}
// Intialize expert_ids with -1
for (size_t i = threadIdx.x; i < block_ids_size; i += stride) {
block_ids[i] = -1;
}
int32_t b_num_tokens = 0;
if (batch_id < num_batches) {
b_num_tokens = batch_num_tokens[batch_id];
}
int32_t const ceil_b_num_tokens =
CEILDIV(b_num_tokens, block_size) * block_size;
// Compute prefix sum over token counts per expert
using BlockScan = cub::BlockScan<int32_t, 1024>;
__shared__ typename BlockScan::TempStorage temp_storage;
int cumsum_val;
BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val);
__syncthreads();
bool const is_last_batch = batch_id == (num_batches - 1);
if (is_last_batch) {
*num_tokens_post_pad = cumsum_val + ceil_b_num_tokens;
}
if (batch_id < num_batches) {
int32_t const batch_offset = batch_id * max_tokens_per_batch;
for (size_t i = 0; i < b_num_tokens; ++i) {
sorted_ids[cumsum_val + i] = batch_offset + i;
}
int32_t const block_start = cumsum_val / block_size;
int32_t const num_blocks = ceil_b_num_tokens / block_size;
for (size_t i = 0; i < num_blocks; ++i) {
block_ids[block_start + i] = batch_id;
}
}
}
} // namespace batched_moe_align_block_size
template <typename scalar_t>
__device__ void _moe_align_block_size(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t* __restrict__ expert_map, int32_t num_experts,
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded,
int32_t max_num_m_blocks, int32_t model_offset, int32_t inactive_expert_id,
int32_t topk_num, int32_t* token_mask, bool has_expert_map) {
extern __shared__ int32_t shared_counts[];
// Compute input buffer offsets. Typically these will all be 0, except when
// using Multi LoRA.
int sorted_token_ids_offset = max_num_tokens_padded * model_offset;
int expert_ids_offset = max_num_m_blocks * model_offset;
int cumsum_offset = (num_experts + 1) * model_offset;
// Use separate threadblocks to fill sorted_token_ids.
// This is safe since the current kernel does not use sorted_token_ids.
if (blockIdx.x % 2) {
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded;
it += blockDim.x) {
sorted_token_ids[sorted_token_ids_offset + it] = numel;
}
return;
}
const int warp_id = threadIdx.x / WARP_SIZE;
const int my_expert_start = warp_id * experts_per_warp;
for (int i = 0; i < experts_per_warp; ++i) {
if (my_expert_start + i < padded_num_experts) {
shared_counts[warp_id * experts_per_warp + i] = 0;
}
}
__syncthreads();
const size_t tid = threadIdx.x;
const size_t stride = blockDim.x;
for (size_t i = tid; i < numel; i += stride) {
int expert_id = topk_ids[i];
if (expert_id >= num_experts) {
continue;
}
if (has_expert_map) {
expert_id = expert_map[expert_id];
// filter invalid experts
if (expert_id == -1) continue;
}
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset],
mask);
}
__syncthreads();
// Compute prefix sum over token counts per expert
using BlockScan = cub::BlockScan<int32_t, 1024>;
__shared__ typename BlockScan::TempStorage temp_storage;
int expert_count = 0;
int expert_id = threadIdx.x;
if (expert_id < num_experts) {
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
expert_count = CEILDIV(expert_count, block_size) * block_size;
}
int cumsum_val;
BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val);
if (expert_id <= num_experts) {
cumsum[cumsum_offset + expert_id] = cumsum_val;
}
if (expert_id == num_experts) {
total_tokens_post_pad[model_offset] = cumsum_val;
}
__syncthreads();
if (threadIdx.x < num_experts) {
for (int i = cumsum[cumsum_offset + threadIdx.x];
i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) {
expert_ids[expert_ids_offset + i / block_size] = threadIdx.x;
}
}
// Fill remaining expert_ids with 0
const size_t fill_start_idx =
cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x;
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) {
expert_ids[expert_ids_offset + i] = inactive_expert_id;
}
}
template <typename scalar_t, int32_t fill_threads>
__device__ void _moe_align_block_size_small_batch_expert(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size,
size_t numel, int32_t max_num_tokens_padded, int32_t max_num_m_blocks,
int32_t inactive_expert_id, int32_t model_offset, int32_t topk_num,
int32_t* token_mask, bool has_expert_map) {
// Compute input buffer offsets. Typically these will all be 0, except when
// using Multi LoRA.
int sorted_token_ids_offset = max_num_tokens_padded * model_offset;
int expert_ids_offset = max_num_m_blocks * model_offset;
// Use an additional group of threads to fill sorted_token_ids.
// Since the current kernel will use sorted_token_ids afterward,
// we fill sorted_token_ids within the same threadblock to make
// synchronization easier.
if (threadIdx.x < fill_threads) {
// Initialize sorted_token_ids with numel
for (size_t it = threadIdx.x; it < max_num_tokens_padded;
it += fill_threads) {
sorted_token_ids[sorted_token_ids_offset + it] = numel;
}
// Three __syncthreads() corresponding to the other threads
__syncthreads();
__syncthreads();
__syncthreads();
return;
}
const size_t tid = threadIdx.x - fill_threads;
const size_t stride = blockDim.x - fill_threads;
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[(tid + 1) * num_experts + i] = 0;
}
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
if (has_expert_map) {
expert_id = expert_map[expert_id];
// filter invalid expert
if (expert_id == -1) continue;
}
int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
tokens_cnts[(tid + 1) * num_experts + expert_id] += mask;
}
__syncthreads();
if (tid < num_experts) {
tokens_cnts[tid] = 0;
for (int i = 1; i <= stride; ++i) {
tokens_cnts[i * num_experts + tid] +=
tokens_cnts[(i - 1) * num_experts + tid];
}
}
__syncthreads();
if (tid == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] =
cumsum[i - 1] +
CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) *
block_size;
}
total_tokens_post_pad[model_offset] =
static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
if (tid < num_experts) {
for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) {
expert_ids[expert_ids_offset + i / block_size] = tid;
}
}
// Fill remaining expert_ids with 0
const size_t fill_start_idx = cumsum[num_experts] / block_size + tid;
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) {
expert_ids[expert_ids_offset + i] = inactive_expert_id;
}
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
if (has_expert_map) {
expert_id = expert_map[expert_id];
// filter invalid expert
if (expert_id == -1) continue;
}
int32_t rank_post_pad =
tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id];
if (token_mask == nullptr || token_mask[i / topk_num]) {
sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i;
++tokens_cnts[tid * num_experts + expert_id];
}
}
}
template <typename scalar_t>
__device__ void _count_and_sort_expert_tokens(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
int32_t max_num_tokens_padded, int32_t* __restrict__ token_mask,
int32_t model_offset, int32_t topk_num, bool has_expert_map) {
const size_t tid = blockIdx.y * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.y;
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
if (expert_id >= num_experts) {
continue;
}
if (has_expert_map) {
expert_id = expert_map[expert_id];
// filter invalid experts
if (expert_id == -1) continue;
}
if (token_mask == nullptr || token_mask[i / topk_num]) {
int32_t rank_post_pad = atomicAdd(
&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1);
sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] =
i;
}
}
}
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t* __restrict__ expert_map, int32_t num_experts,
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded,
int32_t topk_num, bool has_expert_map) {
_moe_align_block_size(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, padded_num_experts, experts_per_warp, block_size, numel,
cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size),
0, 0, topk_num, nullptr, has_expert_map);
}
template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
int32_t max_num_tokens_padded, int32_t topk_num, bool has_expert_map) {
_count_and_sort_expert_tokens(
topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts,
max_num_tokens_padded, nullptr, 0, topk_num, has_expert_map);
}
template <typename scalar_t, int TOPK>
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., topk, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
scalar_t x = 0.0;
#pragma unroll
for (int k = 0; k < TOPK; ++k) {
x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
}
out[token_idx * d + idx] = x;
}
}
template <typename scalar_t, int32_t fill_threads>
__global__ void moe_align_block_size_small_batch_expert_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size,
size_t numel, int32_t max_num_tokens_padded, int32_t topk_num,
bool has_expert_map) {
_moe_align_block_size_small_batch_expert<scalar_t, fill_threads>(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, block_size, numel, max_num_tokens_padded,
CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr,
has_expert_map);
}
template <typename scalar_t>
__global__ void moe_lora_align_block_size_kernel(
scalar_t* __restrict__ topk_ids, int32_t* __restrict__ token_lora_mapping,
int64_t block_size, int32_t* __restrict__ expert_map, int num_experts,
int max_loras, size_t numel, int max_num_tokens_padded,
int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ expert_ids, int32_t topk_num,
int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
int32_t* __restrict__ cumsum, int32_t experts_per_warp,
int32_t padded_num_experts, int32_t* lora_ids,
int32_t* __restrict__ token_mask, bool has_expert_map) {
int lora_idx = blockIdx.x / 2;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
return;
}
// Populate the token_mask based on the token-LoRA mapping
int num_tokens = numel / topk_num;
if (threadIdx.x == 0) {
total_tokens_post_pad[lora_id] = 0;
for (int i = 0; i < num_tokens; i++) {
token_mask[(lora_id * num_tokens) + i] =
(int)token_lora_mapping[i] == lora_id;
}
}
__syncthreads();
_moe_align_block_size(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, padded_num_experts, experts_per_warp, block_size, numel,
cumsum, max_num_tokens_padded, max_num_m_blocks, lora_id, -1, topk_num,
&token_mask[(lora_id * num_tokens)], has_expert_map);
}
template <typename scalar_t>
__global__ void lora_count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
int32_t max_num_tokens_padded, int32_t topk_num, int32_t* token_mask,
int32_t* lora_ids, bool has_expert_map) {
int lora_idx = blockIdx.x;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1) {
return;
}
int num_tokens = numel / topk_num;
_count_and_sort_expert_tokens(
topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts,
max_num_tokens_padded, &token_mask[(lora_id * num_tokens)], lora_id,
topk_num, has_expert_map);
}
template <typename scalar_t, int32_t fill_threads>
__global__ void moe_lora_align_block_size_small_batch_expert_kernel(
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
int64_t block_size, int32_t* __restrict__ expert_map, int num_experts,
int max_loras, size_t numel, int max_num_tokens_padded,
int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ expert_ids, int topk_num,
int32_t* total_tokens_post_pad, int32_t* adapter_enabled, int32_t* lora_ids,
int32_t* token_mask, bool has_expert_map) {
int lora_idx = blockIdx.x;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
return;
}
int num_tokens = numel / topk_num;
if (threadIdx.x == 0) {
total_tokens_post_pad[lora_id] = 0;
for (int i = 0; i < num_tokens; i++) {
token_mask[(lora_id * num_tokens) + i] =
(int)token_lora_mapping[i] == lora_id;
}
}
__syncthreads();
_moe_align_block_size_small_batch_expert<scalar_t, fill_threads>(
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
num_experts, block_size, numel, max_num_tokens_padded, max_num_m_blocks,
-1, lora_id, topk_num, &token_mask[(lora_id * num_tokens)],
has_expert_map);
}
} // namespace moe
} // namespace vllm
// taken from
// https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
std::optional<torch::Tensor> maybe_expert_map) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t padded_num_experts =
((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int experts_per_warp = WARP_SIZE;
int threads = 1024;
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
// BlockScan uses 1024 threads and assigns one thread per expert.
TORCH_CHECK(padded_num_experts < 1024,
"padded_num_experts must be less than 1024");
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
bool has_expert_map = maybe_expert_map.has_value();
torch::Tensor expert_map;
if (has_expert_map) {
expert_map = maybe_expert_map.value();
} else {
expert_map = torch::empty({0}, options_int);
}
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `cumsum` tensors
bool small_batch_expert_mode =
(topk_ids.numel() < 1024) && (num_experts <= 64);
if (small_batch_expert_mode) {
const int32_t threads = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem_size =
((threads + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
// threadIdx.x >= fill_threads: counting experts and aligning
// threadIdx.x < fill_threads: filling sorted_token_ids
constexpr int32_t fill_threads = 256;
auto small_batch_expert_kernel =
vllm::moe::moe_align_block_size_small_batch_expert_kernel<
scalar_t, fill_threads>;
small_batch_expert_kernel<<<1, fill_threads + threads,
shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
expert_map.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.numel(), sorted_token_ids.size(0), topk_ids.size(1),
has_expert_map);
} else {
torch::Tensor cumsum_buffer =
torch::empty({num_experts + 1}, options_int);
auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
size_t shared_mem_size =
num_warps * experts_per_warp * sizeof(int32_t);
// launch two threadblocks
// blockIdx.x == 0: counting experts and aligning
// blockIdx.x == 1: filling sorted_token_ids
align_kernel<<<2, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
expert_map.data_ptr<int32_t>(), num_experts, padded_num_experts,
experts_per_warp, block_size, topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>(), sorted_token_ids.size(0),
topk_ids.size(1), has_expert_map);
const int block_threads = std::min(256, (int)threads);
const int num_blocks =
(topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
dim3 gridDims(1, actual_blocks);
auto sort_kernel =
vllm::moe::count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<gridDims, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), expert_map.data_ptr<int32_t>(),
topk_ids.numel(), num_experts, sorted_token_ids.size(0),
topk_ids.size(1), has_expert_map);
}
});
}
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size,
torch::Tensor const& batch_num_tokens,
torch::Tensor sorted_ids,
torch::Tensor batch_ids,
torch::Tensor num_tokens_post_pad) {
namespace batched_kernel = vllm::moe::batched_moe_align_block_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t const B = batch_num_tokens.size(0);
int32_t const num_blocks_per_batch =
round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size;
int32_t const num_blocks = num_blocks_per_batch * B;
int64_t const sorted_ids_size = num_blocks * block_size;
TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size);
TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size);
TORCH_CHECK(num_tokens_post_pad.size(0) == 1);
TORCH_CHECK(B <= batched_kernel::num_threads);
batched_kernel::batched_moe_align_block_size_kernel<<<
batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>(
B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr<int32_t>(),
sorted_ids.data_ptr<int32_t>(), batch_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>());
}
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size]
{
const int hidden_size = input.size(-1);
const auto num_tokens = output.numel() / hidden_size;
const int topk = input.size(1);
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (topk) {
case 2:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 3:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
case 4:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
default:
at::sum_out(output, input, 1);
break;
}
}
void moe_lora_align_block_size(
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids, std::optional<torch::Tensor> maybe_expert_map) {
const int topk_num = topk_ids.size(1);
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t padded_num_experts =
((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
// BlockScan uses 1024 threads and assigns one thread per expert.
TORCH_CHECK(padded_num_experts < 1024,
"padded_num_experts must be less than 1024");
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
torch::Tensor token_mask =
torch::empty({max_loras * topk_ids.size(0)}, options_int);
bool has_expert_map = maybe_expert_map.has_value();
torch::Tensor expert_map;
if (has_expert_map) {
expert_map = maybe_expert_map.value();
} else {
expert_map = torch::empty({0}, options_int);
}
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
bool small_batch_expert_mode =
(topk_ids.numel() < 1024) && (num_experts <= 64);
if (small_batch_expert_mode) {
const int32_t num_thread = max((int32_t)num_experts, 128);
const int32_t shared_mem =
(num_thread + 1) * num_experts * sizeof(int32_t) +
(num_experts + 1) * sizeof(int32_t);
if (shared_mem > device_max_shared_mem) {
TORCH_CHECK(false, "Shared memory usage exceeds device limit.");
}
// threadIdx.x >= fill_threads: counting experts and aligning
// threadIdx.x < fill_threads: filling sorted_token_ids
constexpr int32_t fill_threads = 256;
dim3 blockDim(num_thread + fill_threads);
auto kernel =
vllm::moe::moe_lora_align_block_size_small_batch_expert_kernel<
scalar_t, fill_threads>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
token_lora_mapping.data_ptr<int32_t>(), block_size,
expert_map.data_ptr<int32_t>(), num_experts, max_loras,
topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks,
sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>(),
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>(),
token_mask.data_ptr<int32_t>(), has_expert_map);
} else {
int num_thread = 1024;
dim3 blockDim(num_thread);
size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE);
size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t);
// cumsum buffer
torch::Tensor cumsum =
torch::zeros({max_loras * (num_experts + 1)}, options_int);
auto align_kernel =
vllm::moe::moe_lora_align_block_size_kernel<scalar_t>;
// launch two threadblocks for each lora
// blockIdx.x % 2 == 0: counting experts and aligning
// blockIdx.x % 2 == 1: filling sorted_token_ids
align_kernel<<<max_loras * 2, blockDim, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
token_lora_mapping.data_ptr<int32_t>(), block_size,
expert_map.data_ptr<int32_t>(), num_experts, max_loras,
topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks,
sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>(),
adapter_enabled.data_ptr<int32_t>(), cumsum.data_ptr<int32_t>(),
WARP_SIZE, padded_num_experts, lora_ids.data_ptr<int32_t>(),
token_mask.data_ptr<int32_t>(), has_expert_map);
const int block_threads = std::min(256, (int)num_thread);
const int num_blocks =
(topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
dim3 gridDims(max_loras, actual_blocks);
auto sort_kernel =
vllm::moe::lora_count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<gridDims, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(), cumsum.data_ptr<int32_t>(),
expert_map.data_ptr<int32_t>(), topk_ids.numel(), num_experts,
max_num_tokens_padded, topk_num, token_mask.data_ptr<int32_t>(),
lora_ids.data_ptr<int32_t>(), has_expert_map);
}
});
}

52
csrc/moe/moe_ops.h Normal file
View File

@@ -0,0 +1,52 @@
#pragma once
#include <torch/all.h>
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output, bool renormalize);
void moe_sum(torch::Tensor& input, torch::Tensor& output);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad,
std::optional<torch::Tensor> maybe_expert_map);
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size,
torch::Tensor const& expert_num_tokens,
torch::Tensor sorted_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
void moe_lora_align_block_size(
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids, std::optional<torch::Tensor> maybe_expert_map);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,
std::optional<torch::Tensor> b_qzeros,
std::optional<torch::Tensor> topk_weights,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, int64_t top_k,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit);
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
torch::Tensor const& scores, int64_t n_group, int64_t topk_group,
int64_t topk, bool renormalize, double routed_scaling_factor,
torch::Tensor const& bias, int64_t scoring_func);
#endif
bool moe_permute_unpermute_supported();
void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor);

View File

@@ -0,0 +1,222 @@
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h"
#include "permute_unpermute_kernels/dispatch.h"
#include "core/registration.h"
// moe_permute kernels require at least CUDA 12.0
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
void moe_permute(
const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& token_expert_indices, // [n_token, topk]
const std::optional<torch::Tensor>& expert_map, // [n_expert]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input, // [permuted_size, hidden]
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& inv_permuted_idx, // [n_token, topk]
torch::Tensor& permuted_idx, // [permute_size]
torch::Tensor& m_indices) { // [align_expand_m]
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
"token_expert_indices must be int32");
TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
"inv_permuted_idx must be int32");
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
"expert_first_token_offset shape != n_local_expert+1")
TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
"token_expert_indices shape must be same as inv_permuted_idx");
auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1];
auto align_block_size_value =
align_block_size.has_value() ? align_block_size.value() : -1;
auto stream = at::cuda::getCurrentCUDAStream().stream();
const long sorter_size =
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
auto sort_workspace = torch::empty(
{sorter_size},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
auto permuted_experts_id = torch::empty_like(topk_ids);
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
CubKeyValueSorter sorter{};
int64_t* valid_num_ptr = nullptr;
// pre-process kernel for expert-parallelism:
// no local expert id plus "n_expert" offset for priority to local expert
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
// operation is to make local expert high priority in following sort topk_ids
// and scan local expert_first_token_offset for each ep rank for next group
// gemm.
if (expert_map.has_value()) {
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
valid_num_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
preprocessTopkIdLauncher(get_ptr<int>(copy_topk_ids), n_token * topk,
expert_map_ptr, n_expert, stream);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert(
get_ptr<int>(copy_topk_ids), get_ptr<int>(token_expert_indices),
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
n_hidden, topk, n_local_expert, align_block_size_value, stream);
});
// get m_indices and update expert_first_token_offset with align block
// this is only required for DeepGemm and not required for CUTLASS group gemm
if (align_block_size.has_value()) {
auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
expert_first_token_offset.copy_(align_expert_first_token_offset);
}
}
void moe_unpermute(
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
const torch::Tensor& topk_weights, // [n_token, topk]
const torch::Tensor& inv_permuted_idx, // [n_token, topk]
const std::optional<torch::Tensor>&
expert_first_token_offset, // [n_local_expert+1]
int64_t topk,
torch::Tensor& hidden_states // [n_token, hidden]
) {
TORCH_CHECK(
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
"permuted_hidden_states dtype must be same as hidden_states");
auto n_token = hidden_states.size(0);
auto n_hidden = hidden_states.size(1);
auto stream = at::cuda::getCurrentCUDAStream().stream();
int64_t const* valid_ptr = nullptr;
if (expert_first_token_offset.has_value()) {
int n_local_expert = expert_first_token_offset.value().size(0) - 1;
valid_ptr =
get_ptr<int64_t>(expert_first_token_offset.value()) + n_local_expert;
}
MOE_DISPATCH(hidden_states.scalar_type(), [&] {
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
get_ptr<scalar_t>(permuted_hidden_states),
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
get_ptr<int>(inv_permuted_idx), n_token, n_hidden, topk, valid_ptr,
stream);
});
}
template <typename T>
__global__ void shuffleInputRowsKernel(const T* input,
const int32_t* dst2src_map, T* output,
int64_t num_src_rows,
int64_t num_dst_rows, int64_t num_cols) {
int64_t dest_row_idx = blockIdx.x;
int64_t const source_row_idx = dst2src_map[dest_row_idx];
if (blockIdx.x < num_dst_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = num_cols / ELEM_PER_THREAD;
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor) {
TORCH_CHECK(input_tensor.scalar_type() == output_tensor.scalar_type(),
"Input and output tensors must have the same data type");
auto stream = at::cuda::getCurrentCUDAStream().stream();
int64_t const blocks = output_tensor.size(0);
int64_t const threads = 256;
int64_t const num_dest_rows = output_tensor.size(0);
int64_t const num_src_rows = input_tensor.size(0);
int64_t const num_cols = input_tensor.size(1);
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
"num_cols must be divisible by 128 / "
"sizeof(input_tensor.scalar_type()) / 8");
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
dst2src_map.data_ptr<int32_t>(),
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
num_dest_rows, num_cols);
});
}
#else
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indices,
const std::optional<torch::Tensor>& expert_map,
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
}
void moe_unpermute(
const torch::Tensor& permuted_hidden_states,
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
torch::Tensor& hidden_states) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
}
#endif
bool moe_permute_unpermute_supported() {
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
return true;
#else
return false;
#endif
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_permute", &moe_permute);
m.impl("moe_unpermute", &moe_unpermute);
}

342
csrc/moe/moe_wna16.cu Normal file
View File

@@ -0,0 +1,342 @@
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "moe_wna16_utils.h"
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
template <typename scalar_t, int bit, int GROUPS>
__global__ void moe_wna16_gemm_kernel(
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
const uint32_t* __restrict__ qzeros,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_token_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ num_tokens_post_pad,
uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m,
uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M,
uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp,
bool mul_topk_weight) {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
return;
} else {
#endif
using Dtype = ScalarType<scalar_t>;
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return;
const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x;
const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K;
const int32_t expert_id = expert_ids[blockIdx.x];
int32_t num_valid_tokens = 0;
extern __shared__ uint16_t block_input_tmp[];
scalar_t* block_input = reinterpret_cast<scalar_t*>(block_input_tmp);
scalar_t2* block_input_half2 = reinterpret_cast<scalar_t2*>(block_input);
// load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory
for (int m = 0; m < BLOCK_SIZE_M; m++) {
const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m;
const int32_t token_index = sorted_token_ids[offset_m];
if (token_index / top_k >= size_m) break;
num_valid_tokens = m + 1;
if (expert_id != -1) {
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
for (int i = 0; i < k_per_thread; i++) {
int k = BLOCK_SIZE_N * i + threadIdx.x;
if (k >= BLOCK_SIZE_K) break;
if (offset_k + k >= size_k) break;
// load input to shared memory
// use a special layout to fit the layout of dequanted-weight
int origin_k;
if constexpr (bit == 4) {
// [0, 4, 1, 5, 2, 6, 3, 7]
int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2);
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order;
} else {
// [0, 2, 1, 3]
int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2);
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order;
}
origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K;
block_input[m * BLOCK_SIZE_K + k] = input[origin_k];
}
}
}
if (expert_id == -1) return;
__syncthreads();
if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return;
float res[64]; // assume BLOCK_SIZE_M <= 64
scalar_t2 res2;
scalar_t2 scale_f2;
scalar_t2 qzero_f2;
// note that (size_n * size_k * expert_id) may greater than 2 ** 31
constexpr int8_t pack_factor = 32 / bit;
const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id;
const uint32_t* expert_qweight = qweight + expert_offset / pack_factor;
const scalar_t* expert_scales = scales + expert_offset / group_size;
const uint32_t* expert_qzeros =
qzeros + expert_offset / group_size / pack_factor;
// load 4*int32 one time: 4 int32 = 128 bit = 1 float4
// weight would be loaded in loop
uint32_t expert_qweight_tmp[4];
float4* expert_qweight_tmp_float4 =
reinterpret_cast<float4*>(expert_qweight_tmp);
// load all required scales one time
scalar_t expert_scales_groups[GROUPS];
int scales_offset_tmp =
(offset_n * size_k + offset_k) / group_size / GROUPS;
if constexpr (GROUPS == 1) {
*expert_scales_groups = expert_scales[scales_offset_tmp];
} else if constexpr (GROUPS == 2) {
float* expert_scales_groups_tmp =
reinterpret_cast<float*>(expert_scales_groups);
*expert_scales_groups_tmp =
reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp];
} else if constexpr (GROUPS == 4) {
float2* expert_scales_groups_tmp =
reinterpret_cast<float2*>(expert_scales_groups);
*expert_scales_groups_tmp =
reinterpret_cast<const float2*>(expert_scales)[scales_offset_tmp];
} else if constexpr (GROUPS == 8) {
float4* expert_scales_groups_tmp =
reinterpret_cast<float4*>(expert_scales_groups);
*expert_scales_groups_tmp =
reinterpret_cast<const float4*>(expert_scales)[scales_offset_tmp];
}
// load all required qzeros one time
uint8_t expert_qzeros_groups[GROUPS];
if (!has_zp) {
if constexpr (bit == 4) {
qzero_f2 = Dtype::num2num2(Dtype::int2num(8));
} else {
qzero_f2 = Dtype::num2num2(Dtype::int2num(128));
}
} else {
int qzeros_offset_tmp =
(offset_n / (8 / bit)) * (size_k / group_size / GROUPS) +
offset_k / group_size / GROUPS;
if constexpr (GROUPS == 1) {
uint8_t* expert_qzeros_groups_tmp =
reinterpret_cast<uint8_t*>(expert_qzeros_groups);
*expert_qzeros_groups_tmp =
reinterpret_cast<const uint8_t*>(expert_qzeros)[qzeros_offset_tmp];
} else if constexpr (GROUPS == 2) {
uint16_t* expert_qzeros_groups_tmp =
reinterpret_cast<uint16_t*>(expert_qzeros_groups);
*expert_qzeros_groups_tmp =
reinterpret_cast<const uint16_t*>(expert_qzeros)[qzeros_offset_tmp];
} else if constexpr (GROUPS == 4) {
uint32_t* expert_qzeros_groups_tmp =
reinterpret_cast<uint32_t*>(expert_qzeros_groups);
*expert_qzeros_groups_tmp =
reinterpret_cast<const uint32_t*>(expert_qzeros)[qzeros_offset_tmp];
} else if constexpr (GROUPS == 8) {
uint64_t* expert_qzeros_groups_tmp =
reinterpret_cast<uint64_t*>(expert_qzeros_groups);
*expert_qzeros_groups_tmp =
reinterpret_cast<const uint64_t*>(expert_qzeros)[qzeros_offset_tmp];
}
}
for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) {
int k = offset_k + tmp_k * pack_factor;
if (k >= size_k) break;
const int32_t weight_offset = offset_n * size_k + k;
if (tmp_k % 4 == 0) {
*expert_qweight_tmp_float4 = reinterpret_cast<const float4*>(
expert_qweight)[weight_offset / pack_factor / 4];
}
if (tmp_k % (group_size / pack_factor) == 0) {
scalar_t scale_f =
expert_scales_groups[tmp_k / (group_size / pack_factor)];
scale_f2 = Dtype::num2num2(scale_f);
if (has_zp) {
uint8_t qzero =
expert_qzeros_groups[tmp_k / (group_size / pack_factor)];
if constexpr (bit == 4) {
qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF;
}
qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero));
}
}
scalar_t2 weight_half2[16 / bit];
dequant<scalar_t2, bit>(expert_qweight_tmp[tmp_k % 4], weight_half2);
for (int m = 0; m < num_valid_tokens; m++) {
res2 = {};
#pragma unroll
for (int i = 0; i < 16 / bit; i++) {
int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i;
res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2),
block_input_half2[offset_input], res2);
}
if (tmp_k == 0) {
res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
} else {
res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
}
}
}
for (int m = 0; m < num_valid_tokens; ++m) {
const int32_t token_index =
sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m];
if (mul_topk_weight) {
res[m] *= topk_weights[token_index];
}
atomicAdd(&output[token_index * size_n + offset_n],
Dtype::float2num(res[m]));
}
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
}
#endif
}
template <typename scalar_t>
void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output,
const uint32_t* b_qweight, const scalar_t* b_scales,
const uint32_t* b_qzeros, const float* topk_weights,
const int32_t* sorted_token_ids,
const int32_t* expert_ids,
const int32_t* num_tokens_post_pad, int num_experts,
int group_size, int num_token_blocks, int top_k,
int size_m, int size_n, int size_k, int BLOCK_SIZE_M,
int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit,
bool has_zp, bool mul_topk_weight) {
dim3 blockDim, gridDim;
blockDim.x = BLOCK_SIZE_N;
blockDim.y = 1;
blockDim.z = 1;
gridDim.x = num_token_blocks;
gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N);
gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K);
auto kernel = moe_wna16_gemm_kernel<scalar_t, 4, 1>;
if (bit == 4) {
if (BLOCK_SIZE_K / group_size == 2) {
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 2>;
} else if (BLOCK_SIZE_K / group_size == 4) {
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 4>;
} else if (BLOCK_SIZE_K / group_size == 8) {
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 8>;
}
} else {
if (BLOCK_SIZE_K / group_size == 1) {
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 1>;
} else if (BLOCK_SIZE_K / group_size == 2) {
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 2>;
} else if (BLOCK_SIZE_K / group_size == 4) {
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 4>;
} else if (BLOCK_SIZE_K / group_size == 8) {
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 8>;
}
}
const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, shared_mem_size, stream>>>(
input, output, b_qweight, b_scales, b_qzeros, topk_weights,
sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts,
group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N,
BLOCK_SIZE_K, has_zp, mul_topk_weight);
}
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,
std::optional<torch::Tensor> b_qzeros,
std::optional<torch::Tensor> topk_weights,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, int64_t top_k,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
output.zero_();
const int num_experts = b_qweight.size(0);
const int size_m = input.size(0);
const int size_n = b_qweight.size(1);
const int size_k = input.size(1);
const int group_size = size_k / b_scales.size(2);
int64_t EM = sorted_token_ids.size(0);
if (size_m <= BLOCK_SIZE_M) {
EM = min(EM, size_m * BLOCK_SIZE_M * top_k);
}
const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
const uint32_t* b_qzeros_ptr;
if (b_qzeros.has_value())
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
const float* topk_weights_ptr = nullptr;
if (topk_weights.has_value())
topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
int groups_per_block_row = BLOCK_SIZE_K / group_size;
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
TORCH_CHECK(size_k % BLOCK_SIZE_K == 0,
"size_k must divisible by BLOCK_SIZE_K");
TORCH_CHECK(BLOCK_SIZE_K % group_size == 0,
"BLOCK_SIZE_K must divisible by group_size");
TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64");
TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 ||
groups_per_block_row == 4 || groups_per_block_row == 8,
"BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]");
if (input.scalar_type() == at::ScalarType::Half) {
run_moe_wna16_gemm<half>(
(const half*)input.data_ptr<at::Half>(),
(half*)output.data_ptr<at::Half>(),
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
(const half*)b_scales.data_ptr<at::Half>(), b_qzeros_ptr,
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
b_qzeros.has_value(), topk_weights.has_value());
} else if (input.scalar_type() == at::ScalarType::BFloat16) {
run_moe_wna16_gemm<nv_bfloat16>(
(const nv_bfloat16*)input.data_ptr<at::BFloat16>(),
(nv_bfloat16*)output.data_ptr<at::BFloat16>(),
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
(const nv_bfloat16*)b_scales.data_ptr<at::BFloat16>(), b_qzeros_ptr,
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
b_qzeros.has_value(), topk_weights.has_value());
} else {
TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16");
}
return output;
}

200
csrc/moe/moe_wna16_utils.h Normal file
View File

@@ -0,0 +1,200 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
template <typename scalar_t>
class ScalarType {};
template <>
class ScalarType<half> {
public:
using scalar_t = half;
using scalar_t2 = half2;
static __device__ float inline num2float(const half x) {
return __half2float(x);
}
static __device__ half2 inline num2num2(const half x) {
return __half2half2(x);
}
static __device__ half2 inline nums2num2(const half x1, const half x2) {
return __halves2half2(x1, x2);
}
static __host__ __device__ half inline float2num(const float x) {
return __float2half(x);
}
static __host__ __device__ half inline int2num(const float x) {
return __int2half_rn(x);
}
static __host__ __device__ float2 inline num22float2(const half2 x) {
return __half22float2(x);
}
static __host__ __device__ half2 inline float22num2(const float2 x) {
return __float22half2_rn(x);
}
};
template <>
class ScalarType<nv_bfloat16> {
public:
using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
const nv_bfloat16 x2) {
return __halves2bfloat162(x1, x2);
}
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x);
}
static __host__ __device__ nv_bfloat16 inline int2num(const float x) {
return __int2bfloat16_rn(x);
}
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
return __bfloat1622float2(x);
}
static __host__ __device__ nv_bfloat162 inline float22num2(const float2 x) {
return __float22bfloat162_rn(x);
}
#endif
};
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t2, int bit>
__device__ inline void dequant(int q, scalar_t2* res) {}
template <>
__device__ inline void dequant<half2, 4>(int q, half2* res) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
q >>= 8;
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0),
*reinterpret_cast<const half2*>(&SUB));
res[1] = __hfma2(*reinterpret_cast<half2*>(&hi0),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
res[2] = __hsub2(*reinterpret_cast<half2*>(&lo1),
*reinterpret_cast<const half2*>(&SUB));
res[3] = __hfma2(*reinterpret_cast<half2*>(&hi1),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<half2, 8>(int q, half2* res) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
res[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
__device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
res[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo0),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
res[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi0),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
res[2] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo1),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
res[3] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi1),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
}
template <>
__device__ inline void dequant<nv_bfloat162, 8>(int q, nv_bfloat162* res) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388608.f;
fp32_intermediates[1] -= 8388608.f;
fp32_intermediates[2] -= 8388608.f;
fp32_intermediates[3] -= 8388608.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(res);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
}
#endif

View File

@@ -0,0 +1,59 @@
#pragma once
#include <cuda_fp8.h>
#define MOE_SWITCH(TYPE, ...) \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \
}
#define MOE_DISPATCH_CASE(enum_type, ...) \
case enum_type: { \
using scalar_t = ScalarType2CudaType<enum_type>::type; \
__VA_ARGS__(); \
break; \
}
#define MOE_DISPATCH_FLOAT_CASE(...) \
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define MOE_DISPATCH(TYPE, ...) \
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
template <at::ScalarType type>
struct ScalarType2CudaType;
template <>
struct ScalarType2CudaType<at::ScalarType::Float> {
using type = float;
};
template <>
struct ScalarType2CudaType<at::ScalarType::Half> {
using type = half;
};
template <>
struct ScalarType2CudaType<at::ScalarType::BFloat16> {
using type = __nv_bfloat16;
};
// uint8 for packed fp4
template <>
struct ScalarType2CudaType<at::ScalarType::Byte> {
using type = uint8_t;
};
// #if __CUDA_ARCH__ >= 890
// fp8
template <>
struct ScalarType2CudaType<at::ScalarType::Float8_e5m2> {
using type = __nv_fp8_e5m2;
};
template <>
struct ScalarType2CudaType<at::ScalarType::Float8_e4m3fn> {
using type = __nv_fp8_e4m3;
};
// #endif

View File

@@ -0,0 +1,231 @@
#include "moe_permute_unpermute_kernel.h"
// moe_permute kernels require at least CUDA 12.0
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
// CubKeyValueSorter definition begin
CubKeyValueSorter::CubKeyValueSorter()
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
int CubKeyValueSorter::expertsToBits(int num_experts) {
// Max value we represent is V = num_experts + (num_experts - 1) = 2 *
// num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1
return static_cast<int>(log2(2 * num_experts - 1)) + 1;
}
CubKeyValueSorter::CubKeyValueSorter(int const num_experts)
: num_experts_(num_experts), num_bits_(expertsToBits(num_experts)) {}
void CubKeyValueSorter::updateNumExperts(int const num_experts) {
num_experts_ = num_experts;
num_bits_ = expertsToBits(num_experts);
}
size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs,
int const num_experts) {
int num_bits = expertsToBits(num_experts);
size_t required_storage = 0;
int* null_int = nullptr;
cub::DeviceRadixSort::SortPairs(nullptr, required_storage, null_int, null_int,
null_int, null_int, num_key_value_pairs, 0,
num_bits);
// when num_key_value_pairs, num_experts, num_bits, required_storage = 64,
// 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same
// inputs
if (required_storage == 0) {
required_storage = 1;
}
return required_storage;
}
void CubKeyValueSorter::run(void* workspace, size_t const workspace_size,
int const* keys_in, int* keys_out,
int const* values_in, int* values_out,
size_t const num_key_value_pairs,
cudaStream_t stream) {
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_);
size_t actual_ws_size = workspace_size;
TORCH_CHECK(expected_ws_size <= workspace_size,
"[CubKeyValueSorter::run] The allocated workspace is too small "
"to run this problem.");
cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out,
values_in, values_out, num_key_value_pairs, 0,
num_bits_, stream);
}
// CubKeyValueSorter definition end
static inline size_t pad_to_multiple_of_16(size_t const& input) {
static constexpr int ALIGNMENT = 16;
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
}
template <class T>
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
int64_t const arr_length,
T const target) {
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high) {
int64_t mid = (low + high) / 2;
if (sorted_indices[mid] >= target) {
high = mid - 1;
} else {
low = mid + 1;
target_location = mid;
}
}
return target_location + 1;
}
// Calculates the start offset of the tokens for a given expert. The last
// element is the total number of valid tokens
__global__ void computeExpertFirstTokenOffsetKernel(
int const* sorted_experts, int64_t const sorted_experts_len,
int const num_experts, int64_t* expert_first_token_offset) {
// First, compute the global tid. We only need 1 thread per expert.
int const expert = blockIdx.x * blockDim.x + threadIdx.x;
// Note that expert goes [0, num_experts] (inclusive) because we want a count
// for the total number of active tokens at the end of the scan.
if (expert >= num_experts + 1) {
return;
}
expert_first_token_offset[expert] =
findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert);
}
void computeExpertFirstTokenOffset(int const* sorted_indices,
int const total_indices,
int const num_experts,
int64_t* expert_first_token_offset,
cudaStream_t stream) {
int const num_entries = num_experts + 1;
int const threads = std::min(1024, num_entries);
int const blocks = (num_entries + threads - 1) / threads;
computeExpertFirstTokenOffsetKernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, expert_first_token_offset);
}
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
int* permuted_experts, int* permuted_rows,
int64_t* expert_first_token_offset, int num_rows,
int num_experts, int num_experts_per_node, int k,
CubKeyValueSorter& sorter, void* sorter_ws,
cudaStream_t stream) {
int64_t const expanded_num_rows = static_cast<int64_t>(k) * num_rows;
// We need to use the full num_experts because that is the sentinel value used
// by topk for disabled experts
sorter.updateNumExperts(num_experts);
size_t const sorter_ws_size_bytes = pad_to_multiple_of_16(
sorter.getWorkspaceSize(expanded_num_rows, num_experts));
sorter.run((void*)sorter_ws, sorter_ws_size_bytes, expert_for_source_row,
permuted_experts, source_rows, permuted_rows, expanded_num_rows,
stream);
computeExpertFirstTokenOffset(permuted_experts, expanded_num_rows,
num_experts_per_node, expert_first_token_offset,
stream);
}
__global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
const int* expert_map_ptr,
int num_experts) {
auto tidx = threadIdx.x;
auto bidx = blockIdx.x;
auto offset = bidx * blockDim.x;
auto bound = min(offset + blockDim.x, size);
extern __shared__ int smem_expert_map[];
// store expert_map in smem
for (int i = tidx; i < num_experts; i += blockDim.x) {
smem_expert_map[i] = expert_map_ptr[i];
}
__syncthreads();
// query global expert id in expert map.
// if global expert id = -1 in exert map, plus n_expert
// else set global expert id = exert map[global expert id]
if (offset + tidx < bound) {
auto topk_id = topk_id_ptr[offset + tidx];
auto local_expert_idx = smem_expert_map[topk_id];
if (local_expert_idx == -1) {
topk_id += num_experts;
} else {
topk_id = local_expert_idx;
}
__syncwarp();
topk_id_ptr[offset + tidx] = topk_id;
}
}
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts,
cudaStream_t stream) {
int block = std::min(size, 1024);
int grid = (size + block - 1) / block;
int smem_size = (num_experts) * sizeof(int);
preprocessTopkIdKernel<<<grid, block, smem_size, stream>>>(
topk_id_ptr, size, expert_map_ptr, num_experts);
}
template <bool ALIGN_BLOCK_SIZE>
__global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset,
int* m_indices, const int num_local_expert,
const int align_block_size) {
int eidx = blockIdx.x;
int tidx = threadIdx.x;
extern __shared__ int64_t smem_expert_first_token_offset[];
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i);
}
__syncthreads();
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
auto first_token_offset = smem_expert_first_token_offset[eidx];
int n_token_in_expert = last_token_offset - first_token_offset;
if constexpr (ALIGN_BLOCK_SIZE) {
n_token_in_expert = (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
// round up to ALIGN_BLOCK_SIZE
int64_t accumulate_align_offset = 0;
for (int i = 1; i <= eidx + 1; i++) {
int n_token = smem_expert_first_token_offset[i] -
smem_expert_first_token_offset[i - 1];
accumulate_align_offset =
accumulate_align_offset + (n_token + align_block_size - 1) /
align_block_size * align_block_size;
if (i == eidx) {
first_token_offset = accumulate_align_offset;
}
// last block store align_expert_first_token_offset
if (eidx == num_local_expert - 1 && threadIdx.x == 0) {
align_expert_first_token_offset[i] = accumulate_align_offset;
}
}
}
for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) {
// update m_indice with expert id
m_indices[first_token_offset + idx] = eidx;
}
}
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream) {
int block = 256;
int grid = num_local_expert;
int smem_size = sizeof(int64_t) * (num_local_expert + 1);
if (align_block_size == -1) {
getMIndicesKernel<false><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
} else {
getMIndicesKernel<true><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
}
}
#endif

View File

@@ -0,0 +1,83 @@
#pragma once
// reference from tensorrt_llm moe kernel implementation archive in
// https://github.com/BBuf/tensorrt-llm-moe/tree/master
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include "dispatch.h"
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/util_type.cuh>
#include "cutlass/numeric_size.h"
#include "cutlass/array.h"
template <typename T>
inline T* get_ptr(torch::Tensor& t) {
return reinterpret_cast<T*>(t.data_ptr());
}
template <typename T>
inline const T* get_ptr(const torch::Tensor& t) {
return reinterpret_cast<const T*>(t.data_ptr());
}
class CubKeyValueSorter {
public:
CubKeyValueSorter();
CubKeyValueSorter(int const num_experts);
void updateNumExperts(int const num_experts);
static size_t getWorkspaceSize(size_t const num_key_value_pairs,
int const num_experts);
void run(void* workspace, size_t const workspace_size, int const* keys_in,
int* keys_out, int const* values_in, int* values_out,
size_t const num_key_value_pairs, cudaStream_t stream);
private:
static int expertsToBits(int experts);
int num_experts_;
int num_bits_;
};
void computeExpertFirstTokenOffset(int const* sorted_indices,
int const total_indices,
int const num_experts,
int64_t* expert_first_token_offset,
cudaStream_t stream);
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
int* permuted_experts, int* permuted_rows,
int64_t* expert_first_token_offset, int num_rows,
int num_experts, int num_experts_per_node, int k,
CubKeyValueSorter& sorter, void* sorter_ws,
cudaStream_t stream);
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream);
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int64_t const num_rows, int64_t const cols, int64_t const k,
int64_t const* num_valid_ptr, cudaStream_t stream);
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts,
cudaStream_t stream);
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream);
#include "moe_permute_unpermute_kernel.inl"

View File

@@ -0,0 +1,203 @@
#pragma once
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
int64_t expanded_dest_row = blockIdx.x;
int64_t const expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
extern __shared__ int64_t smem_expert_first_token_offset[];
if constexpr (ALIGN_BLOCK_SIZE) {
// load g2s
for (int idx = threadIdx.x; idx < num_local_experts + 1;
idx += blockDim.x) {
smem_expert_first_token_offset[idx] =
__ldg(expert_first_token_offset + idx);
}
__syncthreads();
int lane_idx = threadIdx.x & 31;
if (lane_idx == 0) {
// set token_offset_in_expert = 0 if this expert is not local expert
int token_offset_in_expert =
expert_id >= num_local_experts
? 0
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
int64_t accumulate_align_offset = 0;
#pragma unroll 1
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
smem_expert_first_token_offset[eidx - 1];
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
}
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
}
// lane0 shuffle broadcast align_expanded_dest_row
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
}
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
static_cast<int>(expanded_dest_row);
// skip non local expert token
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
permuted_idx[expanded_dest_row] = expanded_source_row;
}
}
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<T>::value;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
int64_t const source_row = expanded_source_row / k;
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(permuted_output + expanded_dest_row * cols);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = cols / ELEM_PER_THREAD;
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
int64_t const blocks = num_rows * k;
int64_t const threads = 256;
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
FuncPtr func_map[2][2] = {
{&expandInputRowsKernel<T, false, false>,
&expandInputRowsKernel<T, false, true>},
{&expandInputRowsKernel<T, true, false>,
&expandInputRowsKernel<T, true, true>},
};
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_align_block_size = align_block_size != -1;
auto func = func_map[is_check_skip][is_align_block_size];
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
func<<<blocks, threads, smem_size, stream>>>(
unpermuted_input, permuted_output, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, permuted_idx,
expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
num_local_experts, align_block_size);
}
template <class T, class U>
__host__ __device__ constexpr static U arrayConvert(T const& input) {
using Type = typename U::Element;
static_assert(T::kElements == U::kElements);
U u;
#pragma unroll
for (int i = 0; i < U::kElements; i++) {
u[i] = static_cast<Type>(input[i]);
}
return u;
}
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) {
assert(orig_cols % 4 == 0);
int64_t const original_row = blockIdx.x;
auto const offset = original_row * orig_cols;
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
int64_t const num_valid = *num_valid_ptr;
// Load 128-bits per thread, according to the smallest data type we read/write
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
128 / std::min(cutlass::sizeof_bits<OutputType>::value,
cutlass::sizeof_bits<T>::value);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
using OutputElem = cutlass::Array<OutputType, FINALIZE_ELEM_PER_THREAD>;
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
auto const* expanded_permuted_rows_v =
reinterpret_cast<InputElem const*>(expanded_permuted_rows);
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
#pragma unroll
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
ComputeElem thread_output;
thread_output.fill(0);
for (int k_idx = 0; k_idx < k; ++k_idx) {
int64_t const expanded_original_row = original_row * k + k_idx;
int64_t const expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx;
float const row_scale = scales[k_offset];
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
continue;
}
auto const* expanded_permuted_rows_row_ptr =
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(
expanded_permuted_rows_row_ptr[elem_index]);
thread_output = thread_output + row_scale * (expert_result);
}
OutputElem output_elem =
arrayConvert<ComputeElem, OutputElem>(thread_output);
reduced_row_ptr_v[elem_index] = output_elem;
}
}
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int64_t const num_rows, int64_t const cols, int64_t const k,
int64_t const* num_valid_ptr, cudaStream_t stream) {
int64_t const blocks = num_rows;
int64_t const threads = 256;
bool const check_finished = num_valid_ptr != nullptr;
using FuncPtr = decltype(&finalizeMoeRoutingKernel<T, OutputType, false>);
FuncPtr func_map[2] = {&finalizeMoeRoutingKernel<T, OutputType, false>,
&finalizeMoeRoutingKernel<T, OutputType, true>};
auto* const kernel = func_map[check_finished];
kernel<<<blocks, threads, 0, stream>>>(
expanded_permuted_rows, reduced_unpermuted_output, scales,
expanded_source_row_to_expanded_dest_row, cols, k, num_valid_ptr);
}

View File

@@ -0,0 +1,707 @@
/*
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
* Copyright (c) 2024, The vLLM team.
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <type_traits>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include "../cub_helpers.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat162 __nv_bfloat162;
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace vllm {
namespace moe {
/// Aligned array type
template <
typename T,
/// Number of elements in the array
int N,
/// Alignment requirement in bytes
int Alignment = sizeof(T) * N
>
struct alignas(Alignment) AlignedArray {
T data[N];
};
template <typename T>
__device__ __forceinline__ float toFloat(T value) {
if constexpr (std::is_same_v<T, float>) {
return value;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
return __bfloat162float(value);
} else if constexpr (std::is_same_v<T, __half>) {
return __half2float(value);
}
}
// ====================== Softmax things ===============================
// We have our own implementation of softmax here so we can support transposing the output
// in the softmax kernel when we extend this module to support expert-choice routing.
template <int TPB, typename InputType>
__launch_bounds__(TPB) __global__
void moeSoftmax(const InputType* input, const bool* finished, float* output, const int num_cols)
{
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
__shared__ float normalizing_factor;
__shared__ float float_max;
const int thread_row_offset = blockIdx.x * num_cols;
float threadData(-FLT_MAX);
// Don't touch finished rows.
if ((finished != nullptr) && finished[blockIdx.x])
{
return;
}
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
const float val = toFloat(input[idx]);
threadData = max(val, threadData);
}
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp());
if (threadIdx.x == 0)
{
float_max = maxElem;
}
__syncthreads();
threadData = 0;
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
const float val = toFloat(input[idx]);
threadData += expf(val - float_max);
}
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp());
if (threadIdx.x == 0)
{
normalizing_factor = 1.f / Z;
}
__syncthreads();
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
{
const int idx = thread_row_offset + ii;
const float val = toFloat(input[idx]);
const float softmax_val = expf(val - float_max) * normalizing_factor;
output[idx] = softmax_val;
}
}
template <int TPB, typename IndType>
__launch_bounds__(TPB) __global__ void moeTopK(
const float* inputs_after_softmax,
const bool* finished,
float* output,
IndType* indices,
int* source_rows,
const int num_experts,
const int k,
const int start_expert,
const int end_expert,
const bool renormalize)
{
using cub_kvp = cub::KeyValuePair<int, float>;
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
__shared__ typename BlockReduce::TempStorage tmpStorage;
cub_kvp thread_kvp;
cub::ArgMax arg_max;
const int num_rows = gridDim.x;
const int block_row = blockIdx.x;
const bool row_is_active = finished ? !finished[block_row] : true;
const int thread_read_offset = blockIdx.x * num_experts;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
thread_kvp.key = 0;
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
cub_kvp inp_kvp;
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
{
const int idx = thread_read_offset + expert;
inp_kvp.key = expert;
inp_kvp.value = inputs_after_softmax[idx];
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
{
const int prior_winning_expert = indices[k * block_row + prior_k];
if (prior_winning_expert == expert)
{
inp_kvp = thread_kvp;
}
}
thread_kvp = arg_max(inp_kvp, thread_kvp);
}
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
if (threadIdx.x == 0)
{
// Ignore experts the node isn't responsible for with expert parallelism
const int expert = result_kvp.key;
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
const int idx = k * block_row + k_idx;
output[idx] = result_kvp.value;
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
assert(indices[idx] >= 0);
source_rows[idx] = k_idx * num_rows + block_row;
if (renormalize) {
selected_sum += result_kvp.value;
}
}
__syncthreads();
}
// Renormalize the k weights for this row to sum to 1, if requested.
if (renormalize) {
if (threadIdx.x == 0) {
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
for (int k_idx = 0; k_idx < k; ++k_idx) {
const int idx = k * block_row + k_idx;
output[idx] = output[idx] / denom;
}
}
}
}
// ====================== TopK softmax things ===============================
/*
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
are a small power of 2. This allows us to cleanly share the rows among the threads in
a single warp and eliminate communication between warps (so no need to use shared mem).
It fuses the softmax, max and argmax into a single kernel.
Limitations:
1) This implementation is optimized for when the number of experts is a small power of 2.
Additionally it also supports when number of experts is multiple of 64 which is still
faster than the computing softmax and topK separately (only tested on CUDA yet).
2) This implementation assumes k is small, but will work for any k.
*/
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType, typename InputType = float>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize)
{
static_assert(std::is_same_v<InputType, float> || std::is_same_v<InputType, __nv_bfloat16> ||
std::is_same_v<InputType, __half>,
"InputType must be float, __nv_bfloat16, or __half");
// We begin by enforcing compile time assertions and setting up compile time constants.
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
// Number of bytes each thread pulls in per load
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
if constexpr (std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) {
static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0,
"ELTS_PER_LDG must be 1 or even for 16-bit conversion");
}
// Restrictions based on previous section.
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM, "THREADS_PER_ROW can be at most warp size");
// We have NUM_EXPERTS elements per row. We specialize for small #experts
static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT;
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
// Restrictions for previous section.
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
// ===================== From this point, we finally start computing run-time variables. ========================
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
// This, each block processes a chunk of rows. We start by computing the start row for each block.
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
// Now, using the base row per thread block, we compute the base row per warp.
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
// The threads in a warp are split into sub-groups that will work on a row.
// We compute row offset for each thread sub-group
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
const int thread_row = warp_base_row + thread_row_in_warp;
// Threads with indices out of bounds should early exit here.
if (thread_row >= num_rows)
{
return;
}
const bool row_is_active = finished ? !finished[thread_row] : true;
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
// row it will read.
const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
// Finally, we pull in the data from global mem
float row_chunk[VPT];
// NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert to float
if constexpr (std::is_same_v<InputType, float>) {
using VecType = AlignedArray<float, ELTS_PER_LDG>;
VecType* row_chunk_vec_ptr = reinterpret_cast<VecType*>(&row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
}
} else if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2(
*reinterpret_cast<const __nv_bfloat162*>(vec.data + jj * 2)
);
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __nv_bfloat16* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __bfloat162float(*scalar_ptr);
}
}
} else if constexpr (std::is_same_v<InputType, __half>) {
if constexpr (ELTS_PER_LDG >= 2) {
using VecType = AlignedArray<__half, ELTS_PER_LDG>;
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
#pragma unroll
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
row_chunk_f2[base_idx_f2 + jj] = __half22float2(
*reinterpret_cast<const __half2*>(vec.data + jj * 2)
);
}
}
} else { // ELTS_PER_LDG == 1
#pragma unroll
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
row_chunk[ii] = __half2float(*scalar_ptr);
}
}
}
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
// convert to float afterwards for the exp + sum reduction.
float thread_max = row_chunk[0];
#pragma unroll
for (int ii = 1; ii < VPT; ++ii)
{
thread_max = max(thread_max, row_chunk[ii]);
}
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
}
// From this point, thread max in all the threads have the max within the row.
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
float row_sum = 0;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii)
{
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
row_sum += row_chunk[ii];
}
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
}
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
// argmax after computing the softmax.
const float reciprocal_row_sum = 1.f / row_sum;
#pragma unroll
for (int ii = 0; ii < VPT; ++ii)
{
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
}
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
// with the max index.
int start_col = first_elt_read_by_thread;
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
float selected_sum = 0.f;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
// First, each thread does the local argmax
float max_val = row_chunk[0];
int expert = start_col;
#pragma unroll
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
{
#pragma unroll
for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
{
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
// No check on the experts here since columns with the smallest index are processed first and only
// updated if > (not >=)
if (val > max_val)
{
max_val = val;
expert = col + ii;
}
}
}
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
// then blank out their max with -inf and the warp can run more iterations...
#pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{
float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this way
if (other_max > max_val || (other_max == max_val && other_expert < expert))
{
max_val = other_max;
expert = other_expert;
}
}
// Write the max for this k iteration to global memory.
if (thread_group_idx == 0)
{
// Add a guard to ignore experts not included by this node
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
const bool should_process_row = row_is_active && node_uses_expert;
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
// single) thread per row of the input/output matrices.
const int idx = k * thread_row + k_idx;
output[idx] = max_val;
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
if (renormalize) {
selected_sum += max_val;
}
}
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
if (k_idx + 1 < k)
{
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
if (thread_group_idx == thread_to_clear_in_group)
{
const int offset_for_expert = expert % ELTS_PER_LDG;
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
}
}
}
// Renormalize the k weights for this row to sum to 1, if requested.
if (renormalize) {
if (thread_group_idx == 0)
{
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
for (int k_idx = 0; k_idx < k; ++k_idx)
{
const int idx = k * thread_row + k_idx;
output[idx] = output[idx] / denom;
}
}
}
}
namespace detail
{
// Constructs some constants needed to partition the work across threads at compile time.
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename InputType>
struct TopkConstants
{
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, "");
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW;
};
} // namespace detail
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType, typename InputType>
void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize,
cudaStream_t stream)
{
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM, IndType, InputType><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize);
}
#ifndef USE_ROCM
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
static_assert(WARP_SIZE == 32, \
"Unsupported warp size. Only 32 is supported for CUDA"); \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream);
#else
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
if (WARP_SIZE == 64) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream); \
} else if (WARP_SIZE == 32) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
num_tokens, topk, 0, num_experts, renormalize, stream); \
} else { \
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
}
#endif
template <typename IndType, typename InputType>
void topkGatingSoftmaxKernelLauncher(
const InputType* gating_output,
float* topk_weights,
IndType* topk_indices,
int* token_expert_indices,
float* softmax_workspace,
const int num_tokens,
const int num_experts,
const int topk,
const bool renormalize,
cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
#ifndef USE_ROCM
// for bfloat16 dtype, we need 4 bytes loading to make sure num_experts
// elements can be loaded by a warp
static constexpr int BYTES_PER_LDG_MULTIPLE_64 =
(std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) ? 4 : 8;
#endif
switch (num_experts) {
case 1:
LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 2:
LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 4:
LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 8:
LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 16:
LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 32:
LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 64:
LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 128:
LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 256:
LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
case 512:
LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
break;
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
// alternatively we can test 4 bytes loading and enable it in future.
#ifndef USE_ROCM
case 192:
LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 320:
LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 384:
LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 448:
LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
case 576:
LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
break;
#endif
default: {
TORCH_CHECK(softmax_workspace != nullptr,
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
static constexpr int TPB = 256;
moeSoftmax<TPB, InputType><<<num_tokens, TPB, 0, stream>>>(
gating_output, nullptr, softmax_workspace, num_experts);
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices,
num_experts, topk, 0, num_experts, renormalize);
}
}
}
} // namespace moe
} // namespace vllm
template<typename ComputeType>
void dispatch_topk_softmax_launch(
torch::Tensor& gating_output,
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& softmax_workspace,
int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream)
{
if (topk_indices.scalar_type() == at::ScalarType::Int) {
vllm::moe::topkGatingSoftmaxKernelLauncher<int, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
} else if (topk_indices.scalar_type() == at::ScalarType::UInt32) {
vllm::moe::topkGatingSoftmaxKernelLauncher<uint32_t, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<uint32_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
} else {
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
vllm::moe::topkGatingSoftmaxKernelLauncher<int64_t, ComputeType>(
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int64_t>(),
token_expert_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
num_tokens, num_experts, topk, renormalize, stream);
}
}
void topk_softmax(
torch::Tensor& topk_weights, // [num_tokens, topk]
torch::Tensor& topk_indices, // [num_tokens, topk]
torch::Tensor& token_expert_indices, // [num_tokens, topk]
torch::Tensor& gating_output, // [num_tokens, num_experts]
bool renormalize)
{
const int num_experts = gating_output.size(-1);
const auto num_tokens = gating_output.numel() / num_experts;
const int topk = topk_weights.size(-1);
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
const bool needs_workspace = !is_pow_2 || num_experts > 256;
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float);
torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options);
if (gating_output.scalar_type() == at::ScalarType::Float) {
dispatch_topk_softmax_launch<float>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
} else if (gating_output.scalar_type() == at::ScalarType::Half) {
dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
} else if (gating_output.scalar_type() == at::ScalarType::BFloat16) {
dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices,
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
} else {
TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type());
}
}

122
csrc/moe/torch_bindings.cpp Normal file
View File

@@ -0,0 +1,122 @@
#include "core/registration.h"
#include "moe_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output, bool renormalize) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
// Calculate the result of moe by summing up the partial results
// from all selected experts.
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
m.impl("moe_sum", torch::kCUDA, &moe_sum);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts,"
" int block_size, Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad,"
" Tensor? maybe_expert_map) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size, but for the batched case.
m.def(
"batched_moe_align_block_size(int max_tokens_per_batch,"
" int block_size, Tensor expert_num_tokens,"
" Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
m.impl("batched_moe_align_block_size", torch::kCUDA,
&batched_moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m.def(
"moe_lora_align_block_size(Tensor topk_ids,"
" Tensor token_lora_mapping,"
" int num_experts,"
" int block_size, int max_loras, "
" int max_num_tokens_padded, "
" int max_num_m_blocks, "
" Tensor !sorted_token_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad,"
" Tensor !adapter_enabled,"
" Tensor !lora_ids,"
" Tensor? maybe_expert_map) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
#ifndef USE_ROCM
m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
"Tensor b_scales, Tensor? b_qzeros, "
"Tensor? topk_weights, Tensor sorted_token_ids, "
"Tensor expert_ids, Tensor num_tokens_post_pad, "
"int top_k, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, "
"int bit) -> Tensor");
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor? b_bias_or_none,"
"Tensor! b_scales, Tensor? a_scales, Tensor? global_scale, Tensor? "
"b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float,"
"int thread_k, int thread_n, int blocks_per_sm) -> Tensor");
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.def(
"moe_permute(Tensor input, Tensor topk_ids,"
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
"permuted_idx, Tensor! m_indices)->()");
m.def(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"Tensor inv_permuted_idx, Tensor? expert_first_token_offset, "
"int topk, Tensor! hidden_states)->()");
m.def("moe_permute_unpermute_supported() -> bool");
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
// Row shuffle for MoE
m.def(
"shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! "
"output_tensor) -> ()");
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
// Apply grouped topk routing to select experts.
m.def(
"grouped_topk(Tensor scores, int n_group, int "
"topk_group, int topk, bool renormalize, float "
"routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, "
"Tensor)");
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
#endif
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)