Sync from v0.13
This commit is contained in:
147
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Normal file
147
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Normal file
@@ -0,0 +1,147 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// _dyn_quant_matmul_4bit is only available on AArch64.
|
||||
#if defined(__aarch64__)
|
||||
#include <ATen/ops/_dyn_quant_matmul_4bit.h>
|
||||
#endif
|
||||
|
||||
inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w,
|
||||
int64_t group_size_eff, int64_t in_features,
|
||||
int64_t out_features) {
|
||||
#if defined(__aarch64__)
|
||||
return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff,
|
||||
in_features, out_features);
|
||||
#else
|
||||
TORCH_CHECK(false,
|
||||
"dynamic 4-bit int MoE path requires AArch64 (ARM64); "
|
||||
"_dyn_quant_matmul_4bit is unavailable on this architecture");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
enum ActivationKind : int64_t {
|
||||
SwiGLU_Gu = 0, // act = SiLU(g) * u
|
||||
SwiGLUOAI = 1, // act = SiLU(u) * g
|
||||
SiLU = 2 // SiLU
|
||||
};
|
||||
|
||||
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
||||
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
|
||||
int64_t activation_kind) {
|
||||
TORCH_CHECK(x.dim() == 2, "x must be 2D");
|
||||
TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2,
|
||||
"topk tensors must be [T, K]");
|
||||
TORCH_CHECK(
|
||||
w13_packed.size(0) == w2_packed.size(0),
|
||||
"w13_packed and w2_packed must have same number of experts in dim 0");
|
||||
TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I");
|
||||
|
||||
const int64_t T = x.size(0);
|
||||
const int64_t K = topk_ids.size(1);
|
||||
const int64_t E = w13_packed.size(0);
|
||||
const int64_t N = T * K;
|
||||
|
||||
auto x_c = x.contiguous();
|
||||
auto ids_c = topk_ids.contiguous();
|
||||
auto gates_c = topk_weights.to(at::kFloat).contiguous();
|
||||
|
||||
// bucketing tokens -> experts
|
||||
c10::SmallVector<int64_t, 64> counts(
|
||||
E, 0); // Small vector uses stack allocation
|
||||
{
|
||||
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
|
||||
for (int64_t i = 0; i < N; ++i) {
|
||||
const int64_t e_id = ids_ptr[i];
|
||||
TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range");
|
||||
counts[e_id]++;
|
||||
}
|
||||
}
|
||||
c10::SmallVector<int64_t, 65> offsets(E + 1, 0); // ( E +1 )
|
||||
for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e];
|
||||
|
||||
auto expert_tokens = at::empty({offsets[E]}, ids_c.options());
|
||||
auto expert_gates = at::empty({offsets[E]}, gates_c.options());
|
||||
{
|
||||
c10::SmallVector<int64_t, 64> cursor(E, 0);
|
||||
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
|
||||
const auto* gts_ptr = gates_c.data_ptr<float>();
|
||||
auto* tok_ptr = expert_tokens.data_ptr<int64_t>();
|
||||
auto* gate_ptr = expert_gates.data_ptr<float>();
|
||||
|
||||
for (int64_t t = 0; t < T; ++t) {
|
||||
const int64_t base = t * K;
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const int64_t idx = base + k;
|
||||
const int64_t e = ids_ptr[idx];
|
||||
const int64_t p = offsets[e] + (cursor[e]++);
|
||||
tok_ptr[p] = t;
|
||||
gate_ptr[p] = gts_ptr[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
|
||||
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
|
||||
|
||||
auto X_all = x_c.index_select(/*dim=*/0, expert_tokens);
|
||||
if (apply_router_weight_on_input) {
|
||||
X_all = X_all.mul(expert_gates.unsqueeze(1));
|
||||
}
|
||||
auto Y_all = at::empty({offsets[E], H}, x_c.options());
|
||||
|
||||
at::parallel_for(0, offsets[E], 0, [&](int64_t idx_begin, int64_t idx_end) {
|
||||
c10::InferenceMode guard;
|
||||
for (int64_t e = 0; e < E; ++e) {
|
||||
int64_t start = std::max(offsets[e], idx_begin);
|
||||
int64_t end = std::min(offsets[e + 1], idx_end);
|
||||
int64_t te = end - start;
|
||||
if (te <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||
|
||||
auto w13_e = w13_packed.select(/*dim=*/0, e);
|
||||
auto w2_e = w2_packed.select(/*dim=*/0, e);
|
||||
|
||||
// W13
|
||||
auto y13 =
|
||||
mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2);
|
||||
|
||||
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
|
||||
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
|
||||
|
||||
torch::Tensor act;
|
||||
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
|
||||
constexpr double kAlpha = 1.702; // GPT-OSS default
|
||||
constexpr double kLimit = 7.0; // GPT-OSS default
|
||||
auto gate_c = at::clamp_max(g_part, kLimit);
|
||||
auto up_c = at::clamp(u_part, -kLimit, kLimit);
|
||||
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
|
||||
act = up_c.add(1.0).mul(glu);
|
||||
} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
|
||||
act = at::silu(g_part).mul(u_part);
|
||||
}
|
||||
|
||||
// W2
|
||||
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
|
||||
|
||||
// Store per-expert result
|
||||
Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y);
|
||||
}
|
||||
});
|
||||
|
||||
if (!apply_router_weight_on_input) {
|
||||
Y_all = Y_all.mul(expert_gates.unsqueeze(1));
|
||||
}
|
||||
|
||||
auto out = at::zeros({T, H}, x.options());
|
||||
out =
|
||||
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
|
||||
|
||||
return out;
|
||||
}
|
||||
891
csrc/moe/grouped_topk_kernels.cu
Normal file
891
csrc/moe/grouped_topk_kernels.cu
Normal file
@@ -0,0 +1,891 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
|
||||
* Copyright (c) 2025, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
template <int size, typename T>
|
||||
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
return ((len - 1) / size + 1) * size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr __host__ __device__ bool isPowerOf2(T v) {
|
||||
return (v && !(v & (v - 1)));
|
||||
}
|
||||
|
||||
template <bool greater, typename T>
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
|
||||
return (val > baseline && greater) || (val < baseline && !greater);
|
||||
}
|
||||
|
||||
template <bool greater, typename T, typename idxT>
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
|
||||
idxT baseline_index) {
|
||||
bool res = (val > baseline && greater) || (val < baseline && !greater);
|
||||
if (val == baseline) {
|
||||
res = (index < baseline_index && greater) ||
|
||||
(index < baseline_index && !greater);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T, typename idxT>
|
||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
|
||||
return max(cache_topk,
|
||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||
}
|
||||
|
||||
template <int size, bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge {
|
||||
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
constexpr int stride = arr_len / 2;
|
||||
for (int i = 0; i < stride; ++i) {
|
||||
int const other_i = i + stride;
|
||||
T& val = val_arr[i];
|
||||
T& other_val = val_arr[other_i];
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
|
||||
idx_arr[other_i]);
|
||||
} else {
|
||||
is_better = is_better_than<ascending>(val, other_val);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
T tmp = val;
|
||||
val = other_val;
|
||||
other_val = tmp;
|
||||
|
||||
idxT tmp2 = idx_arr[i];
|
||||
idx_arr[i] = idx_arr[other_i];
|
||||
idx_arr[other_i] = tmp2;
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort<32, ascending, T, idxT, is_stable> {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// ascending doesn't matter before merging since all we need is a bitonic
|
||||
// sequence
|
||||
for (int stage = 0; stage < 4; ++stage) {
|
||||
for (int stride = (1 << stage); stride > 0; stride /= 2) {
|
||||
bool reverse = (lane >> stage) & 2;
|
||||
bool is_second = lane & stride;
|
||||
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) !=
|
||||
(reverse != is_second);
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) !=
|
||||
(reverse != is_second);
|
||||
}
|
||||
} else {
|
||||
is_better = (*val_arr != other &&
|
||||
(*val_arr > other) != (reverse != is_second));
|
||||
}
|
||||
if (is_better) {
|
||||
*val_arr = other;
|
||||
*idx_arr = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
|
||||
idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
|
||||
bool is_second = lane & stride;
|
||||
T& val = *val_arr;
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
||||
idxT& idx = *idx_arr;
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) ==
|
||||
(reverse != is_second); // for min
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) ==
|
||||
(reverse != is_second); // for max
|
||||
}
|
||||
} else {
|
||||
is_better =
|
||||
(val != other && ((val > other) == (ascending != is_second)));
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
val = other;
|
||||
idx = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSort {
|
||||
public:
|
||||
__device__ WarpSort(idxT k, T dummy)
|
||||
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
||||
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
||||
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
val_arr_[i] = dummy_;
|
||||
idx_arr_[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// load and merge k sorted values
|
||||
__device__ void load_sorted(T const* __restrict__ in,
|
||||
idxT const* __restrict__ in_idx, idxT start) {
|
||||
idxT idx = start + WARP_SIZE - 1 - lane_;
|
||||
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
||||
if (idx < start + k_) {
|
||||
T t = in[idx];
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(t, val_arr_[i]);
|
||||
}
|
||||
if (is_better) {
|
||||
val_arr_[i] = t;
|
||||
idx_arr_[i] = in_idx[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
}
|
||||
|
||||
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out[out_i] = val_arr_[i];
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void dumpIdx(idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||
|
||||
T val_arr_[max_arr_len_];
|
||||
idxT idx_arr_[max_arr_len_];
|
||||
|
||||
int const lane_;
|
||||
idxT const k_;
|
||||
T const dummy_;
|
||||
|
||||
}; // end class WarpSort
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
public:
|
||||
__device__ WarpSelect(idxT k, T dummy)
|
||||
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
|
||||
k_th_(dummy),
|
||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||
|
||||
int const num_of_warp = blockDim.x / WARP_SIZE;
|
||||
int const warp_id = threadIdx.x / WARP_SIZE;
|
||||
val_smem_ = reinterpret_cast<T*>(smem_buf);
|
||||
val_smem_ += warp_id * WARP_SIZE;
|
||||
idx_smem_ = reinterpret_cast<idxT*>(
|
||||
smem_buf +
|
||||
round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
|
||||
idx_smem_ += warp_id * WARP_SIZE;
|
||||
}
|
||||
|
||||
__device__ void add(T const* in, idxT start, idxT end) {
|
||||
idxT const end_for_fullwarp =
|
||||
round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
|
||||
for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
|
||||
T val = (i < end) ? in[i] : dummy_;
|
||||
add(val, i);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void add(T val, idxT idx) {
|
||||
bool do_add;
|
||||
if constexpr (is_stable) {
|
||||
do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
|
||||
} else {
|
||||
do_add = is_better_than<greater>(val, k_th_);
|
||||
}
|
||||
|
||||
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
||||
if (mask == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
|
||||
if (do_add && pos < WARP_SIZE) {
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
do_add = false;
|
||||
}
|
||||
smem_buf_len_ += __popc(mask);
|
||||
if (smem_buf_len_ >= WARP_SIZE) {
|
||||
__syncwarp();
|
||||
merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
|
||||
smem_buf_len_ -= WARP_SIZE;
|
||||
}
|
||||
if (do_add) {
|
||||
pos -= WARP_SIZE;
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
__device__ void done() {
|
||||
if (smem_buf_len_) {
|
||||
T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
|
||||
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
|
||||
merge_buf_(val, idx);
|
||||
}
|
||||
|
||||
// after done(), smem is used for merging results among warps
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
private:
|
||||
__device__ void set_k_th_() {
|
||||
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
if constexpr (is_stable) {
|
||||
k_th_idx_ =
|
||||
__shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void merge_buf_(T val, idxT idx) {
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);
|
||||
|
||||
T& old = val_arr_[max_arr_len_ - 1];
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(val, old);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
old = val;
|
||||
idx_arr_[max_arr_len_ - 1] = idx;
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
|
||||
set_k_th_();
|
||||
}
|
||||
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;
|
||||
|
||||
T* val_smem_;
|
||||
idxT* idx_smem_;
|
||||
int smem_buf_len_ = 0;
|
||||
|
||||
T k_th_;
|
||||
idxT k_th_idx_;
|
||||
int const k_th_lane_;
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T neg_inf() {
|
||||
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
|
||||
// so we need to cast from fp32
|
||||
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline bool is_finite(const T val) {
|
||||
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800)
|
||||
return cuda::std::isfinite(val);
|
||||
#else
|
||||
return isfinite(cuda_cast<float, T>(val));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Scoring function enums
|
||||
enum ScoringFunc {
|
||||
SCORING_NONE = 0, // no activation function
|
||||
SCORING_SIGMOID = 1 // apply sigmoid
|
||||
};
|
||||
|
||||
// Efficient sigmoid approximation from TensorRT-LLM
|
||||
__device__ inline float sigmoid_accurate(float x) {
|
||||
return 0.5f * tanhf(0.5f * x) + 0.5f;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T apply_sigmoid(T val) {
|
||||
float f = cuda_cast<float, T>(val);
|
||||
return cuda_cast<T, float>(sigmoid_accurate(f));
|
||||
}
|
||||
|
||||
template <ScoringFunc SF, typename T>
|
||||
__device__ inline T apply_scoring(T val) {
|
||||
if constexpr (SF == SCORING_SIGMOID) {
|
||||
return apply_sigmoid(val);
|
||||
} else {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, ScoringFunc SF>
|
||||
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
||||
cg::thread_block_tile<32> const& tile,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = neg_inf<T>();
|
||||
T second_largest = neg_inf<T>();
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
T value = apply_scoring<SF>(input[i]);
|
||||
value = value + bias[i];
|
||||
|
||||
if (value > largest) {
|
||||
second_largest = largest;
|
||||
largest = value;
|
||||
} else if (value > second_largest) {
|
||||
second_largest = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
T value = apply_scoring<SF>(input[i]);
|
||||
value = value + bias[i];
|
||||
largest = value;
|
||||
}
|
||||
}
|
||||
// Get the top2 warpwise
|
||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
|
||||
T max2 = max1;
|
||||
bool equal_to_max1 = (max1 == largest);
|
||||
|
||||
int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));
|
||||
|
||||
if (count_max1 == 1) {
|
||||
largest = (largest == max1) ? second_largest : largest;
|
||||
max2 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
}
|
||||
|
||||
if (lane_id == 0) {
|
||||
*output = max1 + max2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, ScoringFunc SF>
|
||||
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_cases,
|
||||
int64_t const n_group,
|
||||
int64_t const num_experts_per_group) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
|
||||
if (case_id < num_cases) {
|
||||
input += case_id * num_experts_per_group;
|
||||
// bias is per expert group, offset to current group
|
||||
int32_t group_id = case_id % n_group;
|
||||
T const* group_bias = bias + group_id * num_experts_per_group;
|
||||
output += case_id;
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
topk_with_k2<T, SF>(output, input, group_bias, tile, lane_id,
|
||||
num_experts_per_group);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, ScoringFunc SF, int NGroup = -1>
|
||||
__global__ void group_idx_and_topk_idx_kernel(
|
||||
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
|
||||
T const* bias, int64_t const num_tokens, int64_t const n_group,
|
||||
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
|
||||
int64_t const num_experts_per_group, bool renormalize,
|
||||
double routed_scaling_factor) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
int32_t case_id =
|
||||
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
|
||||
scores += case_id * num_experts;
|
||||
group_scores += case_id * n_group;
|
||||
topk_values += case_id * topk;
|
||||
topk_indices += case_id * topk;
|
||||
|
||||
constexpr bool kUseStaticNGroup = (NGroup > 0);
|
||||
// use int32 to avoid implicit conversion
|
||||
int32_t const n_group_i32 =
|
||||
kUseStaticNGroup ? NGroup : static_cast<int32_t>(n_group);
|
||||
|
||||
int32_t align_num_experts_per_group =
|
||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = neg_inf<T>();
|
||||
T topk_group_value = neg_inf<T>();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
||||
// acqbulk because it's ptr arithmetic
|
||||
#endif
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min =
|
||||
WARP_SIZE - n_group_i32 + static_cast<int32_t>(topk_group);
|
||||
// The check is necessary to avoid abnormal input
|
||||
if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) {
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
int count_equal_to_top_value = WARP_SIZE - n_group_i32;
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = neg_inf<T>();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value =
|
||||
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, neg_inf<T>());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
auto process_group = [&](int i_group) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates = neg_inf<T>();
|
||||
if (i < num_experts_per_group) {
|
||||
// apply scoring function (if any) and add bias
|
||||
T input = scores[offset + i];
|
||||
if (is_finite(input)) {
|
||||
T score = apply_scoring<SF>(input);
|
||||
candidates = score + bias[offset + i];
|
||||
}
|
||||
}
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
count_equalto_topkth_group++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr (kUseStaticNGroup) {
|
||||
#pragma unroll
|
||||
for (int i_group = 0; i_group < NGroup; ++i_group) {
|
||||
process_group(i_group);
|
||||
}
|
||||
} else {
|
||||
for (int i_group = 0; i_group < n_group_i32; ++i_group) {
|
||||
process_group(i_group);
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
T value = cuda_cast<T, float>(0.0f);
|
||||
if (i < topk) {
|
||||
// Load the score value (without bias) for normalization
|
||||
T input = scores[s_topk_idx[i]];
|
||||
value = apply_scoring<SF>(input);
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
if (renormalize) {
|
||||
topk_sum +=
|
||||
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if (if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float base = cuda_cast<float, T>(s_topk_value[i]);
|
||||
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
|
||||
: (base * routed_scaling_factor);
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = value;
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
topk_indices[i] = i;
|
||||
topk_values[i] = 1.0f / topk;
|
||||
}
|
||||
}
|
||||
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
||||
// default result.
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT, ScoringFunc SF>
|
||||
inline void launch_group_idx_and_topk_kernel(
|
||||
cudaLaunchConfig_t const& config, T* scores, T* group_scores,
|
||||
float* topk_values, IdxT* topk_indices, T const* bias,
|
||||
int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
|
||||
int64_t const topk, int64_t const num_experts,
|
||||
int64_t const num_experts_per_group, bool const renormalize,
|
||||
double const routed_scaling_factor) {
|
||||
auto launch = [&](auto* kernel_instance2) {
|
||||
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
||||
topk_values, topk_indices, bias, num_tokens, n_group,
|
||||
topk_group, topk, num_experts, num_experts_per_group,
|
||||
renormalize, routed_scaling_factor);
|
||||
};
|
||||
|
||||
switch (n_group) {
|
||||
case 4: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 4>);
|
||||
break;
|
||||
}
|
||||
case 8: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 8>);
|
||||
break;
|
||||
}
|
||||
case 16: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 16>);
|
||||
break;
|
||||
}
|
||||
case 32: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 32>);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF>);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
||||
IdxT* topk_indices, T const* bias, int64_t const num_tokens,
|
||||
int64_t const num_experts, int64_t const n_group,
|
||||
int64_t const topk_group, int64_t const topk,
|
||||
bool const renormalize, double const routed_scaling_factor,
|
||||
int const scoring_func, bool enable_pdl = false,
|
||||
cudaStream_t const stream = 0) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
auto const sf = static_cast<ScoringFunc>(scoring_func);
|
||||
int64_t const num_experts_per_group = num_experts / n_group;
|
||||
auto launch_topk_with_k2 = [&](auto* kernel_instance1) {
|
||||
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
|
||||
num_tokens, num_cases, n_group, num_experts_per_group);
|
||||
};
|
||||
switch (sf) {
|
||||
case SCORING_NONE: {
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_NONE>;
|
||||
launch_topk_with_k2(kernel_instance1);
|
||||
break;
|
||||
}
|
||||
case SCORING_SIGMOID: {
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_SIGMOID>;
|
||||
launch_topk_with_k2(kernel_instance1);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
// should be guarded by higher level checks.
|
||||
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
|
||||
}
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
size_t dynamic_smem_in_bytes =
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||
config.stream = stream;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
switch (sf) {
|
||||
case SCORING_NONE: {
|
||||
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_NONE>(
|
||||
config, scores, group_scores, topk_values, topk_indices, bias,
|
||||
num_tokens, n_group, topk_group, topk, num_experts,
|
||||
num_experts_per_group, renormalize, routed_scaling_factor);
|
||||
break;
|
||||
}
|
||||
case SCORING_SIGMOID: {
|
||||
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_SIGMOID>(
|
||||
config, scores, group_scores, topk_values, topk_indices, bias,
|
||||
num_tokens, n_group, topk_group, topk, num_experts,
|
||||
num_experts_per_group, renormalize, routed_scaling_factor);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
template void invokeNoAuxTc<T, IdxT>( \
|
||||
T * scores, T * group_scores, float* topk_values, IdxT* topk_indices, \
|
||||
T const* bias, int64_t const num_tokens, int64_t const num_experts, \
|
||||
int64_t const n_group, int64_t const topk_group, int64_t const topk, \
|
||||
bool const renormalize, double const routed_scaling_factor, \
|
||||
int const scoring_func, bool enable_pdl, cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
|
||||
} // end namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
torch::Tensor const& scores, int64_t n_group, int64_t topk_group,
|
||||
int64_t topk, bool renormalize, double routed_scaling_factor,
|
||||
torch::Tensor const& bias, int64_t scoring_func = 0) {
|
||||
auto data_type = scores.scalar_type();
|
||||
auto input_size = scores.sizes();
|
||||
int64_t num_tokens = input_size[0];
|
||||
int64_t num_experts = input_size[1];
|
||||
TORCH_CHECK(input_size.size() == 2, "scores must be a 2D Tensor");
|
||||
TORCH_CHECK(num_experts % n_group == 0,
|
||||
"num_experts should be divisible by n_group");
|
||||
TORCH_CHECK(n_group <= 32,
|
||||
"n_group should be smaller than or equal to 32 for now");
|
||||
TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
|
||||
TORCH_CHECK(scoring_func == vllm::moe::SCORING_NONE ||
|
||||
scoring_func == vllm::moe::SCORING_SIGMOID,
|
||||
"scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)");
|
||||
|
||||
torch::Tensor group_scores = torch::empty(
|
||||
{num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
|
||||
// Always output float32 for topk_values (eliminates Python-side conversion)
|
||||
torch::Tensor topk_values = torch::empty(
|
||||
{num_tokens, topk}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
|
||||
torch::Tensor topk_indices = torch::empty(
|
||||
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
auto stream = c10::cuda::getCurrentCUDAStream(scores.get_device());
|
||||
|
||||
switch (data_type) {
|
||||
case torch::kFloat16:
|
||||
// Handle Float16
|
||||
vllm::moe::invokeNoAuxTc<half, int32_t>(
|
||||
reinterpret_cast<half*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<half const*>(bias.data_ptr()), num_tokens,
|
||||
num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
|
||||
break;
|
||||
case torch::kFloat32:
|
||||
// Handle Float32
|
||||
vllm::moe::invokeNoAuxTc<float, int32_t>(
|
||||
reinterpret_cast<float*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<float const*>(bias.data_ptr()), num_tokens,
|
||||
num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
|
||||
break;
|
||||
case torch::kBFloat16:
|
||||
// Handle BFloat16
|
||||
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
|
||||
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16 const*>(bias.data_ptr()), num_tokens,
|
||||
num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, static_cast<int>(scoring_func), false, stream);
|
||||
break;
|
||||
default:
|
||||
// Handle other data types
|
||||
throw std::invalid_argument(
|
||||
"Invalid dtype, only supports float16, float32, and bfloat16");
|
||||
break;
|
||||
}
|
||||
return {topk_values, topk_indices};
|
||||
}
|
||||
2
csrc/moe/marlin_moe_wna16/.gitignore
vendored
Normal file
2
csrc/moe/marlin_moe_wna16/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
sm*_kernel_*.cu
|
||||
kernel_selector.h
|
||||
286
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
286
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import jinja2
|
||||
|
||||
ARCHS = []
|
||||
SUPPORT_FP8 = False
|
||||
for arch in sys.argv[1].split(","):
|
||||
arch = arch[: arch.index(".") + 2].replace(".", "")
|
||||
arch = int(arch)
|
||||
# only SM89 and SM120 fully support
|
||||
# mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
|
||||
# SM90 and SM100 can use this PTX, but it’s simulated
|
||||
# with FP16 MMA, so it cannot achieve any acceleration.
|
||||
if arch in [89, 120]:
|
||||
SUPPORT_FP8 = True
|
||||
|
||||
FILE_HEAD_COMMENT = """
|
||||
// auto generated by generate_kernels.py
|
||||
// clang-format off
|
||||
""".lstrip()
|
||||
|
||||
FILE_HEAD = (
|
||||
FILE_HEAD_COMMENT
|
||||
+ """
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
"""
|
||||
)
|
||||
|
||||
TEMPLATE = (
|
||||
"template __global__ void Marlin<"
|
||||
"{{a_type_id}}, "
|
||||
"{{b_type_id}}, "
|
||||
"{{c_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, "
|
||||
"{{stages}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{is_zp_float}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
|
||||
QUANT_CONFIGS = [
|
||||
# AWQ-INT4
|
||||
{
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4
|
||||
{
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 0, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT8
|
||||
{
|
||||
"b_type": "kU8B128",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 0, 2, 4, 8],
|
||||
},
|
||||
# FP8
|
||||
{
|
||||
"b_type": "kFE4M3fn",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [-1, 8],
|
||||
},
|
||||
# NVFP4
|
||||
{
|
||||
"b_type": "kFE2M1f",
|
||||
"s_type": "kFE4M3fn",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [1],
|
||||
},
|
||||
# MXFP4
|
||||
{
|
||||
"a_type": ["kBFloat16"],
|
||||
"b_type": "kFE2M1f",
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": THREAD_M_BLOCKS,
|
||||
"group_blocks": [2],
|
||||
},
|
||||
# AWQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with INT8 activation
|
||||
{
|
||||
"a_type": ["kS8"],
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# GPTQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kU4B8",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# AWQ-INT4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kU4",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [-1, 2, 4, 8],
|
||||
},
|
||||
# MXFP4 with FP8 activation
|
||||
{
|
||||
"a_type": ["kFE4M3fn"],
|
||||
"b_type": "kFE2M1f",
|
||||
"c_type": ["kBFloat16"],
|
||||
"s_type": "kFE8M0fnu",
|
||||
"thread_configs": THREAD_CONFIGS,
|
||||
"thread_m_blocks": [1, 2, 3, 4],
|
||||
"group_blocks": [2],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
filename = os.path.dirname(__file__) + "/kernel_selector.h"
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
result_dict = {}
|
||||
|
||||
for quant_config in QUANT_CONFIGS:
|
||||
c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
|
||||
a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
|
||||
b_type = quant_config["b_type"]
|
||||
all_group_blocks = quant_config["group_blocks"]
|
||||
all_m_blocks = quant_config["thread_m_blocks"]
|
||||
all_thread_configs = quant_config["thread_configs"]
|
||||
|
||||
for a_type, c_type in itertools.product(a_types, c_types):
|
||||
if not SUPPORT_FP8 and a_type == "kFE4M3fn":
|
||||
continue
|
||||
if "16" in a_type and "16" in c_type and a_type != c_type:
|
||||
continue
|
||||
s_type = quant_config.get("s_type", c_type)
|
||||
if (a_type, b_type, c_type) not in result_dict:
|
||||
result_dict[(a_type, b_type, c_type)] = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
all_group_blocks, all_m_blocks, all_thread_configs
|
||||
):
|
||||
thread_k, thread_n, threads = thread_configs
|
||||
|
||||
if threads == 256:
|
||||
# for small batch (m_blocks == 1),
|
||||
# we only need (128, 128, 256)
|
||||
# for large batch (m_blocks > 1),
|
||||
# we only need (64, 256, 256)
|
||||
if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
|
||||
continue
|
||||
if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
|
||||
continue
|
||||
|
||||
config = {
|
||||
"threads": threads,
|
||||
"s_type": s_type,
|
||||
"thread_m_blocks": max(m_blocks, 1),
|
||||
"thread_k_blocks": thread_k // 16,
|
||||
"thread_n_blocks": thread_n // 16,
|
||||
"m_block_size_8": "true" if m_blocks == 0.5 else "false",
|
||||
"stages": "pipe_stages",
|
||||
"group_blocks": group_blocks,
|
||||
"is_zp_float": "false",
|
||||
}
|
||||
|
||||
result_dict[(a_type, b_type, c_type)].append(config)
|
||||
|
||||
kernel_selector_str = FILE_HEAD_COMMENT
|
||||
|
||||
for (a_type, b_type, c_type), config_list in result_dict.items():
|
||||
all_template_str_list = []
|
||||
for config in config_list:
|
||||
s_type = config["s_type"]
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
conditions = [
|
||||
f"a_type == vllm::{a_type}",
|
||||
f"b_type == vllm::{b_type}",
|
||||
f"c_type == vllm::{c_type}",
|
||||
f"s_type == vllm::{s_type}",
|
||||
f"threads == {config['threads']}",
|
||||
f"thread_m_blocks == {config['thread_m_blocks']}",
|
||||
f"thread_n_blocks == {config['thread_n_blocks']}",
|
||||
f"thread_k_blocks == {config['thread_k_blocks']}",
|
||||
f"m_block_size_8 == {config['m_block_size_8']}",
|
||||
f"group_blocks == {config['group_blocks']}",
|
||||
f"is_zp_float == {config['is_zp_float']}",
|
||||
]
|
||||
conditions = " && ".join(conditions)
|
||||
|
||||
if kernel_selector_str == FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += f"if ({conditions})\n kernel = "
|
||||
else:
|
||||
kernel_selector_str += f"else if ({conditions})\n kernel = "
|
||||
|
||||
kernel_template2 = (
|
||||
"Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
|
||||
"{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, {{thread_k_blocks}}, "
|
||||
"{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
|
||||
"{{is_zp_float}}>;"
|
||||
)
|
||||
|
||||
kernel_selector_str += (
|
||||
jinja2.Template(kernel_template2).render(
|
||||
a_type_id=f"vllm::{a_type}.id()",
|
||||
b_type_id=f"vllm::{b_type}.id()",
|
||||
c_type_id=f"vllm::{c_type}.id()",
|
||||
s_type_id=f"vllm::{s_type}.id()",
|
||||
**config,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
if a_type == "kFE4M3fn":
|
||||
filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
else:
|
||||
filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
|
||||
|
||||
filename = filename.lower()
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
|
||||
kernel_selector_str += (
|
||||
"else if (a_type == vllm::kFE4M3fn)\n"
|
||||
" TORCH_CHECK(false, "
|
||||
'"marlin kernel with fp8 activation is not built.");'
|
||||
)
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
|
||||
f.write(kernel_selector_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
generate_new_kernels()
|
||||
47
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
47
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
@@ -0,0 +1,47 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const float *__restrict__ a_scales_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ global_scale_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
|
||||
bool use_fp32_reduce
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <const vllm::ScalarTypeId a_type_id, // A ScalarType id
|
||||
const vllm::ScalarTypeId b_type_id, // B ScalarType id
|
||||
const vllm::ScalarTypeId c_type_id, // C ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // B_SCALE ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
2352
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
2352
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
863
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
863
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
@@ -0,0 +1,863 @@
|
||||
/*
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "kernel.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||
"only float16 and bfloat16 is supported");
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
||||
// on the given "perm" indices.
|
||||
template <int moe_block_size>
|
||||
__global__ void permute_cols_kernel(
|
||||
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr,
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||
const int32_t* __restrict__ expert_ids_ptr,
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||
int size_k, int top_k) {
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);
|
||||
int32_t block_sorted_ids[moe_block_size];
|
||||
int block_num_valid_tokens = 0;
|
||||
int64_t old_expert_id = 0;
|
||||
int64_t expert_id = 0;
|
||||
int row_stride = size_k * sizeof(half) / 16;
|
||||
|
||||
auto read_moe_block_data = [&](int block_id) {
|
||||
block_num_valid_tokens = moe_block_size;
|
||||
int4* tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);
|
||||
for (int i = 0; i < moe_block_size / 4; i++) {
|
||||
tmp_block_sorted_ids[i] =
|
||||
((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
|
||||
}
|
||||
for (int i = 0; i < moe_block_size; i++) {
|
||||
if (block_sorted_ids[i] >= size_m * top_k) {
|
||||
block_num_valid_tokens = i;
|
||||
break;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
auto permute_row = [&](int row) {
|
||||
int iters = size_k / default_threads;
|
||||
int rest = size_k % default_threads;
|
||||
|
||||
int in_offset = (row / top_k) * row_stride;
|
||||
int out_offset = row * row_stride;
|
||||
|
||||
half const* a_row_half =
|
||||
reinterpret_cast<half const*>(a_int4_ptr + in_offset);
|
||||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);
|
||||
|
||||
int base_k = 0;
|
||||
|
||||
for (int i = 0; i < iters; i++) {
|
||||
auto cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
|
||||
base_k += default_threads;
|
||||
}
|
||||
|
||||
if (rest) {
|
||||
if (threadIdx.x < rest) {
|
||||
auto cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {
|
||||
old_expert_id = expert_id;
|
||||
int tmp_expert_id = expert_ids_ptr[index];
|
||||
if (tmp_expert_id == -1) continue;
|
||||
expert_id = tmp_expert_id;
|
||||
perm_int_ptr += (expert_id - old_expert_id) * size_k;
|
||||
read_moe_block_data(index);
|
||||
|
||||
for (int i = 0; i < block_num_valid_tokens; i++)
|
||||
permute_row(block_sorted_ids[i]);
|
||||
}
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
int thread_k;
|
||||
int thread_n;
|
||||
int num_threads;
|
||||
} thread_config_t;
|
||||
|
||||
thread_config_t small_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{128, 128, 256},
|
||||
{64, 128, 128}};
|
||||
|
||||
thread_config_t large_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{64, 256, 256},
|
||||
{64, 128, 128}};
|
||||
|
||||
typedef struct {
|
||||
int blocks_per_sm;
|
||||
thread_config_t tb_cfg;
|
||||
} exec_config_t;
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_k = th_config.thread_k;
|
||||
|
||||
// Get max scale groups per thread-block
|
||||
int tb_groups;
|
||||
if (group_size == -1) {
|
||||
tb_groups = 1;
|
||||
} else if (group_size == 0) {
|
||||
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
|
||||
} else {
|
||||
tb_groups = div_ceil(tb_k, group_size);
|
||||
}
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n,
|
||||
int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full, int has_zp,
|
||||
int is_zp_float, bool is_a_8bit) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
|
||||
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2);
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
|
||||
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
sh_zp_size = sh_s_size;
|
||||
else if (num_bits == 4)
|
||||
sh_zp_size = sh_s_size / 4;
|
||||
else if (num_bits == 8)
|
||||
sh_zp_size = sh_s_size / 2;
|
||||
}
|
||||
|
||||
int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size +
|
||||
sh_g_idx_size + sh_block_meta_size;
|
||||
|
||||
return total_size;
|
||||
}
|
||||
|
||||
bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
|
||||
int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, int has_zp, int is_zp_float,
|
||||
int max_shared_mem, bool is_a_8bit) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify K/N are divisible by thread K/N
|
||||
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify min for thread K/N
|
||||
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// num_threads must be at least 128 (= 4 warps)
|
||||
if (th_config.num_threads < 128) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size =
|
||||
get_kernel_cache_size(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
MarlinFuncPtr get_marlin_kernel(
|
||||
const vllm::ScalarType a_type, const vllm::ScalarType b_type,
|
||||
const vllm::ScalarType c_type, const vllm::ScalarType s_type,
|
||||
int thread_m_blocks, int thread_n_blocks, int thread_k_blocks,
|
||||
bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks,
|
||||
int threads, bool is_zp_float) {
|
||||
int num_bits = b_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
|
||||
#include "kernel_selector.h"
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
exec_config_t determine_exec_config(
|
||||
const vllm::ScalarType& a_type, const vllm::ScalarType& b_type,
|
||||
const vllm::ScalarType& c_type, const vllm::ScalarType& s_type, int prob_m,
|
||||
int prob_n, int prob_k, int num_experts, int top_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits, int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, bool is_zp_float, int max_shared_mem, int sms,
|
||||
bool is_a_8bit) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
: small_batch_thread_configs;
|
||||
int thread_configs_size =
|
||||
thread_m_blocks > 1
|
||||
? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)
|
||||
: sizeof(small_batch_thread_configs) / sizeof(thread_config_t);
|
||||
|
||||
int count = 0;
|
||||
constexpr int device_max_reg_size = 255 * 1024;
|
||||
for (int i = 0; i < thread_configs_size; i++) {
|
||||
thread_config_t th_config = thread_configs[i];
|
||||
|
||||
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem - 512,
|
||||
is_a_8bit)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
is_a_8bit);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
group_blocks = group_size == -1 ? -1 : (group_size / 16);
|
||||
}
|
||||
|
||||
auto kernel =
|
||||
get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks,
|
||||
th_config.thread_n / 16, th_config.thread_k / 16,
|
||||
m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
th_config.num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
cudaFuncAttributes attr;
|
||||
cudaFuncGetAttributes(&attr, kernel);
|
||||
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
|
||||
int allow_count = min(device_max_reg_size / reg_size,
|
||||
max_shared_mem / (cache_size + 1536));
|
||||
if (thread_m_blocks == 1)
|
||||
allow_count = max(min(allow_count, 4), 1);
|
||||
else
|
||||
allow_count = max(min(allow_count, 2), 1);
|
||||
|
||||
if (prob_n / th_config.thread_n * prob_m * top_k * 4 < sms * allow_count) {
|
||||
allow_count =
|
||||
max(prob_n / th_config.thread_n * prob_m * top_k * 4 / sms, 1);
|
||||
}
|
||||
|
||||
if (allow_count > count) {
|
||||
count = allow_count;
|
||||
exec_cfg = {count, th_config};
|
||||
};
|
||||
}
|
||||
|
||||
return exec_cfg;
|
||||
}
|
||||
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
void* a_s, void* b_s, void* g_s, void* zp, void* g_idx,
|
||||
void* perm, void* a_tmp, void* sorted_token_ids,
|
||||
void* expert_ids, void* num_tokens_past_padded,
|
||||
void* topk_weights, int moe_block_size, int num_experts,
|
||||
int top_k, bool mul_topk_weights, bool is_ep, int prob_m,
|
||||
int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& a_type, vllm::ScalarType const& b_type,
|
||||
vllm::ScalarType const& c_type, vllm::ScalarType const& s_type,
|
||||
bool has_bias, bool has_act_order, bool is_k_full, bool has_zp,
|
||||
int num_groups, int group_size, int dev, cudaStream_t stream,
|
||||
int thread_k, int thread_n, int sms, int blocks_per_sm,
|
||||
bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) {
|
||||
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
bool is_a_8bit = a_type.size_bits() == 8;
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
int group_blocks = 0;
|
||||
if (has_act_order) {
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(group_size != -1);
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
} else {
|
||||
TORCH_CHECK(group_size == 0);
|
||||
group_blocks = 0;
|
||||
}
|
||||
} else {
|
||||
if (group_size == -1) {
|
||||
group_blocks = -1;
|
||||
} else {
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
int num_bits = b_type.size_bits();
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* bias_ptr = (const int4*)b_bias;
|
||||
const float* a_s_ptr = (const float*)a_s;
|
||||
const int4* b_s_ptr = (const int4*)b_s;
|
||||
const uint16_t* g_s_ptr = (const uint16_t*)g_s;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||
const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids;
|
||||
const int32_t* expert_ids_ptr = (const int32_t*)expert_ids;
|
||||
const int32_t* num_tokens_past_padded_ptr =
|
||||
(const int32_t*)num_tokens_past_padded;
|
||||
const float* topk_weights_ptr = (const float*)topk_weights;
|
||||
int* locks = (int*)workspace;
|
||||
|
||||
if (has_act_order) {
|
||||
// Permute A columns
|
||||
auto kernel = permute_cols_kernel<8>;
|
||||
if (moe_block_size == 8) {
|
||||
} else if (moe_block_size == 16)
|
||||
kernel = permute_cols_kernel<16>;
|
||||
else if (moe_block_size == 32)
|
||||
kernel = permute_cols_kernel<32>;
|
||||
else if (moe_block_size == 48)
|
||||
kernel = permute_cols_kernel<48>;
|
||||
else if (moe_block_size == 64)
|
||||
kernel = permute_cols_kernel<64>;
|
||||
else
|
||||
TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size);
|
||||
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<sms, default_threads, 0, stream>>>(
|
||||
A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr,
|
||||
num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
|
||||
// clang-format on
|
||||
A_ptr = a_tmp_ptr;
|
||||
prob_m = prob_m * top_k;
|
||||
top_k = 1;
|
||||
|
||||
// If we have a full K, then we can run the non-act-order version of Marlin
|
||||
// (since the weight rows are reordered by increasing group ids, and by
|
||||
// having a full K, we have full original groups)
|
||||
if (is_k_full) has_act_order = false;
|
||||
}
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
int major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
dev);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
dev);
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 80,
|
||||
"marlin kernel only support Ampere or newer GPUs.");
|
||||
if (a_type == vllm::kFE4M3fn) {
|
||||
TORCH_CHECK(major_capability * 10 + minor_capability >= 89,
|
||||
"FP8 only support Ada Lovelace or newer GPUs.");
|
||||
TORCH_CHECK(
|
||||
major_capability * 10 + minor_capability == 89 ||
|
||||
major_capability * 10 + minor_capability == 120,
|
||||
"Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than "
|
||||
"Marlin W4A16 on other devices).");
|
||||
}
|
||||
|
||||
// Set thread config
|
||||
exec_config_t exec_cfg;
|
||||
thread_config_t thread_tfg;
|
||||
if (thread_k != -1 && thread_n != -1) {
|
||||
thread_tfg = thread_config_t{thread_k, thread_n, thread_k * thread_n / 64};
|
||||
if (blocks_per_sm == -1) blocks_per_sm = 1;
|
||||
exec_cfg = exec_config_t{blocks_per_sm, thread_tfg};
|
||||
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
||||
" is not divisible by thread_n = ", thread_n);
|
||||
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by thread_k = ", thread_k);
|
||||
} else {
|
||||
// Auto config
|
||||
exec_cfg = determine_exec_config(
|
||||
a_type, b_type, c_type, s_type, prob_m, prob_n, prob_k, num_experts,
|
||||
top_k, thread_m_blocks, m_block_size_8, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float, max_shared_mem, sms,
|
||||
is_a_8bit);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
}
|
||||
|
||||
int num_threads = thread_tfg.num_threads;
|
||||
thread_k = thread_tfg.thread_k;
|
||||
thread_n = thread_tfg.thread_n;
|
||||
int blocks = sms * exec_cfg.blocks_per_sm;
|
||||
if (exec_cfg.blocks_per_sm > 1)
|
||||
max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024;
|
||||
|
||||
int thread_k_blocks = thread_k / 16;
|
||||
int thread_n_blocks = thread_n / 16;
|
||||
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks,
|
||||
prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem, is_a_8bit),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
|
||||
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||
", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem = ", max_shared_mem);
|
||||
|
||||
int sh_cache_size =
|
||||
get_kernel_cache_size(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
|
||||
prob_n, prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, is_a_8bit);
|
||||
|
||||
auto kernel = get_marlin_kernel(
|
||||
a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks,
|
||||
thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks,
|
||||
num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||
", num_groups = ", num_groups, ", group_size = ", group_size,
|
||||
", thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_n_blocks = ", thread_n_blocks,
|
||||
", thread_k_blocks = ", thread_k_blocks,
|
||||
", num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_mem);
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& a_scales_or_none,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
vllm::ScalarTypeId const& b_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float, int64_t thread_k, int64_t thread_n,
|
||||
int64_t blocks_per_sm) {
|
||||
vllm::ScalarTypeId a_type_id, c_type_id, s_type_id;
|
||||
|
||||
auto c_dtype = a.dtype();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
a_type_id = vllm::kFloat16.id();
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
a_type_id = vllm::kBFloat16.id();
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
c_dtype = b_scales.dtype();
|
||||
if (b_scales.scalar_type() == at::ScalarType::Half) {
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (b_scales.scalar_type() == at::ScalarType::BFloat16) {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
|
||||
TORCH_CHECK(c_or_none.has_value(), "c must be passed for W4A8-FP4");
|
||||
torch::Tensor c = c_or_none.value();
|
||||
c_dtype = c.dtype();
|
||||
|
||||
if (c.scalar_type() == at::ScalarType::Half) {
|
||||
c_type_id = vllm::kFloat16.id();
|
||||
} else if (c.scalar_type() == at::ScalarType::BFloat16) {
|
||||
c_type_id = vllm::kBFloat16.id();
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported c dtype");
|
||||
}
|
||||
}
|
||||
|
||||
if (a.scalar_type() == at::ScalarType::Float8_e4m3fn) {
|
||||
a_type_id = vllm::kFE4M3fn.id();
|
||||
} else if (a.scalar_type() == at::ScalarType::Char) {
|
||||
a_type_id = vllm::kS8.id();
|
||||
} else {
|
||||
TORCH_CHECK(false, "unsupported `a` scalar_type");
|
||||
}
|
||||
}
|
||||
|
||||
s_type_id = c_type_id;
|
||||
if (b_type_id == vllm::kFE2M1f.id()) {
|
||||
if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) {
|
||||
s_type_id = vllm::kFE4M3fn.id();
|
||||
} else if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) {
|
||||
s_type_id = vllm::kFE8M0fnu.id();
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"When b_type = float4_e2m1f, b_scale scalar type must be",
|
||||
"float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4).");
|
||||
}
|
||||
}
|
||||
|
||||
vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id);
|
||||
vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id);
|
||||
vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id);
|
||||
vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id);
|
||||
|
||||
int pack_factor = 32 / b_type.size_bits();
|
||||
int num_experts = b_q_weight.size(0);
|
||||
|
||||
if (moe_block_size != 8) {
|
||||
TORCH_CHECK(moe_block_size % 16 == 0,
|
||||
"unsupported moe_block_size=", moe_block_size);
|
||||
TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64,
|
||||
"unsupported moe_block_size=", moe_block_size);
|
||||
}
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
", size_m = ", size_m);
|
||||
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
|
||||
", size_k = ", size_k);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(
|
||||
size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1),
|
||||
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
", size_k = ", size_k,
|
||||
", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK(
|
||||
b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0,
|
||||
"b_q_weight.size(2) = ", b_q_weight.size(2),
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||
", actual_size_n = ", actual_size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
torch::Tensor a_scales;
|
||||
auto options = torch::TensorOptions().dtype(c_dtype).device(a.device());
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
|
||||
if (a_scales_or_none.has_value()) {
|
||||
a_scales = a_scales_or_none.value();
|
||||
TORCH_CHECK(a_type.size_bits() == 8,
|
||||
"a_scales can only be used for 8bit activation.");
|
||||
} else {
|
||||
a_scales = torch::empty({0}, options_fp32);
|
||||
TORCH_CHECK(a_type.size_bits() != 8,
|
||||
"the a_scales parameter must be passed for 8bit activation.");
|
||||
}
|
||||
|
||||
// sms: number of SMs to use for the kernel
|
||||
int sms = -1;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
torch::Tensor c;
|
||||
if (c_or_none.has_value()) {
|
||||
c = c_or_none.value();
|
||||
TORCH_CHECK(c.device().is_cuda(), "c is not on GPU");
|
||||
TORCH_CHECK(c.is_contiguous(), "c is not contiguous");
|
||||
TORCH_CHECK(c.size(0) == size_m * top_k,
|
||||
"Shape mismatch: c.size(0) = ", c.size(0),
|
||||
", size_m * topk = ", size_m * top_k);
|
||||
TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1),
|
||||
", size_n = ", size_n);
|
||||
} else {
|
||||
c = torch::empty({size_m * top_k, size_n}, options);
|
||||
}
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
torch::Tensor c_tmp;
|
||||
if (use_fp32_reduce && !use_atomic_add) {
|
||||
// max num of threadblocks is sms * 4
|
||||
long max_c_tmp_size = min(
|
||||
(long)size_n * sorted_token_ids.size(0),
|
||||
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
|
||||
if (moe_block_size == 8) max_c_tmp_size *= 2;
|
||||
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
|
||||
} else {
|
||||
c_tmp = torch::empty({0}, options_fp32);
|
||||
}
|
||||
|
||||
// Detect groupsize and act_order
|
||||
int num_groups = -1;
|
||||
int group_size = -1;
|
||||
|
||||
int rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3");
|
||||
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
|
||||
" is not size_n = ", size_n);
|
||||
num_groups = b_scales.size(1);
|
||||
|
||||
torch::Tensor g_idx, perm, a_tmp;
|
||||
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
|
||||
g_idx = g_idx_or_none.value();
|
||||
perm = perm_or_none.value();
|
||||
|
||||
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
||||
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
|
||||
// Verify g_idx and perm
|
||||
TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) ||
|
||||
(g_idx.size(-1) == size_k && perm.size(-1) == size_k),
|
||||
"Unexpected g_idx.size(-1) = ", g_idx.size(-1),
|
||||
" and perm.size(-1) = ", perm.size(-1),
|
||||
", where size_k = ", size_k);
|
||||
} else {
|
||||
g_idx = torch::empty({0}, options);
|
||||
perm = torch::empty({0}, options);
|
||||
a_tmp = torch::empty({0}, options);
|
||||
}
|
||||
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
|
||||
|
||||
if (has_act_order) {
|
||||
a_tmp = torch::empty({size_m * top_k, size_k}, options);
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
||||
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by num_groups = ", num_groups);
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = 0;
|
||||
}
|
||||
|
||||
} else {
|
||||
a_tmp = torch::empty({0}, options);
|
||||
if (num_groups > 1) {
|
||||
TORCH_CHECK(
|
||||
size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by b_scales.size(1) = ", b_scales.size(1));
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = -1;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn,
|
||||
"global_scale can only be used for nvfp4 format.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn),
|
||||
"the global_scale parameter must be passed for nvfp4 format.");
|
||||
}
|
||||
|
||||
bool has_bias = b_bias_or_none.has_value();
|
||||
torch::Tensor b_bias;
|
||||
if (has_bias) {
|
||||
b_bias = b_bias_or_none.value();
|
||||
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
|
||||
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
|
||||
TORCH_CHECK(b_bias.size(1) == size_n, "b_bias.size(0) != size_n");
|
||||
TORCH_CHECK(b_bias.stride(1) == 1, "b_bias.stride(1) != 1");
|
||||
} else {
|
||||
b_bias = torch::empty({0}, options);
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
if (b_zeros_or_none.has_value()) {
|
||||
b_zeros = b_zeros_or_none.value();
|
||||
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
||||
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
||||
} else {
|
||||
b_zeros = torch::empty({0}, options);
|
||||
}
|
||||
bool has_zp = b_zeros.size(-1) > 0;
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_type == vllm::kU4 || b_type == vllm::kU8,
|
||||
"b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 ||
|
||||
b_type == vllm::kS4 || b_type == vllm::kS8 ||
|
||||
b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f,
|
||||
"b_type must be uint4b8, uint8b128, int4, int8, "
|
||||
"float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ",
|
||||
b_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
|
||||
"Computation type must be float16 (half) when using float zero "
|
||||
"points.");
|
||||
}
|
||||
|
||||
// Verify b_zeros
|
||||
if (has_zp) {
|
||||
int rank = b_zeros.sizes().size();
|
||||
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
|
||||
if (is_zp_float) {
|
||||
TORCH_CHECK(b_zeros.size(2) == size_n,
|
||||
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||
" is not size_n = ", size_n);
|
||||
TORCH_CHECK(num_groups == b_zeros.size(1),
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
|
||||
} else {
|
||||
TORCH_CHECK(b_zeros.size(1) == num_groups,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
|
||||
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0,
|
||||
"size_n = ", size_n, ", is not divisible by min_thread_n = ",
|
||||
MARLIN_NAMESPACE_NAME::min_thread_n);
|
||||
|
||||
int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n;
|
||||
int min_workspace_size = min(
|
||||
max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4);
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = ", workspace.numel(),
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
|
||||
TORCH_CHECK(a_scales.scalar_type() == at::ScalarType::Float,
|
||||
"scalar type of a_scales must be float");
|
||||
TORCH_CHECK(global_scale.scalar_type() == c.scalar_type(),
|
||||
"scalar type of global_scale must be the same with c");
|
||||
if (a_type.size_bits() == 16) {
|
||||
TORCH_CHECK(
|
||||
a.scalar_type() == c.scalar_type(),
|
||||
"scalar type of a must be the same with c for 16 bit activation");
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm(
|
||||
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(),
|
||||
b_bias.data_ptr(), a_scales.data_ptr(), b_scales.data_ptr(),
|
||||
global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(),
|
||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||
topk_weights.data_ptr(), moe_block_size, num_experts, top_k,
|
||||
mul_topk_weights, is_ep, size_m, size_n, size_k, workspace.data_ptr(),
|
||||
a_type, b_type, c_type, s_type, has_bias, has_act_order, is_k_full,
|
||||
has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, blocks_per_sm, use_atomic_add, use_fp32_reduce,
|
||||
is_zp_float);
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
759
csrc/moe/moe_align_sum_kernels.cu
Normal file
759
csrc/moe/moe_align_sum_kernels.cu
Normal file
@@ -0,0 +1,759 @@
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/Atomic.cuh>
|
||||
|
||||
#include "../cuda_compat.h"
|
||||
#include "../dispatch_utils.h"
|
||||
#include "core/math.hpp"
|
||||
|
||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
namespace batched_moe_align_block_size {
|
||||
|
||||
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
|
||||
static constexpr int32_t num_threads = 1024;
|
||||
static constexpr int32_t num_blocks = 1;
|
||||
__global__ void batched_moe_align_block_size_kernel(
|
||||
int32_t const num_batches, int32_t const max_tokens_per_batch,
|
||||
int32_t const block_size, int32_t const* __restrict__ batch_num_tokens,
|
||||
int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids,
|
||||
int32_t* __restrict__ num_tokens_post_pad) {
|
||||
// TODO(varun): This is a naive implementation. Could be optimized.
|
||||
|
||||
size_t const batch_id = threadIdx.x;
|
||||
size_t const stride = blockDim.x * gridDim.x;
|
||||
int32_t const num_blocks_per_batch =
|
||||
CEILDIV(max_tokens_per_batch, block_size);
|
||||
int32_t const sorted_ids_size =
|
||||
num_blocks_per_batch * num_batches * block_size;
|
||||
int32_t const block_ids_size = sorted_ids_size / block_size;
|
||||
int32_t const SENTINEL =
|
||||
num_batches * max_tokens_per_batch; // To denote invalid entries.
|
||||
// Intialize sorted_ids
|
||||
for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) {
|
||||
sorted_ids[i] = SENTINEL;
|
||||
}
|
||||
// Intialize expert_ids with -1
|
||||
for (size_t i = threadIdx.x; i < block_ids_size; i += stride) {
|
||||
block_ids[i] = -1;
|
||||
}
|
||||
|
||||
int32_t b_num_tokens = 0;
|
||||
if (batch_id < num_batches) {
|
||||
b_num_tokens = batch_num_tokens[batch_id];
|
||||
}
|
||||
int32_t const ceil_b_num_tokens =
|
||||
CEILDIV(b_num_tokens, block_size) * block_size;
|
||||
|
||||
// Compute prefix sum over token counts per expert
|
||||
using BlockScan = cub::BlockScan<int32_t, 1024>;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
int cumsum_val;
|
||||
BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val);
|
||||
__syncthreads();
|
||||
|
||||
bool const is_last_batch = batch_id == (num_batches - 1);
|
||||
if (is_last_batch) {
|
||||
*num_tokens_post_pad = cumsum_val + ceil_b_num_tokens;
|
||||
}
|
||||
|
||||
if (batch_id < num_batches) {
|
||||
int32_t const batch_offset = batch_id * max_tokens_per_batch;
|
||||
for (size_t i = 0; i < b_num_tokens; ++i) {
|
||||
sorted_ids[cumsum_val + i] = batch_offset + i;
|
||||
}
|
||||
|
||||
int32_t const block_start = cumsum_val / block_size;
|
||||
int32_t const num_blocks = ceil_b_num_tokens / block_size;
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
block_ids[block_start + i] = batch_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace batched_moe_align_block_size
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ void _moe_align_block_size(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t* __restrict__ expert_map, int32_t num_experts,
|
||||
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
|
||||
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded,
|
||||
int32_t max_num_m_blocks, int32_t model_offset, int32_t inactive_expert_id,
|
||||
int32_t topk_num, int32_t* token_mask, bool has_expert_map) {
|
||||
extern __shared__ int32_t shared_counts[];
|
||||
|
||||
// Compute input buffer offsets. Typically these will all be 0, except when
|
||||
// using Multi LoRA.
|
||||
int sorted_token_ids_offset = max_num_tokens_padded * model_offset;
|
||||
int expert_ids_offset = max_num_m_blocks * model_offset;
|
||||
int cumsum_offset = (num_experts + 1) * model_offset;
|
||||
|
||||
// Use separate threadblocks to fill sorted_token_ids.
|
||||
// This is safe since the current kernel does not use sorted_token_ids.
|
||||
if (blockIdx.x % 2) {
|
||||
// Initialize sorted_token_ids with numel
|
||||
for (size_t it = threadIdx.x; it < max_num_tokens_padded;
|
||||
it += blockDim.x) {
|
||||
sorted_token_ids[sorted_token_ids_offset + it] = numel;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int my_expert_start = warp_id * experts_per_warp;
|
||||
|
||||
for (int i = 0; i < experts_per_warp; ++i) {
|
||||
if (my_expert_start + i < padded_num_experts) {
|
||||
shared_counts[warp_id * experts_per_warp + i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t stride = blockDim.x;
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int expert_id = topk_ids[i];
|
||||
if (expert_id >= num_experts) {
|
||||
continue;
|
||||
}
|
||||
if (has_expert_map) {
|
||||
expert_id = expert_map[expert_id];
|
||||
// filter invalid experts
|
||||
if (expert_id == -1) continue;
|
||||
}
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
|
||||
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset],
|
||||
mask);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute prefix sum over token counts per expert
|
||||
using BlockScan = cub::BlockScan<int32_t, 1024>;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
|
||||
int expert_count = 0;
|
||||
int expert_id = threadIdx.x;
|
||||
if (expert_id < num_experts) {
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
|
||||
expert_count = CEILDIV(expert_count, block_size) * block_size;
|
||||
}
|
||||
|
||||
int cumsum_val;
|
||||
BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val);
|
||||
if (expert_id <= num_experts) {
|
||||
cumsum[cumsum_offset + expert_id] = cumsum_val;
|
||||
}
|
||||
|
||||
if (expert_id == num_experts) {
|
||||
total_tokens_post_pad[model_offset] = cumsum_val;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum[cumsum_offset + threadIdx.x];
|
||||
i < cumsum[cumsum_offset + threadIdx.x + 1]; i += block_size) {
|
||||
expert_ids[expert_ids_offset + i / block_size] = threadIdx.x;
|
||||
}
|
||||
}
|
||||
|
||||
// Fill remaining expert_ids with 0
|
||||
const size_t fill_start_idx =
|
||||
cumsum[cumsum_offset + num_experts] / block_size + threadIdx.x;
|
||||
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += blockDim.x) {
|
||||
expert_ids[expert_ids_offset + i] = inactive_expert_id;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int32_t fill_threads>
|
||||
__device__ void _moe_align_block_size_small_batch_expert(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size,
|
||||
size_t numel, int32_t max_num_tokens_padded, int32_t max_num_m_blocks,
|
||||
int32_t inactive_expert_id, int32_t model_offset, int32_t topk_num,
|
||||
int32_t* token_mask, bool has_expert_map) {
|
||||
// Compute input buffer offsets. Typically these will all be 0, except when
|
||||
// using Multi LoRA.
|
||||
int sorted_token_ids_offset = max_num_tokens_padded * model_offset;
|
||||
int expert_ids_offset = max_num_m_blocks * model_offset;
|
||||
|
||||
// Use an additional group of threads to fill sorted_token_ids.
|
||||
// Since the current kernel will use sorted_token_ids afterward,
|
||||
// we fill sorted_token_ids within the same threadblock to make
|
||||
// synchronization easier.
|
||||
if (threadIdx.x < fill_threads) {
|
||||
// Initialize sorted_token_ids with numel
|
||||
for (size_t it = threadIdx.x; it < max_num_tokens_padded;
|
||||
it += fill_threads) {
|
||||
sorted_token_ids[sorted_token_ids_offset + it] = numel;
|
||||
}
|
||||
// Three __syncthreads() corresponding to the other threads
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
__syncthreads();
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t tid = threadIdx.x - fill_threads;
|
||||
const size_t stride = blockDim.x - fill_threads;
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
int32_t* cumsum = shared_mem;
|
||||
int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[(tid + 1) * num_experts + i] = 0;
|
||||
}
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
if (has_expert_map) {
|
||||
expert_id = expert_map[expert_id];
|
||||
// filter invalid expert
|
||||
if (expert_id == -1) continue;
|
||||
}
|
||||
int mask = token_mask == nullptr ? 1 : token_mask[i / topk_num];
|
||||
tokens_cnts[(tid + 1) * num_experts + expert_id] += mask;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tid < num_experts) {
|
||||
tokens_cnts[tid] = 0;
|
||||
for (int i = 1; i <= stride; ++i) {
|
||||
tokens_cnts[i * num_experts + tid] +=
|
||||
tokens_cnts[(i - 1) * num_experts + tid];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] =
|
||||
cumsum[i - 1] +
|
||||
CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) *
|
||||
block_size;
|
||||
}
|
||||
total_tokens_post_pad[model_offset] =
|
||||
static_cast<int32_t>(cumsum[num_experts]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tid < num_experts) {
|
||||
for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) {
|
||||
expert_ids[expert_ids_offset + i / block_size] = tid;
|
||||
}
|
||||
}
|
||||
|
||||
// Fill remaining expert_ids with 0
|
||||
const size_t fill_start_idx = cumsum[num_experts] / block_size + tid;
|
||||
for (size_t i = fill_start_idx; i < max_num_m_blocks; i += stride) {
|
||||
expert_ids[expert_ids_offset + i] = inactive_expert_id;
|
||||
}
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
if (has_expert_map) {
|
||||
expert_id = expert_map[expert_id];
|
||||
// filter invalid expert
|
||||
if (expert_id == -1) continue;
|
||||
}
|
||||
int32_t rank_post_pad =
|
||||
tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id];
|
||||
|
||||
if (token_mask == nullptr || token_mask[i / topk_num]) {
|
||||
sorted_token_ids[sorted_token_ids_offset + rank_post_pad] = i;
|
||||
++tokens_cnts[tid * num_experts + expert_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ void _count_and_sort_expert_tokens(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
|
||||
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
|
||||
int32_t max_num_tokens_padded, int32_t* __restrict__ token_mask,
|
||||
int32_t model_offset, int32_t topk_num, bool has_expert_map) {
|
||||
const size_t tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
const size_t stride = blockDim.x * gridDim.y;
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
if (expert_id >= num_experts) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (has_expert_map) {
|
||||
expert_id = expert_map[expert_id];
|
||||
// filter invalid experts
|
||||
if (expert_id == -1) continue;
|
||||
}
|
||||
|
||||
if (token_mask == nullptr || token_mask[i / topk_num]) {
|
||||
int32_t rank_post_pad = atomicAdd(
|
||||
&cumsum_buffer[(model_offset * (num_experts + 1)) + expert_id], 1);
|
||||
sorted_token_ids[max_num_tokens_padded * model_offset + rank_post_pad] =
|
||||
i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t* __restrict__ expert_map, int32_t num_experts,
|
||||
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
|
||||
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded,
|
||||
int32_t topk_num, bool has_expert_map) {
|
||||
_moe_align_block_size(
|
||||
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
|
||||
num_experts, padded_num_experts, experts_per_warp, block_size, numel,
|
||||
cumsum, max_num_tokens_padded, CEILDIV(max_num_tokens_padded, block_size),
|
||||
0, 0, topk_num, nullptr, has_expert_map);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void count_and_sort_expert_tokens_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
|
||||
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
|
||||
int32_t max_num_tokens_padded, int32_t topk_num, bool has_expert_map) {
|
||||
_count_and_sort_expert_tokens(
|
||||
topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts,
|
||||
max_num_tokens_padded, nullptr, 0, topk_num, has_expert_map);
|
||||
}
|
||||
|
||||
template <typename scalar_t, int TOPK>
|
||||
__global__ void moe_sum_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., topk, d]
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
scalar_t x = 0.0;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < TOPK; ++k) {
|
||||
x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]);
|
||||
}
|
||||
out[token_idx * d + idx] = x;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int32_t fill_threads>
|
||||
__global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad,
|
||||
int32_t* __restrict__ expert_map, int32_t num_experts, int32_t block_size,
|
||||
size_t numel, int32_t max_num_tokens_padded, int32_t topk_num,
|
||||
bool has_expert_map) {
|
||||
_moe_align_block_size_small_batch_expert<scalar_t, fill_threads>(
|
||||
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
|
||||
num_experts, block_size, numel, max_num_tokens_padded,
|
||||
CEILDIV(max_num_tokens_padded, block_size), 0, 0, topk_num, nullptr,
|
||||
has_expert_map);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_lora_align_block_size_kernel(
|
||||
scalar_t* __restrict__ topk_ids, int32_t* __restrict__ token_lora_mapping,
|
||||
int64_t block_size, int32_t* __restrict__ expert_map, int num_experts,
|
||||
int max_loras, size_t numel, int max_num_tokens_padded,
|
||||
int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ expert_ids, int32_t topk_num,
|
||||
int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
|
||||
int32_t* __restrict__ cumsum, int32_t experts_per_warp,
|
||||
int32_t padded_num_experts, int32_t* lora_ids,
|
||||
int32_t* __restrict__ token_mask, bool has_expert_map) {
|
||||
int lora_idx = blockIdx.x / 2;
|
||||
int lora_id = lora_ids[lora_idx];
|
||||
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Populate the token_mask based on the token-LoRA mapping
|
||||
int num_tokens = numel / topk_num;
|
||||
if (threadIdx.x == 0) {
|
||||
total_tokens_post_pad[lora_id] = 0;
|
||||
|
||||
for (int i = 0; i < num_tokens; i++) {
|
||||
token_mask[(lora_id * num_tokens) + i] =
|
||||
(int)token_lora_mapping[i] == lora_id;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
_moe_align_block_size(
|
||||
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
|
||||
num_experts, padded_num_experts, experts_per_warp, block_size, numel,
|
||||
cumsum, max_num_tokens_padded, max_num_m_blocks, lora_id, -1, topk_num,
|
||||
&token_mask[(lora_id * num_tokens)], has_expert_map);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void lora_count_and_sort_expert_tokens_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
|
||||
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
|
||||
int32_t max_num_tokens_padded, int32_t topk_num, int32_t* token_mask,
|
||||
int32_t* lora_ids, bool has_expert_map) {
|
||||
int lora_idx = blockIdx.x;
|
||||
int lora_id = lora_ids[lora_idx];
|
||||
if (lora_id == -1) {
|
||||
return;
|
||||
}
|
||||
|
||||
int num_tokens = numel / topk_num;
|
||||
|
||||
_count_and_sort_expert_tokens(
|
||||
topk_ids, sorted_token_ids, cumsum_buffer, expert_map, numel, num_experts,
|
||||
max_num_tokens_padded, &token_mask[(lora_id * num_tokens)], lora_id,
|
||||
topk_num, has_expert_map);
|
||||
}
|
||||
|
||||
template <typename scalar_t, int32_t fill_threads>
|
||||
__global__ void moe_lora_align_block_size_small_batch_expert_kernel(
|
||||
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
|
||||
int64_t block_size, int32_t* __restrict__ expert_map, int num_experts,
|
||||
int max_loras, size_t numel, int max_num_tokens_padded,
|
||||
int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids,
|
||||
int32_t* __restrict__ expert_ids, int topk_num,
|
||||
int32_t* total_tokens_post_pad, int32_t* adapter_enabled, int32_t* lora_ids,
|
||||
int32_t* token_mask, bool has_expert_map) {
|
||||
int lora_idx = blockIdx.x;
|
||||
int lora_id = lora_ids[lora_idx];
|
||||
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int num_tokens = numel / topk_num;
|
||||
if (threadIdx.x == 0) {
|
||||
total_tokens_post_pad[lora_id] = 0;
|
||||
|
||||
for (int i = 0; i < num_tokens; i++) {
|
||||
token_mask[(lora_id * num_tokens) + i] =
|
||||
(int)token_lora_mapping[i] == lora_id;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
_moe_align_block_size_small_batch_expert<scalar_t, fill_threads>(
|
||||
topk_ids, sorted_token_ids, expert_ids, total_tokens_post_pad, expert_map,
|
||||
num_experts, block_size, numel, max_num_tokens_padded, max_num_m_blocks,
|
||||
-1, lora_id, topk_num, &token_mask[(lora_id * num_tokens)],
|
||||
has_expert_map);
|
||||
}
|
||||
|
||||
} // namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
// taken from
|
||||
// https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad,
|
||||
std::optional<torch::Tensor> maybe_expert_map) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int64_t padded_num_experts =
|
||||
((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
int experts_per_warp = WARP_SIZE;
|
||||
int threads = 1024;
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
|
||||
// BlockScan uses 1024 threads and assigns one thread per expert.
|
||||
TORCH_CHECK(padded_num_experts < 1024,
|
||||
"padded_num_experts must be less than 1024");
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
||||
bool has_expert_map = maybe_expert_map.has_value();
|
||||
torch::Tensor expert_map;
|
||||
if (has_expert_map) {
|
||||
expert_map = maybe_expert_map.value();
|
||||
} else {
|
||||
expert_map = torch::empty({0}, options_int);
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `cumsum` tensors
|
||||
bool small_batch_expert_mode =
|
||||
(topk_ids.numel() < 1024) && (num_experts <= 64);
|
||||
|
||||
if (small_batch_expert_mode) {
|
||||
const int32_t threads = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem_size =
|
||||
((threads + 1) * num_experts + (num_experts + 1)) *
|
||||
sizeof(int32_t);
|
||||
|
||||
// threadIdx.x >= fill_threads: counting experts and aligning
|
||||
// threadIdx.x < fill_threads: filling sorted_token_ids
|
||||
constexpr int32_t fill_threads = 256;
|
||||
auto small_batch_expert_kernel =
|
||||
vllm::moe::moe_align_block_size_small_batch_expert_kernel<
|
||||
scalar_t, fill_threads>;
|
||||
small_batch_expert_kernel<<<1, fill_threads + threads,
|
||||
shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
expert_map.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel(), sorted_token_ids.size(0), topk_ids.size(1),
|
||||
has_expert_map);
|
||||
} else {
|
||||
torch::Tensor cumsum_buffer =
|
||||
torch::empty({num_experts + 1}, options_int);
|
||||
auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
||||
|
||||
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
|
||||
size_t shared_mem_size =
|
||||
num_warps * experts_per_warp * sizeof(int32_t);
|
||||
|
||||
// launch two threadblocks
|
||||
// blockIdx.x == 0: counting experts and aligning
|
||||
// blockIdx.x == 1: filling sorted_token_ids
|
||||
align_kernel<<<2, threads, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
expert_map.data_ptr<int32_t>(), num_experts, padded_num_experts,
|
||||
experts_per_warp, block_size, topk_ids.numel(),
|
||||
cumsum_buffer.data_ptr<int32_t>(), sorted_token_ids.size(0),
|
||||
topk_ids.size(1), has_expert_map);
|
||||
|
||||
const int block_threads = std::min(256, (int)threads);
|
||||
const int num_blocks =
|
||||
(topk_ids.numel() + block_threads - 1) / block_threads;
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
dim3 gridDims(1, actual_blocks);
|
||||
|
||||
auto sort_kernel =
|
||||
vllm::moe::count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
sort_kernel<<<gridDims, block_threads, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>(), expert_map.data_ptr<int32_t>(),
|
||||
topk_ids.numel(), num_experts, sorted_token_ids.size(0),
|
||||
topk_ids.size(1), has_expert_map);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
|
||||
int64_t block_size,
|
||||
torch::Tensor const& batch_num_tokens,
|
||||
torch::Tensor sorted_ids,
|
||||
torch::Tensor batch_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
namespace batched_kernel = vllm::moe::batched_moe_align_block_size;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
int32_t const B = batch_num_tokens.size(0);
|
||||
int32_t const num_blocks_per_batch =
|
||||
round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size;
|
||||
int32_t const num_blocks = num_blocks_per_batch * B;
|
||||
int64_t const sorted_ids_size = num_blocks * block_size;
|
||||
|
||||
TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size);
|
||||
TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size);
|
||||
TORCH_CHECK(num_tokens_post_pad.size(0) == 1);
|
||||
TORCH_CHECK(B <= batched_kernel::num_threads);
|
||||
|
||||
batched_kernel::batched_moe_align_block_size_kernel<<<
|
||||
batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>(
|
||||
B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr<int32_t>(),
|
||||
sorted_ids.data_ptr<int32_t>(), batch_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>());
|
||||
}
|
||||
|
||||
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||
{
|
||||
const int hidden_size = input.size(-1);
|
||||
const auto num_tokens = output.numel() / hidden_size;
|
||||
const int topk = input.size(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
switch (topk) {
|
||||
case 2:
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||
vllm::moe::moe_sum_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
hidden_size);
|
||||
});
|
||||
break;
|
||||
|
||||
case 3:
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||
vllm::moe::moe_sum_kernel<scalar_t, 3><<<grid, block, 0, stream>>>(
|
||||
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
hidden_size);
|
||||
});
|
||||
break;
|
||||
|
||||
case 4:
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
|
||||
vllm::moe::moe_sum_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
hidden_size);
|
||||
});
|
||||
break;
|
||||
|
||||
default:
|
||||
at::sum_out(output, input, 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void moe_lora_align_block_size(
|
||||
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
|
||||
int64_t num_experts, int64_t block_size, int64_t max_loras,
|
||||
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
|
||||
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
|
||||
torch::Tensor lora_ids, std::optional<torch::Tensor> maybe_expert_map) {
|
||||
const int topk_num = topk_ids.size(1);
|
||||
|
||||
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
|
||||
|
||||
int device_max_shared_mem;
|
||||
auto dev = topk_ids.get_device();
|
||||
cudaDeviceGetAttribute(&device_max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int64_t padded_num_experts =
|
||||
((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
|
||||
// BlockScan uses 1024 threads and assigns one thread per expert.
|
||||
TORCH_CHECK(padded_num_experts < 1024,
|
||||
"padded_num_experts must be less than 1024");
|
||||
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
||||
torch::Tensor token_mask =
|
||||
torch::empty({max_loras * topk_ids.size(0)}, options_int);
|
||||
bool has_expert_map = maybe_expert_map.has_value();
|
||||
torch::Tensor expert_map;
|
||||
if (has_expert_map) {
|
||||
expert_map = maybe_expert_map.value();
|
||||
} else {
|
||||
expert_map = torch::empty({0}, options_int);
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
|
||||
bool small_batch_expert_mode =
|
||||
(topk_ids.numel() < 1024) && (num_experts <= 64);
|
||||
|
||||
if (small_batch_expert_mode) {
|
||||
const int32_t num_thread = max((int32_t)num_experts, 128);
|
||||
const int32_t shared_mem =
|
||||
(num_thread + 1) * num_experts * sizeof(int32_t) +
|
||||
(num_experts + 1) * sizeof(int32_t);
|
||||
if (shared_mem > device_max_shared_mem) {
|
||||
TORCH_CHECK(false, "Shared memory usage exceeds device limit.");
|
||||
}
|
||||
|
||||
// threadIdx.x >= fill_threads: counting experts and aligning
|
||||
// threadIdx.x < fill_threads: filling sorted_token_ids
|
||||
constexpr int32_t fill_threads = 256;
|
||||
|
||||
dim3 blockDim(num_thread + fill_threads);
|
||||
auto kernel =
|
||||
vllm::moe::moe_lora_align_block_size_small_batch_expert_kernel<
|
||||
scalar_t, fill_threads>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem));
|
||||
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
token_lora_mapping.data_ptr<int32_t>(), block_size,
|
||||
expert_map.data_ptr<int32_t>(), num_experts, max_loras,
|
||||
topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks,
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), topk_num,
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>(),
|
||||
token_mask.data_ptr<int32_t>(), has_expert_map);
|
||||
} else {
|
||||
int num_thread = 1024;
|
||||
dim3 blockDim(num_thread);
|
||||
size_t num_warps = CEILDIV(padded_num_experts, WARP_SIZE);
|
||||
|
||||
size_t shared_mem_size = num_warps * WARP_SIZE * sizeof(int32_t);
|
||||
|
||||
// cumsum buffer
|
||||
torch::Tensor cumsum =
|
||||
torch::zeros({max_loras * (num_experts + 1)}, options_int);
|
||||
|
||||
auto align_kernel =
|
||||
vllm::moe::moe_lora_align_block_size_kernel<scalar_t>;
|
||||
|
||||
// launch two threadblocks for each lora
|
||||
// blockIdx.x % 2 == 0: counting experts and aligning
|
||||
// blockIdx.x % 2 == 1: filling sorted_token_ids
|
||||
align_kernel<<<max_loras * 2, blockDim, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
token_lora_mapping.data_ptr<int32_t>(), block_size,
|
||||
expert_map.data_ptr<int32_t>(), num_experts, max_loras,
|
||||
topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks,
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), topk_num,
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
adapter_enabled.data_ptr<int32_t>(), cumsum.data_ptr<int32_t>(),
|
||||
WARP_SIZE, padded_num_experts, lora_ids.data_ptr<int32_t>(),
|
||||
token_mask.data_ptr<int32_t>(), has_expert_map);
|
||||
|
||||
const int block_threads = std::min(256, (int)num_thread);
|
||||
const int num_blocks =
|
||||
(topk_ids.numel() + block_threads - 1) / block_threads;
|
||||
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
|
||||
dim3 gridDims(max_loras, actual_blocks);
|
||||
auto sort_kernel =
|
||||
vllm::moe::lora_count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
|
||||
sort_kernel<<<gridDims, block_threads, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(), cumsum.data_ptr<int32_t>(),
|
||||
expert_map.data_ptr<int32_t>(), topk_ids.numel(), num_experts,
|
||||
max_num_tokens_padded, topk_num, token_mask.data_ptr<int32_t>(),
|
||||
lora_ids.data_ptr<int32_t>(), has_expert_map);
|
||||
}
|
||||
});
|
||||
}
|
||||
52
csrc/moe/moe_ops.h
Normal file
52
csrc/moe/moe_ops.h
Normal file
@@ -0,0 +1,52 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output, bool renormalize);
|
||||
|
||||
void moe_sum(torch::Tensor& input, torch::Tensor& output);
|
||||
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad,
|
||||
std::optional<torch::Tensor> maybe_expert_map);
|
||||
|
||||
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
|
||||
int64_t block_size,
|
||||
torch::Tensor const& expert_num_tokens,
|
||||
torch::Tensor sorted_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
void moe_lora_align_block_size(
|
||||
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
|
||||
int64_t num_experts, int64_t block_size, int64_t max_loras,
|
||||
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
|
||||
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
|
||||
torch::Tensor lora_ids, std::optional<torch::Tensor> maybe_expert_map);
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
std::optional<torch::Tensor> b_qzeros,
|
||||
std::optional<torch::Tensor> topk_weights,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
torch::Tensor const& scores, int64_t n_group, int64_t topk_group,
|
||||
int64_t topk, bool renormalize, double routed_scaling_factor,
|
||||
torch::Tensor const& bias, int64_t scoring_func);
|
||||
#endif
|
||||
|
||||
bool moe_permute_unpermute_supported();
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor,
|
||||
const torch::Tensor& dst2src_map,
|
||||
torch::Tensor& output_tensor);
|
||||
222
csrc/moe/moe_permute_unpermute_op.cu
Normal file
222
csrc/moe/moe_permute_unpermute_op.cu
Normal file
@@ -0,0 +1,222 @@
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h"
|
||||
#include "permute_unpermute_kernels/dispatch.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
// moe_permute kernels require at least CUDA 12.0
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
|
||||
|
||||
void moe_permute(
|
||||
const torch::Tensor& input, // [n_token, hidden]
|
||||
const torch::Tensor& topk_ids, // [n_token, topk]
|
||||
const torch::Tensor& token_expert_indices, // [n_token, topk]
|
||||
const std::optional<torch::Tensor>& expert_map, // [n_expert]
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor& permuted_input, // [permuted_size, hidden]
|
||||
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
|
||||
torch::Tensor& inv_permuted_idx, // [n_token, topk]
|
||||
torch::Tensor& permuted_idx, // [permute_size]
|
||||
torch::Tensor& m_indices) { // [align_expand_m]
|
||||
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
|
||||
"topk_ids must be int32");
|
||||
TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int,
|
||||
"token_expert_indices must be int32");
|
||||
TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int,
|
||||
"inv_permuted_idx must be int32");
|
||||
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
|
||||
"expert_first_token_offset shape != n_local_expert+1")
|
||||
TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(),
|
||||
"token_expert_indices shape must be same as inv_permuted_idx");
|
||||
auto n_token = input.sizes()[0];
|
||||
auto n_hidden = input.sizes()[1];
|
||||
auto align_block_size_value =
|
||||
align_block_size.has_value() ? align_block_size.value() : -1;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const long sorter_size =
|
||||
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
|
||||
auto sort_workspace = torch::empty(
|
||||
{sorter_size},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
|
||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
|
||||
|
||||
CubKeyValueSorter sorter{};
|
||||
int64_t* valid_num_ptr = nullptr;
|
||||
// pre-process kernel for expert-parallelism:
|
||||
// no local expert id plus "n_expert" offset for priority to local expert
|
||||
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
|
||||
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
|
||||
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
|
||||
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
|
||||
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
|
||||
// operation is to make local expert high priority in following sort topk_ids
|
||||
// and scan local expert_first_token_offset for each ep rank for next group
|
||||
// gemm.
|
||||
if (expert_map.has_value()) {
|
||||
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
|
||||
valid_num_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
|
||||
preprocessTopkIdLauncher(get_ptr<int>(copy_topk_ids), n_token * topk,
|
||||
expert_map_ptr, n_expert, stream);
|
||||
}
|
||||
// expert sort topk expert id and scan expert id get expert_first_token_offset
|
||||
sortAndScanExpert(
|
||||
get_ptr<int>(copy_topk_ids), get_ptr<int>(token_expert_indices),
|
||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
|
||||
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
|
||||
|
||||
// dispatch expandInputRowsKernelLauncher
|
||||
MOE_DISPATCH(input.scalar_type(), [&] {
|
||||
expandInputRowsKernelLauncher<scalar_t>(
|
||||
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
|
||||
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
|
||||
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
|
||||
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
|
||||
n_hidden, topk, n_local_expert, align_block_size_value, stream);
|
||||
});
|
||||
|
||||
// get m_indices and update expert_first_token_offset with align block
|
||||
// this is only required for DeepGemm and not required for CUTLASS group gemm
|
||||
if (align_block_size.has_value()) {
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
||||
}
|
||||
}
|
||||
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
|
||||
const torch::Tensor& topk_weights, // [n_token, topk]
|
||||
const torch::Tensor& inv_permuted_idx, // [n_token, topk]
|
||||
const std::optional<torch::Tensor>&
|
||||
expert_first_token_offset, // [n_local_expert+1]
|
||||
int64_t topk,
|
||||
torch::Tensor& hidden_states // [n_token, hidden]
|
||||
) {
|
||||
TORCH_CHECK(
|
||||
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
|
||||
"permuted_hidden_states dtype must be same as hidden_states");
|
||||
auto n_token = hidden_states.size(0);
|
||||
auto n_hidden = hidden_states.size(1);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
int64_t const* valid_ptr = nullptr;
|
||||
if (expert_first_token_offset.has_value()) {
|
||||
int n_local_expert = expert_first_token_offset.value().size(0) - 1;
|
||||
valid_ptr =
|
||||
get_ptr<int64_t>(expert_first_token_offset.value()) + n_local_expert;
|
||||
}
|
||||
|
||||
MOE_DISPATCH(hidden_states.scalar_type(), [&] {
|
||||
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
|
||||
get_ptr<scalar_t>(permuted_hidden_states),
|
||||
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
|
||||
get_ptr<int>(inv_permuted_idx), n_token, n_hidden, topk, valid_ptr,
|
||||
stream);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void shuffleInputRowsKernel(const T* input,
|
||||
const int32_t* dst2src_map, T* output,
|
||||
int64_t num_src_rows,
|
||||
int64_t num_dst_rows, int64_t num_cols) {
|
||||
int64_t dest_row_idx = blockIdx.x;
|
||||
int64_t const source_row_idx = dst2src_map[dest_row_idx];
|
||||
|
||||
if (blockIdx.x < num_dst_rows) {
|
||||
// Load 128-bits per thread
|
||||
constexpr int64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
|
||||
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
|
||||
|
||||
// Duplicate and permute rows
|
||||
auto const* source_row_ptr =
|
||||
reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
|
||||
auto* dest_row_ptr =
|
||||
reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
|
||||
|
||||
int64_t const start_offset = threadIdx.x;
|
||||
int64_t const stride = blockDim.x;
|
||||
int64_t const num_elems_in_col = num_cols / ELEM_PER_THREAD;
|
||||
|
||||
for (int elem_index = start_offset; elem_index < num_elems_in_col;
|
||||
elem_index += stride) {
|
||||
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void shuffle_rows(const torch::Tensor& input_tensor,
|
||||
const torch::Tensor& dst2src_map,
|
||||
torch::Tensor& output_tensor) {
|
||||
TORCH_CHECK(input_tensor.scalar_type() == output_tensor.scalar_type(),
|
||||
"Input and output tensors must have the same data type");
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
int64_t const blocks = output_tensor.size(0);
|
||||
int64_t const threads = 256;
|
||||
int64_t const num_dest_rows = output_tensor.size(0);
|
||||
int64_t const num_src_rows = input_tensor.size(0);
|
||||
int64_t const num_cols = input_tensor.size(1);
|
||||
|
||||
TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)),
|
||||
"num_cols must be divisible by 128 / "
|
||||
"sizeof(input_tensor.scalar_type()) / 8");
|
||||
|
||||
MOE_DISPATCH(input_tensor.scalar_type(), [&] {
|
||||
shuffleInputRowsKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
reinterpret_cast<scalar_t*>(input_tensor.data_ptr()),
|
||||
dst2src_map.data_ptr<int32_t>(),
|
||||
reinterpret_cast<scalar_t*>(output_tensor.data_ptr()), num_src_rows,
|
||||
num_dest_rows, num_cols);
|
||||
});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_ids,
|
||||
const torch::Tensor& token_expert_indices,
|
||||
const std::optional<torch::Tensor>& expert_map,
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor& permuted_input,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states,
|
||||
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
|
||||
const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
|
||||
torch::Tensor& hidden_states) {
|
||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
bool moe_permute_unpermute_supported() {
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_permute", &moe_permute);
|
||||
m.impl("moe_unpermute", &moe_unpermute);
|
||||
}
|
||||
342
csrc/moe/moe_wna16.cu
Normal file
342
csrc/moe/moe_wna16.cu
Normal file
@@ -0,0 +1,342 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include "moe_wna16_utils.h"
|
||||
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
template <typename scalar_t, int bit, int GROUPS>
|
||||
__global__ void moe_wna16_gemm_kernel(
|
||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||
const uint32_t* __restrict__ qzeros,
|
||||
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_token_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ num_tokens_post_pad,
|
||||
|
||||
uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m,
|
||||
uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M,
|
||||
uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp,
|
||||
bool mul_topk_weight) {
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
return;
|
||||
} else {
|
||||
#endif
|
||||
|
||||
using Dtype = ScalarType<scalar_t>;
|
||||
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||
|
||||
if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return;
|
||||
|
||||
const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x;
|
||||
const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K;
|
||||
|
||||
const int32_t expert_id = expert_ids[blockIdx.x];
|
||||
|
||||
int32_t num_valid_tokens = 0;
|
||||
extern __shared__ uint16_t block_input_tmp[];
|
||||
scalar_t* block_input = reinterpret_cast<scalar_t*>(block_input_tmp);
|
||||
scalar_t2* block_input_half2 = reinterpret_cast<scalar_t2*>(block_input);
|
||||
|
||||
// load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory
|
||||
for (int m = 0; m < BLOCK_SIZE_M; m++) {
|
||||
const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m;
|
||||
const int32_t token_index = sorted_token_ids[offset_m];
|
||||
if (token_index / top_k >= size_m) break;
|
||||
|
||||
num_valid_tokens = m + 1;
|
||||
|
||||
if (expert_id != -1) {
|
||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||
for (int i = 0; i < k_per_thread; i++) {
|
||||
int k = BLOCK_SIZE_N * i + threadIdx.x;
|
||||
if (k >= BLOCK_SIZE_K) break;
|
||||
if (offset_k + k >= size_k) break;
|
||||
|
||||
// load input to shared memory
|
||||
// use a special layout to fit the layout of dequanted-weight
|
||||
int origin_k;
|
||||
if constexpr (bit == 4) {
|
||||
// [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2);
|
||||
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order;
|
||||
} else {
|
||||
// [0, 2, 1, 3]
|
||||
int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2);
|
||||
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order;
|
||||
}
|
||||
|
||||
origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K;
|
||||
block_input[m * BLOCK_SIZE_K + k] = input[origin_k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (expert_id == -1) return;
|
||||
__syncthreads();
|
||||
if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return;
|
||||
|
||||
float res[64]; // assume BLOCK_SIZE_M <= 64
|
||||
scalar_t2 res2;
|
||||
scalar_t2 scale_f2;
|
||||
scalar_t2 qzero_f2;
|
||||
|
||||
// note that (size_n * size_k * expert_id) may greater than 2 ** 31
|
||||
constexpr int8_t pack_factor = 32 / bit;
|
||||
const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id;
|
||||
const uint32_t* expert_qweight = qweight + expert_offset / pack_factor;
|
||||
const scalar_t* expert_scales = scales + expert_offset / group_size;
|
||||
const uint32_t* expert_qzeros =
|
||||
qzeros + expert_offset / group_size / pack_factor;
|
||||
|
||||
// load 4*int32 one time: 4 int32 = 128 bit = 1 float4
|
||||
// weight would be loaded in loop
|
||||
uint32_t expert_qweight_tmp[4];
|
||||
float4* expert_qweight_tmp_float4 =
|
||||
reinterpret_cast<float4*>(expert_qweight_tmp);
|
||||
|
||||
// load all required scales one time
|
||||
scalar_t expert_scales_groups[GROUPS];
|
||||
int scales_offset_tmp =
|
||||
(offset_n * size_k + offset_k) / group_size / GROUPS;
|
||||
if constexpr (GROUPS == 1) {
|
||||
*expert_scales_groups = expert_scales[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 2) {
|
||||
float* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 4) {
|
||||
float2* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float2*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float2*>(expert_scales)[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 8) {
|
||||
float4* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float4*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float4*>(expert_scales)[scales_offset_tmp];
|
||||
}
|
||||
|
||||
// load all required qzeros one time
|
||||
uint8_t expert_qzeros_groups[GROUPS];
|
||||
if (!has_zp) {
|
||||
if constexpr (bit == 4) {
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(8));
|
||||
} else {
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(128));
|
||||
}
|
||||
} else {
|
||||
int qzeros_offset_tmp =
|
||||
(offset_n / (8 / bit)) * (size_k / group_size / GROUPS) +
|
||||
offset_k / group_size / GROUPS;
|
||||
if constexpr (GROUPS == 1) {
|
||||
uint8_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint8_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint8_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 2) {
|
||||
uint16_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint16_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint16_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 4) {
|
||||
uint32_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint32_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint32_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 8) {
|
||||
uint64_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint64_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint64_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
}
|
||||
}
|
||||
|
||||
for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) {
|
||||
int k = offset_k + tmp_k * pack_factor;
|
||||
if (k >= size_k) break;
|
||||
const int32_t weight_offset = offset_n * size_k + k;
|
||||
|
||||
if (tmp_k % 4 == 0) {
|
||||
*expert_qweight_tmp_float4 = reinterpret_cast<const float4*>(
|
||||
expert_qweight)[weight_offset / pack_factor / 4];
|
||||
}
|
||||
|
||||
if (tmp_k % (group_size / pack_factor) == 0) {
|
||||
scalar_t scale_f =
|
||||
expert_scales_groups[tmp_k / (group_size / pack_factor)];
|
||||
scale_f2 = Dtype::num2num2(scale_f);
|
||||
|
||||
if (has_zp) {
|
||||
uint8_t qzero =
|
||||
expert_qzeros_groups[tmp_k / (group_size / pack_factor)];
|
||||
if constexpr (bit == 4) {
|
||||
qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF;
|
||||
}
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero));
|
||||
}
|
||||
}
|
||||
|
||||
scalar_t2 weight_half2[16 / bit];
|
||||
dequant<scalar_t2, bit>(expert_qweight_tmp[tmp_k % 4], weight_half2);
|
||||
|
||||
for (int m = 0; m < num_valid_tokens; m++) {
|
||||
res2 = {};
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16 / bit; i++) {
|
||||
int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i;
|
||||
res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2),
|
||||
block_input_half2[offset_input], res2);
|
||||
}
|
||||
|
||||
if (tmp_k == 0) {
|
||||
res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
|
||||
} else {
|
||||
res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int m = 0; m < num_valid_tokens; ++m) {
|
||||
const int32_t token_index =
|
||||
sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m];
|
||||
if (mul_topk_weight) {
|
||||
res[m] *= topk_weights[token_index];
|
||||
}
|
||||
atomicAdd(&output[token_index * size_n + offset_n],
|
||||
Dtype::float2num(res[m]));
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output,
|
||||
const uint32_t* b_qweight, const scalar_t* b_scales,
|
||||
const uint32_t* b_qzeros, const float* topk_weights,
|
||||
const int32_t* sorted_token_ids,
|
||||
const int32_t* expert_ids,
|
||||
const int32_t* num_tokens_post_pad, int num_experts,
|
||||
int group_size, int num_token_blocks, int top_k,
|
||||
int size_m, int size_n, int size_k, int BLOCK_SIZE_M,
|
||||
int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit,
|
||||
bool has_zp, bool mul_topk_weight) {
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_SIZE_N;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = num_token_blocks;
|
||||
gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K);
|
||||
|
||||
auto kernel = moe_wna16_gemm_kernel<scalar_t, 4, 1>;
|
||||
if (bit == 4) {
|
||||
if (BLOCK_SIZE_K / group_size == 2) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 2>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 4) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 4>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 8) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 8>;
|
||||
}
|
||||
} else {
|
||||
if (BLOCK_SIZE_K / group_size == 1) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 1>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 2) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 2>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 4) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 4>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 8) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 8>;
|
||||
}
|
||||
}
|
||||
|
||||
const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
kernel<<<gridDim, blockDim, shared_mem_size, stream>>>(
|
||||
input, output, b_qweight, b_scales, b_qzeros, topk_weights,
|
||||
sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts,
|
||||
group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K, has_zp, mul_topk_weight);
|
||||
}
|
||||
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
std::optional<torch::Tensor> b_qzeros,
|
||||
std::optional<torch::Tensor> topk_weights,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
output.zero_();
|
||||
|
||||
const int num_experts = b_qweight.size(0);
|
||||
const int size_m = input.size(0);
|
||||
const int size_n = b_qweight.size(1);
|
||||
const int size_k = input.size(1);
|
||||
const int group_size = size_k / b_scales.size(2);
|
||||
|
||||
int64_t EM = sorted_token_ids.size(0);
|
||||
if (size_m <= BLOCK_SIZE_M) {
|
||||
EM = min(EM, size_m * BLOCK_SIZE_M * top_k);
|
||||
}
|
||||
const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
|
||||
|
||||
const uint32_t* b_qzeros_ptr;
|
||||
if (b_qzeros.has_value())
|
||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||
const float* topk_weights_ptr = nullptr;
|
||||
if (topk_weights.has_value())
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
|
||||
|
||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||
TORCH_CHECK(size_k % BLOCK_SIZE_K == 0,
|
||||
"size_k must divisible by BLOCK_SIZE_K");
|
||||
TORCH_CHECK(BLOCK_SIZE_K % group_size == 0,
|
||||
"BLOCK_SIZE_K must divisible by group_size");
|
||||
TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64");
|
||||
TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 ||
|
||||
groups_per_block_row == 4 || groups_per_block_row == 8,
|
||||
"BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]");
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::Half) {
|
||||
run_moe_wna16_gemm<half>(
|
||||
(const half*)input.data_ptr<at::Half>(),
|
||||
(half*)output.data_ptr<at::Half>(),
|
||||
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
|
||||
(const half*)b_scales.data_ptr<at::Half>(), b_qzeros_ptr,
|
||||
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
|
||||
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
|
||||
b_qzeros.has_value(), topk_weights.has_value());
|
||||
} else if (input.scalar_type() == at::ScalarType::BFloat16) {
|
||||
run_moe_wna16_gemm<nv_bfloat16>(
|
||||
(const nv_bfloat16*)input.data_ptr<at::BFloat16>(),
|
||||
(nv_bfloat16*)output.data_ptr<at::BFloat16>(),
|
||||
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
|
||||
(const nv_bfloat16*)b_scales.data_ptr<at::BFloat16>(), b_qzeros_ptr,
|
||||
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
|
||||
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
|
||||
b_qzeros.has_value(), topk_weights.has_value());
|
||||
} else {
|
||||
TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
return output;
|
||||
}
|
||||
200
csrc/moe/moe_wna16_utils.h
Normal file
200
csrc/moe/moe_wna16_utils.h
Normal file
@@ -0,0 +1,200 @@
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
|
||||
template <>
|
||||
class ScalarType<half> {
|
||||
public:
|
||||
using scalar_t = half;
|
||||
using scalar_t2 = half2;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||
return __halves2half2(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline float2num(const float x) {
|
||||
return __float2half(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline int2num(const float x) {
|
||||
return __int2half_rn(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const half2 x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ half2 inline float22num2(const float2 x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class ScalarType<nv_bfloat16> {
|
||||
public:
|
||||
using scalar_t = nv_bfloat16;
|
||||
using scalar_t2 = nv_bfloat162;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
return __bfloat162bfloat162(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
||||
const nv_bfloat16 x2) {
|
||||
return __halves2bfloat162(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline int2num(const float x) {
|
||||
return __int2bfloat16_rn(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat162 inline float22num2(const float2 x) {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, int bit>
|
||||
__device__ inline void dequant(int q, scalar_t2* res) {}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, 4>(int q, half2* res) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
|
||||
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
q >>= 8;
|
||||
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
|
||||
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
res[1] = __hfma2(*reinterpret_cast<half2*>(&hi0),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
res[2] = __hsub2(*reinterpret_cast<half2*>(&lo1),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
res[3] = __hfma2(*reinterpret_cast<half2*>(&hi1),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, 8>(int q, half2* res) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
res[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
|
||||
res[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[2] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[3] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, 8>(int q, nv_bfloat162* res) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388608.f;
|
||||
fp32_intermediates[1] -= 8388608.f;
|
||||
fp32_intermediates[2] -= 8388608.f;
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(res);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
#endif
|
||||
59
csrc/moe/permute_unpermute_kernels/dispatch.h
Normal file
59
csrc/moe/permute_unpermute_kernels/dispatch.h
Normal file
@@ -0,0 +1,59 @@
|
||||
#pragma once
|
||||
#include <cuda_fp8.h>
|
||||
#define MOE_SWITCH(TYPE, ...) \
|
||||
at::ScalarType _st = ::detail::scalar_type(TYPE); \
|
||||
switch (_st) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \
|
||||
}
|
||||
|
||||
#define MOE_DISPATCH_CASE(enum_type, ...) \
|
||||
case enum_type: { \
|
||||
using scalar_t = ScalarType2CudaType<enum_type>::type; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
}
|
||||
#define MOE_DISPATCH_FLOAT_CASE(...) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||
MOE_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||
|
||||
#define MOE_DISPATCH(TYPE, ...) \
|
||||
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
|
||||
|
||||
template <at::ScalarType type>
|
||||
struct ScalarType2CudaType;
|
||||
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Float> {
|
||||
using type = float;
|
||||
};
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Half> {
|
||||
using type = half;
|
||||
};
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::BFloat16> {
|
||||
using type = __nv_bfloat16;
|
||||
};
|
||||
// uint8 for packed fp4
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Byte> {
|
||||
using type = uint8_t;
|
||||
};
|
||||
|
||||
// #if __CUDA_ARCH__ >= 890
|
||||
// fp8
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Float8_e5m2> {
|
||||
using type = __nv_fp8_e5m2;
|
||||
};
|
||||
template <>
|
||||
struct ScalarType2CudaType<at::ScalarType::Float8_e4m3fn> {
|
||||
using type = __nv_fp8_e4m3;
|
||||
};
|
||||
// #endif
|
||||
@@ -0,0 +1,231 @@
|
||||
|
||||
#include "moe_permute_unpermute_kernel.h"
|
||||
|
||||
// moe_permute kernels require at least CUDA 12.0
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
|
||||
|
||||
// CubKeyValueSorter definition begin
|
||||
CubKeyValueSorter::CubKeyValueSorter()
|
||||
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
|
||||
|
||||
int CubKeyValueSorter::expertsToBits(int num_experts) {
|
||||
// Max value we represent is V = num_experts + (num_experts - 1) = 2 *
|
||||
// num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1
|
||||
return static_cast<int>(log2(2 * num_experts - 1)) + 1;
|
||||
}
|
||||
|
||||
CubKeyValueSorter::CubKeyValueSorter(int const num_experts)
|
||||
: num_experts_(num_experts), num_bits_(expertsToBits(num_experts)) {}
|
||||
|
||||
void CubKeyValueSorter::updateNumExperts(int const num_experts) {
|
||||
num_experts_ = num_experts;
|
||||
num_bits_ = expertsToBits(num_experts);
|
||||
}
|
||||
|
||||
size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs,
|
||||
int const num_experts) {
|
||||
int num_bits = expertsToBits(num_experts);
|
||||
size_t required_storage = 0;
|
||||
int* null_int = nullptr;
|
||||
cub::DeviceRadixSort::SortPairs(nullptr, required_storage, null_int, null_int,
|
||||
null_int, null_int, num_key_value_pairs, 0,
|
||||
num_bits);
|
||||
|
||||
// when num_key_value_pairs, num_experts, num_bits, required_storage = 64,
|
||||
// 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same
|
||||
// inputs
|
||||
if (required_storage == 0) {
|
||||
required_storage = 1;
|
||||
}
|
||||
return required_storage;
|
||||
}
|
||||
|
||||
void CubKeyValueSorter::run(void* workspace, size_t const workspace_size,
|
||||
int const* keys_in, int* keys_out,
|
||||
int const* values_in, int* values_out,
|
||||
size_t const num_key_value_pairs,
|
||||
cudaStream_t stream) {
|
||||
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_);
|
||||
size_t actual_ws_size = workspace_size;
|
||||
|
||||
TORCH_CHECK(expected_ws_size <= workspace_size,
|
||||
"[CubKeyValueSorter::run] The allocated workspace is too small "
|
||||
"to run this problem.");
|
||||
cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out,
|
||||
values_in, values_out, num_key_value_pairs, 0,
|
||||
num_bits_, stream);
|
||||
}
|
||||
// CubKeyValueSorter definition end
|
||||
|
||||
static inline size_t pad_to_multiple_of_16(size_t const& input) {
|
||||
static constexpr int ALIGNMENT = 16;
|
||||
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
|
||||
}
|
||||
template <class T>
|
||||
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
|
||||
int64_t const arr_length,
|
||||
T const target) {
|
||||
int64_t low = 0, high = arr_length - 1, target_location = -1;
|
||||
while (low <= high) {
|
||||
int64_t mid = (low + high) / 2;
|
||||
|
||||
if (sorted_indices[mid] >= target) {
|
||||
high = mid - 1;
|
||||
} else {
|
||||
low = mid + 1;
|
||||
target_location = mid;
|
||||
}
|
||||
}
|
||||
return target_location + 1;
|
||||
}
|
||||
|
||||
// Calculates the start offset of the tokens for a given expert. The last
|
||||
// element is the total number of valid tokens
|
||||
__global__ void computeExpertFirstTokenOffsetKernel(
|
||||
int const* sorted_experts, int64_t const sorted_experts_len,
|
||||
int const num_experts, int64_t* expert_first_token_offset) {
|
||||
// First, compute the global tid. We only need 1 thread per expert.
|
||||
int const expert = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// Note that expert goes [0, num_experts] (inclusive) because we want a count
|
||||
// for the total number of active tokens at the end of the scan.
|
||||
if (expert >= num_experts + 1) {
|
||||
return;
|
||||
}
|
||||
expert_first_token_offset[expert] =
|
||||
findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert);
|
||||
}
|
||||
|
||||
void computeExpertFirstTokenOffset(int const* sorted_indices,
|
||||
int const total_indices,
|
||||
int const num_experts,
|
||||
int64_t* expert_first_token_offset,
|
||||
cudaStream_t stream) {
|
||||
int const num_entries = num_experts + 1;
|
||||
int const threads = std::min(1024, num_entries);
|
||||
int const blocks = (num_entries + threads - 1) / threads;
|
||||
|
||||
computeExpertFirstTokenOffsetKernel<<<blocks, threads, 0, stream>>>(
|
||||
sorted_indices, total_indices, num_experts, expert_first_token_offset);
|
||||
}
|
||||
|
||||
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
|
||||
int* permuted_experts, int* permuted_rows,
|
||||
int64_t* expert_first_token_offset, int num_rows,
|
||||
int num_experts, int num_experts_per_node, int k,
|
||||
CubKeyValueSorter& sorter, void* sorter_ws,
|
||||
cudaStream_t stream) {
|
||||
int64_t const expanded_num_rows = static_cast<int64_t>(k) * num_rows;
|
||||
// We need to use the full num_experts because that is the sentinel value used
|
||||
// by topk for disabled experts
|
||||
sorter.updateNumExperts(num_experts);
|
||||
size_t const sorter_ws_size_bytes = pad_to_multiple_of_16(
|
||||
sorter.getWorkspaceSize(expanded_num_rows, num_experts));
|
||||
sorter.run((void*)sorter_ws, sorter_ws_size_bytes, expert_for_source_row,
|
||||
permuted_experts, source_rows, permuted_rows, expanded_num_rows,
|
||||
stream);
|
||||
computeExpertFirstTokenOffset(permuted_experts, expanded_num_rows,
|
||||
num_experts_per_node, expert_first_token_offset,
|
||||
stream);
|
||||
}
|
||||
|
||||
__global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr,
|
||||
int num_experts) {
|
||||
auto tidx = threadIdx.x;
|
||||
auto bidx = blockIdx.x;
|
||||
auto offset = bidx * blockDim.x;
|
||||
auto bound = min(offset + blockDim.x, size);
|
||||
extern __shared__ int smem_expert_map[];
|
||||
// store expert_map in smem
|
||||
for (int i = tidx; i < num_experts; i += blockDim.x) {
|
||||
smem_expert_map[i] = expert_map_ptr[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// query global expert id in expert map.
|
||||
// if global expert id = -1 in exert map, plus n_expert
|
||||
// else set global expert id = exert map[global expert id]
|
||||
if (offset + tidx < bound) {
|
||||
auto topk_id = topk_id_ptr[offset + tidx];
|
||||
auto local_expert_idx = smem_expert_map[topk_id];
|
||||
if (local_expert_idx == -1) {
|
||||
topk_id += num_experts;
|
||||
} else {
|
||||
topk_id = local_expert_idx;
|
||||
}
|
||||
__syncwarp();
|
||||
topk_id_ptr[offset + tidx] = topk_id;
|
||||
}
|
||||
}
|
||||
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr, int num_experts,
|
||||
cudaStream_t stream) {
|
||||
int block = std::min(size, 1024);
|
||||
int grid = (size + block - 1) / block;
|
||||
int smem_size = (num_experts) * sizeof(int);
|
||||
preprocessTopkIdKernel<<<grid, block, smem_size, stream>>>(
|
||||
topk_id_ptr, size, expert_map_ptr, num_experts);
|
||||
}
|
||||
|
||||
template <bool ALIGN_BLOCK_SIZE>
|
||||
__global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset,
|
||||
int* m_indices, const int num_local_expert,
|
||||
const int align_block_size) {
|
||||
int eidx = blockIdx.x;
|
||||
int tidx = threadIdx.x;
|
||||
extern __shared__ int64_t smem_expert_first_token_offset[];
|
||||
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
|
||||
smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i);
|
||||
}
|
||||
__syncthreads();
|
||||
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
|
||||
auto first_token_offset = smem_expert_first_token_offset[eidx];
|
||||
int n_token_in_expert = last_token_offset - first_token_offset;
|
||||
|
||||
if constexpr (ALIGN_BLOCK_SIZE) {
|
||||
n_token_in_expert = (n_token_in_expert + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
// round up to ALIGN_BLOCK_SIZE
|
||||
int64_t accumulate_align_offset = 0;
|
||||
for (int i = 1; i <= eidx + 1; i++) {
|
||||
int n_token = smem_expert_first_token_offset[i] -
|
||||
smem_expert_first_token_offset[i - 1];
|
||||
accumulate_align_offset =
|
||||
accumulate_align_offset + (n_token + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
if (i == eidx) {
|
||||
first_token_offset = accumulate_align_offset;
|
||||
}
|
||||
// last block store align_expert_first_token_offset
|
||||
if (eidx == num_local_expert - 1 && threadIdx.x == 0) {
|
||||
align_expert_first_token_offset[i] = accumulate_align_offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) {
|
||||
// update m_indice with expert id
|
||||
m_indices[first_token_offset + idx] = eidx;
|
||||
}
|
||||
}
|
||||
|
||||
void getMIndices(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset, int* m_indices,
|
||||
int num_local_expert, const int align_block_size,
|
||||
cudaStream_t stream) {
|
||||
int block = 256;
|
||||
int grid = num_local_expert;
|
||||
int smem_size = sizeof(int64_t) * (num_local_expert + 1);
|
||||
if (align_block_size == -1) {
|
||||
getMIndicesKernel<false><<<grid, block, smem_size, stream>>>(
|
||||
expert_first_token_offset, align_expert_first_token_offset, m_indices,
|
||||
num_local_expert, align_block_size);
|
||||
} else {
|
||||
getMIndicesKernel<true><<<grid, block, smem_size, stream>>>(
|
||||
expert_first_token_offset, align_expert_first_token_offset, m_indices,
|
||||
num_local_expert, align_block_size);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,83 @@
|
||||
#pragma once
|
||||
// reference from tensorrt_llm moe kernel implementation archive in
|
||||
// https://github.com/BBuf/tensorrt-llm-moe/tree/master
|
||||
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/all.h>
|
||||
#include "dispatch.h"
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/device/device_radix_sort.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#include "cutlass/numeric_size.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
template <typename T>
|
||||
inline T* get_ptr(torch::Tensor& t) {
|
||||
return reinterpret_cast<T*>(t.data_ptr());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline const T* get_ptr(const torch::Tensor& t) {
|
||||
return reinterpret_cast<const T*>(t.data_ptr());
|
||||
}
|
||||
|
||||
class CubKeyValueSorter {
|
||||
public:
|
||||
CubKeyValueSorter();
|
||||
|
||||
CubKeyValueSorter(int const num_experts);
|
||||
|
||||
void updateNumExperts(int const num_experts);
|
||||
|
||||
static size_t getWorkspaceSize(size_t const num_key_value_pairs,
|
||||
int const num_experts);
|
||||
|
||||
void run(void* workspace, size_t const workspace_size, int const* keys_in,
|
||||
int* keys_out, int const* values_in, int* values_out,
|
||||
size_t const num_key_value_pairs, cudaStream_t stream);
|
||||
|
||||
private:
|
||||
static int expertsToBits(int experts);
|
||||
int num_experts_;
|
||||
int num_bits_;
|
||||
};
|
||||
|
||||
void computeExpertFirstTokenOffset(int const* sorted_indices,
|
||||
int const total_indices,
|
||||
int const num_experts,
|
||||
int64_t* expert_first_token_offset,
|
||||
cudaStream_t stream);
|
||||
|
||||
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
|
||||
int* permuted_experts, int* permuted_rows,
|
||||
int64_t* expert_first_token_offset, int num_rows,
|
||||
int num_experts, int num_experts_per_node, int k,
|
||||
CubKeyValueSorter& sorter, void* sorter_ws,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream);
|
||||
|
||||
template <class T, class OutputType>
|
||||
void finalizeMoeRoutingKernelLauncher(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int64_t const num_rows, int64_t const cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr, cudaStream_t stream);
|
||||
|
||||
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
|
||||
const int* expert_map_ptr, int num_experts,
|
||||
cudaStream_t stream);
|
||||
|
||||
void getMIndices(int64_t* expert_first_token_offset,
|
||||
int64_t* align_expert_first_token_offset, int* m_indices,
|
||||
int num_local_expert, const int align_block_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
#include "moe_permute_unpermute_kernel.inl"
|
||||
@@ -0,0 +1,203 @@
|
||||
#pragma once
|
||||
|
||||
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
|
||||
__global__ void expandInputRowsKernel(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
|
||||
int num_local_experts, int align_block_size) {
|
||||
// Reverse permutation map.
|
||||
// I do this so that later, we can use the source -> dest map to do the k-way
|
||||
// reduction and unpermuting. I need the reverse map for that reduction to
|
||||
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
|
||||
// thread block will be responsible for all k summations.
|
||||
int64_t expanded_dest_row = blockIdx.x;
|
||||
int64_t const expanded_source_row =
|
||||
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
|
||||
int expert_id = sorted_experts[expanded_dest_row];
|
||||
|
||||
extern __shared__ int64_t smem_expert_first_token_offset[];
|
||||
if constexpr (ALIGN_BLOCK_SIZE) {
|
||||
// load g2s
|
||||
for (int idx = threadIdx.x; idx < num_local_experts + 1;
|
||||
idx += blockDim.x) {
|
||||
smem_expert_first_token_offset[idx] =
|
||||
__ldg(expert_first_token_offset + idx);
|
||||
}
|
||||
__syncthreads();
|
||||
int lane_idx = threadIdx.x & 31;
|
||||
|
||||
if (lane_idx == 0) {
|
||||
// set token_offset_in_expert = 0 if this expert is not local expert
|
||||
int token_offset_in_expert =
|
||||
expert_id >= num_local_experts
|
||||
? 0
|
||||
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
|
||||
int64_t accumulate_align_offset = 0;
|
||||
#pragma unroll 1
|
||||
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
|
||||
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
|
||||
smem_expert_first_token_offset[eidx - 1];
|
||||
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
|
||||
align_block_size * align_block_size;
|
||||
}
|
||||
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
|
||||
}
|
||||
// lane0 shuffle broadcast align_expanded_dest_row
|
||||
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
assert(expanded_dest_row <= INT32_MAX);
|
||||
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
|
||||
static_cast<int>(expanded_dest_row);
|
||||
// skip non local expert token
|
||||
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
|
||||
permuted_idx[expanded_dest_row] = expanded_source_row;
|
||||
}
|
||||
}
|
||||
|
||||
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
|
||||
// Load 128-bits per thread
|
||||
constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<T>::value;
|
||||
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
|
||||
|
||||
// Duplicate and permute rows
|
||||
int64_t const source_row = expanded_source_row / k;
|
||||
|
||||
auto const* source_row_ptr =
|
||||
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
|
||||
auto* dest_row_ptr =
|
||||
reinterpret_cast<DataElem*>(permuted_output + expanded_dest_row * cols);
|
||||
|
||||
int64_t const start_offset = threadIdx.x;
|
||||
int64_t const stride = blockDim.x;
|
||||
int64_t const num_elems_in_col = cols / ELEM_PER_THREAD;
|
||||
|
||||
for (int elem_index = start_offset; elem_index < num_elems_in_col;
|
||||
elem_index += stride) {
|
||||
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void expandInputRowsKernelLauncher(
|
||||
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
|
||||
int const* expanded_dest_row_to_expanded_source_row,
|
||||
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
|
||||
int64_t* expert_first_token_offset, int64_t const num_rows,
|
||||
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
|
||||
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
|
||||
int64_t const blocks = num_rows * k;
|
||||
int64_t const threads = 256;
|
||||
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
|
||||
FuncPtr func_map[2][2] = {
|
||||
{&expandInputRowsKernel<T, false, false>,
|
||||
&expandInputRowsKernel<T, false, true>},
|
||||
{&expandInputRowsKernel<T, true, false>,
|
||||
&expandInputRowsKernel<T, true, true>},
|
||||
};
|
||||
bool is_check_skip = num_valid_tokens_ptr != nullptr;
|
||||
bool is_align_block_size = align_block_size != -1;
|
||||
auto func = func_map[is_check_skip][is_align_block_size];
|
||||
|
||||
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
|
||||
|
||||
func<<<blocks, threads, smem_size, stream>>>(
|
||||
unpermuted_input, permuted_output, sorted_experts,
|
||||
expanded_dest_row_to_expanded_source_row,
|
||||
expanded_source_row_to_expanded_dest_row, permuted_idx,
|
||||
expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
|
||||
num_local_experts, align_block_size);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
__host__ __device__ constexpr static U arrayConvert(T const& input) {
|
||||
using Type = typename U::Element;
|
||||
static_assert(T::kElements == U::kElements);
|
||||
U u;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < U::kElements; i++) {
|
||||
u[i] = static_cast<Type>(input[i]);
|
||||
}
|
||||
return u;
|
||||
}
|
||||
|
||||
template <typename T, typename OutputType, bool CHECK_SKIPPED>
|
||||
__global__ void finalizeMoeRoutingKernel(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) {
|
||||
assert(orig_cols % 4 == 0);
|
||||
int64_t const original_row = blockIdx.x;
|
||||
auto const offset = original_row * orig_cols;
|
||||
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
|
||||
int64_t const num_valid = *num_valid_ptr;
|
||||
|
||||
// Load 128-bits per thread, according to the smallest data type we read/write
|
||||
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
|
||||
128 / std::min(cutlass::sizeof_bits<OutputType>::value,
|
||||
cutlass::sizeof_bits<T>::value);
|
||||
|
||||
int64_t const start_offset = threadIdx.x;
|
||||
int64_t const stride = blockDim.x;
|
||||
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
|
||||
|
||||
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
|
||||
using OutputElem = cutlass::Array<OutputType, FINALIZE_ELEM_PER_THREAD>;
|
||||
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
|
||||
auto const* expanded_permuted_rows_v =
|
||||
reinterpret_cast<InputElem const*>(expanded_permuted_rows);
|
||||
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
|
||||
|
||||
#pragma unroll
|
||||
for (int elem_index = start_offset; elem_index < num_elems_in_col;
|
||||
elem_index += stride) {
|
||||
ComputeElem thread_output;
|
||||
thread_output.fill(0);
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
int64_t const expanded_original_row = original_row * k + k_idx;
|
||||
int64_t const expanded_permuted_row =
|
||||
expanded_source_row_to_expanded_dest_row[expanded_original_row];
|
||||
|
||||
int64_t const k_offset = original_row * k + k_idx;
|
||||
float const row_scale = scales[k_offset];
|
||||
|
||||
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto const* expanded_permuted_rows_row_ptr =
|
||||
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
|
||||
|
||||
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(
|
||||
expanded_permuted_rows_row_ptr[elem_index]);
|
||||
thread_output = thread_output + row_scale * (expert_result);
|
||||
}
|
||||
|
||||
OutputElem output_elem =
|
||||
arrayConvert<ComputeElem, OutputElem>(thread_output);
|
||||
reduced_row_ptr_v[elem_index] = output_elem;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class OutputType>
|
||||
void finalizeMoeRoutingKernelLauncher(
|
||||
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
|
||||
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
|
||||
int64_t const num_rows, int64_t const cols, int64_t const k,
|
||||
int64_t const* num_valid_ptr, cudaStream_t stream) {
|
||||
int64_t const blocks = num_rows;
|
||||
int64_t const threads = 256;
|
||||
bool const check_finished = num_valid_ptr != nullptr;
|
||||
using FuncPtr = decltype(&finalizeMoeRoutingKernel<T, OutputType, false>);
|
||||
FuncPtr func_map[2] = {&finalizeMoeRoutingKernel<T, OutputType, false>,
|
||||
&finalizeMoeRoutingKernel<T, OutputType, true>};
|
||||
auto* const kernel = func_map[check_finished];
|
||||
kernel<<<blocks, threads, 0, stream>>>(
|
||||
expanded_permuted_rows, reduced_unpermuted_output, scales,
|
||||
expanded_source_row_to_expanded_dest_row, cols, k, num_valid_ptr);
|
||||
}
|
||||
707
csrc/moe/topk_softmax_kernels.cu
Normal file
707
csrc/moe/topk_softmax_kernels.cu
Normal file
@@ -0,0 +1,707 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
|
||||
* Copyright (c) 2024, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <type_traits>
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../cuda_compat.h"
|
||||
#include "../cub_helpers.h"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
#endif
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
/// Aligned array type
|
||||
template <
|
||||
typename T,
|
||||
/// Number of elements in the array
|
||||
int N,
|
||||
/// Alignment requirement in bytes
|
||||
int Alignment = sizeof(T) * N
|
||||
>
|
||||
struct alignas(Alignment) AlignedArray {
|
||||
T data[N];
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ float toFloat(T value) {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
return value;
|
||||
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
|
||||
return __bfloat162float(value);
|
||||
} else if constexpr (std::is_same_v<T, __half>) {
|
||||
return __half2float(value);
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing the output
|
||||
// in the softmax kernel when we extend this module to support expert-choice routing.
|
||||
template <int TPB, typename InputType>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moeSoftmax(const InputType* input, const bool* finished, float* output, const int num_cols)
|
||||
{
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
const int thread_row_offset = blockIdx.x * num_cols;
|
||||
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
// Don't touch finished rows.
|
||||
if ((finished != nullptr) && finished[blockIdx.x])
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val = toFloat(input[idx]);
|
||||
threadData = max(val, threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp());
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val = toFloat(input[idx]);
|
||||
threadData += expf(val - float_max);
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp());
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val = toFloat(input[idx]);
|
||||
const float softmax_val = expf(val - float_max) * normalizing_factor;
|
||||
output[idx] = softmax_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <int TPB, typename IndType>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(
|
||||
const float* inputs_after_softmax,
|
||||
const bool* finished,
|
||||
float* output,
|
||||
IndType* indices,
|
||||
int* source_rows,
|
||||
const int num_experts,
|
||||
const int k,
|
||||
const int start_expert,
|
||||
const int end_expert,
|
||||
const bool renormalize)
|
||||
{
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int num_rows = gridDim.x;
|
||||
const int block_row = blockIdx.x;
|
||||
|
||||
const bool row_is_active = finished ? !finished[block_row] : true;
|
||||
const int thread_read_offset = blockIdx.x * num_experts;
|
||||
float selected_sum = 0.f;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
|
||||
{
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
|
||||
{
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert)
|
||||
{
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
// Ignore experts the node isn't responsible for with expert parallelism
|
||||
const int expert = result_kvp.key;
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = result_kvp.value;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
|
||||
assert(indices[idx] >= 0);
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
if (renormalize) {
|
||||
selected_sum += result_kvp.value;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Renormalize the k weights for this row to sum to 1, if requested.
|
||||
if (renormalize) {
|
||||
if (threadIdx.x == 0) {
|
||||
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = output[idx] / denom;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== TopK softmax things ===============================
|
||||
|
||||
/*
|
||||
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
|
||||
are a small power of 2. This allows us to cleanly share the rows among the threads in
|
||||
a single warp and eliminate communication between warps (so no need to use shared mem).
|
||||
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is optimized for when the number of experts is a small power of 2.
|
||||
Additionally it also supports when number of experts is multiple of 64 which is still
|
||||
faster than the computing softmax and topK separately (only tested on CUDA yet).
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType, typename InputType = float>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
||||
void topkGatingSoftmax(const InputType* input, const bool* finished, float* output, const int num_rows, IndType* indices,
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert, const bool renormalize)
|
||||
{
|
||||
static_assert(std::is_same_v<InputType, float> || std::is_same_v<InputType, __nv_bfloat16> ||
|
||||
std::is_same_v<InputType, __half>,
|
||||
"InputType must be float, __nv_bfloat16, or __half");
|
||||
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
// Number of bytes each thread pulls in per load
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
|
||||
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
||||
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
||||
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
||||
|
||||
if constexpr (std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) {
|
||||
static_assert(ELTS_PER_LDG == 1 || ELTS_PER_LDG % 2 == 0,
|
||||
"ELTS_PER_LDG must be 1 or even for 16-bit conversion");
|
||||
}
|
||||
|
||||
// Restrictions based on previous section.
|
||||
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
|
||||
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM, "THREADS_PER_ROW can be at most warp size");
|
||||
|
||||
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT;
|
||||
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||
|
||||
// Restrictions for previous section.
|
||||
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
|
||||
|
||||
// ===================== From this point, we finally start computing run-time variables. ========================
|
||||
|
||||
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
|
||||
// This, each block processes a chunk of rows. We start by computing the start row for each block.
|
||||
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
|
||||
|
||||
// Now, using the base row per thread block, we compute the base row per warp.
|
||||
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
|
||||
|
||||
// The threads in a warp are split into sub-groups that will work on a row.
|
||||
// We compute row offset for each thread sub-group
|
||||
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
||||
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||
|
||||
// Threads with indices out of bounds should early exit here.
|
||||
if (thread_row >= num_rows)
|
||||
{
|
||||
return;
|
||||
}
|
||||
const bool row_is_active = finished ? !finished[thread_row] : true;
|
||||
|
||||
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
|
||||
// row it will read.
|
||||
const InputType* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
|
||||
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
|
||||
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const InputType* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
// Finally, we pull in the data from global mem
|
||||
float row_chunk[VPT];
|
||||
|
||||
// NOTE(zhuhaoran): dispatch different input types loading, BF16/FP16 convert to float
|
||||
if constexpr (std::is_same_v<InputType, float>) {
|
||||
using VecType = AlignedArray<float, ELTS_PER_LDG>;
|
||||
VecType* row_chunk_vec_ptr = reinterpret_cast<VecType*>(&row_chunk);
|
||||
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
}
|
||||
} else if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
|
||||
if constexpr (ELTS_PER_LDG >= 2) {
|
||||
using VecType = AlignedArray<__nv_bfloat16, ELTS_PER_LDG>;
|
||||
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
|
||||
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
|
||||
row_chunk_f2[base_idx_f2 + jj] = __bfloat1622float2(
|
||||
*reinterpret_cast<const __nv_bfloat162*>(vec.data + jj * 2)
|
||||
);
|
||||
}
|
||||
}
|
||||
} else { // ELTS_PER_LDG == 1
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
const __nv_bfloat16* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
|
||||
row_chunk[ii] = __bfloat162float(*scalar_ptr);
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<InputType, __half>) {
|
||||
if constexpr (ELTS_PER_LDG >= 2) {
|
||||
using VecType = AlignedArray<__half, ELTS_PER_LDG>;
|
||||
float2* row_chunk_f2 = reinterpret_cast<float2*>(row_chunk);
|
||||
const VecType* vec_thread_read_ptr = reinterpret_cast<const VecType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
VecType vec = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
int base_idx_f2 = ii * ELTS_PER_LDG / 2;
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < ELTS_PER_LDG / 2; ++jj) {
|
||||
row_chunk_f2[base_idx_f2 + jj] = __half22float2(
|
||||
*reinterpret_cast<const __half2*>(vec.data + jj * 2)
|
||||
);
|
||||
}
|
||||
}
|
||||
} else { // ELTS_PER_LDG == 1
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
||||
const __half* scalar_ptr = thread_read_ptr + ii * THREADS_PER_ROW;
|
||||
row_chunk[ii] = __half2float(*scalar_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
|
||||
// convert to float afterwards for the exp + sum reduction.
|
||||
float thread_max = row_chunk[0];
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < VPT; ++ii)
|
||||
{
|
||||
thread_max = max(thread_max, row_chunk[ii]);
|
||||
}
|
||||
|
||||
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
|
||||
}
|
||||
|
||||
// From this point, thread max in all the threads have the max within the row.
|
||||
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
|
||||
float row_sum = 0;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii)
|
||||
{
|
||||
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
|
||||
row_sum += row_chunk[ii];
|
||||
}
|
||||
|
||||
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
|
||||
}
|
||||
|
||||
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
||||
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
|
||||
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
|
||||
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
|
||||
// argmax after computing the softmax.
|
||||
const float reciprocal_row_sum = 1.f / row_sum;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii)
|
||||
{
|
||||
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
||||
}
|
||||
|
||||
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
|
||||
// with the max index.
|
||||
int start_col = first_elt_read_by_thread;
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
|
||||
float selected_sum = 0.f;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
int expert = start_col;
|
||||
#pragma unroll
|
||||
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
|
||||
{
|
||||
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||
|
||||
// No check on the experts here since columns with the smallest index are processed first and only
|
||||
// updated if > (not >=)
|
||||
if (val > max_val)
|
||||
{
|
||||
max_val = val;
|
||||
expert = col + ii;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
|
||||
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
|
||||
// then blank out their max with -inf and the warp can run more iterations...
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
|
||||
|
||||
// We want lower indices to "win" in every thread so we break ties this way
|
||||
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
||||
{
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
if (thread_group_idx == 0)
|
||||
{
|
||||
// Add a guard to ignore experts not included by this node
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
|
||||
// single) thread per row of the input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = max_val;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
if (renormalize) {
|
||||
selected_sum += max_val;
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
|
||||
if (k_idx + 1 < k)
|
||||
{
|
||||
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
|
||||
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
|
||||
if (thread_group_idx == thread_to_clear_in_group)
|
||||
{
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Renormalize the k weights for this row to sum to 1, if requested.
|
||||
if (renormalize) {
|
||||
if (thread_group_idx == 0)
|
||||
{
|
||||
const float denom = selected_sum > 0.f ? selected_sum : 1.f;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = output[idx] / denom;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail
|
||||
{
|
||||
// Constructs some constants needed to partition the work across threads at compile time.
|
||||
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename InputType>
|
||||
struct TopkConstants
|
||||
{
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(InputType);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType, typename InputType>
|
||||
void topkGatingSoftmaxLauncherHelper(const InputType* input, const bool* finished, float* output, IndType* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool renormalize,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(InputType) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM, InputType>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
||||
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM, IndType, InputType><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, renormalize);
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
||||
static_assert(WARP_SIZE == 32, \
|
||||
"Unsupported warp size. Only 32 is supported for CUDA"); \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
|
||||
num_tokens, topk, 0, num_experts, renormalize, stream);
|
||||
#else
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
||||
if (WARP_SIZE == 64) { \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
|
||||
num_tokens, topk, 0, num_experts, renormalize, stream); \
|
||||
} else if (WARP_SIZE == 32) { \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, token_expert_indices, \
|
||||
num_tokens, topk, 0, num_experts, renormalize, stream); \
|
||||
} else { \
|
||||
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename IndType, typename InputType>
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
const InputType* gating_output,
|
||||
float* topk_weights,
|
||||
IndType* topk_indices,
|
||||
int* token_expert_indices,
|
||||
float* softmax_workspace,
|
||||
const int num_tokens,
|
||||
const int num_experts,
|
||||
const int topk,
|
||||
const bool renormalize,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
|
||||
#ifndef USE_ROCM
|
||||
// for bfloat16 dtype, we need 4 bytes loading to make sure num_experts
|
||||
// elements can be loaded by a warp
|
||||
static constexpr int BYTES_PER_LDG_MULTIPLE_64 =
|
||||
(std::is_same_v<InputType, __nv_bfloat16> || std::is_same_v<InputType, __half>) ? 4 : 8;
|
||||
#endif
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 512:
|
||||
LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
|
||||
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
|
||||
// alternatively we can test 4 bytes loading and enable it in future.
|
||||
#ifndef USE_ROCM
|
||||
case 192:
|
||||
LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 320:
|
||||
LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 384:
|
||||
LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 448:
|
||||
LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 576:
|
||||
LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
#endif
|
||||
default: {
|
||||
TORCH_CHECK(softmax_workspace != nullptr,
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
|
||||
static constexpr int TPB = 256;
|
||||
moeSoftmax<TPB, InputType><<<num_tokens, TPB, 0, stream>>>(
|
||||
gating_output, nullptr, softmax_workspace, num_experts);
|
||||
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices,
|
||||
num_experts, topk, 0, num_experts, renormalize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
|
||||
template<typename ComputeType>
|
||||
void dispatch_topk_softmax_launch(
|
||||
torch::Tensor& gating_output,
|
||||
torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& softmax_workspace,
|
||||
int num_tokens, int num_experts, int topk, bool renormalize, cudaStream_t stream)
|
||||
{
|
||||
if (topk_indices.scalar_type() == at::ScalarType::Int) {
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher<int, ComputeType>(
|
||||
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens, num_experts, topk, renormalize, stream);
|
||||
} else if (topk_indices.scalar_type() == at::ScalarType::UInt32) {
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher<uint32_t, ComputeType>(
|
||||
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<uint32_t>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens, num_experts, topk, renormalize, stream);
|
||||
} else {
|
||||
TORCH_CHECK(topk_indices.scalar_type() == at::ScalarType::Long);
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher<int64_t, ComputeType>(
|
||||
reinterpret_cast<const ComputeType*>(gating_output.data_ptr()),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int64_t>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens, num_experts, topk, renormalize, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights, // [num_tokens, topk]
|
||||
torch::Tensor& topk_indices, // [num_tokens, topk]
|
||||
torch::Tensor& token_expert_indices, // [num_tokens, topk]
|
||||
torch::Tensor& gating_output, // [num_tokens, num_experts]
|
||||
bool renormalize)
|
||||
{
|
||||
const int num_experts = gating_output.size(-1);
|
||||
const auto num_tokens = gating_output.numel() / num_experts;
|
||||
const int topk = topk_weights.size(-1);
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
const bool needs_workspace = !is_pow_2 || num_experts > 256;
|
||||
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const auto workspace_options = gating_output.options().dtype(at::ScalarType::Float);
|
||||
torch::Tensor softmax_workspace = torch::empty({workspace_size}, workspace_options);
|
||||
|
||||
if (gating_output.scalar_type() == at::ScalarType::Float) {
|
||||
dispatch_topk_softmax_launch<float>(gating_output, topk_weights, topk_indices,
|
||||
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
|
||||
} else if (gating_output.scalar_type() == at::ScalarType::Half) {
|
||||
dispatch_topk_softmax_launch<__half>(gating_output, topk_weights, topk_indices,
|
||||
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
|
||||
} else if (gating_output.scalar_type() == at::ScalarType::BFloat16) {
|
||||
dispatch_topk_softmax_launch<__nv_bfloat16>(gating_output, topk_weights, topk_indices,
|
||||
token_expert_indices, softmax_workspace, num_tokens, num_experts, topk, renormalize, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported gating_output data type: ", gating_output.scalar_type());
|
||||
}
|
||||
}
|
||||
122
csrc/moe/torch_bindings.cpp
Normal file
122
csrc/moe/torch_bindings.cpp
Normal file
@@ -0,0 +1,122 @@
|
||||
#include "core/registration.h"
|
||||
#include "moe_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
// Apply topk softmax to the gating outputs.
|
||||
m.def(
|
||||
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||
"token_expert_indices, Tensor gating_output, bool renormalize) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
|
||||
// Calculate the result of moe by summing up the partial results
|
||||
// from all selected experts.
|
||||
m.def("moe_sum(Tensor input, Tensor! output) -> ()");
|
||||
m.impl("moe_sum", torch::kCUDA, &moe_sum);
|
||||
|
||||
// Aligning the number of tokens to be processed by each expert such
|
||||
// that it is divisible by the block size.
|
||||
m.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts,"
|
||||
" int block_size, Tensor! sorted_token_ids,"
|
||||
" Tensor! experts_ids,"
|
||||
" Tensor! num_tokens_post_pad,"
|
||||
" Tensor? maybe_expert_map) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
// Aligning the number of tokens to be processed by each expert such
|
||||
// that it is divisible by the block size, but for the batched case.
|
||||
m.def(
|
||||
"batched_moe_align_block_size(int max_tokens_per_batch,"
|
||||
" int block_size, Tensor expert_num_tokens,"
|
||||
" Tensor! sorted_token_ids,"
|
||||
" Tensor! experts_ids,"
|
||||
" Tensor! num_tokens_post_pad) -> ()");
|
||||
m.impl("batched_moe_align_block_size", torch::kCUDA,
|
||||
&batched_moe_align_block_size);
|
||||
|
||||
// Aligning the number of tokens to be processed by each expert such
|
||||
// that it is divisible by the block size.
|
||||
m.def(
|
||||
"moe_lora_align_block_size(Tensor topk_ids,"
|
||||
" Tensor token_lora_mapping,"
|
||||
" int num_experts,"
|
||||
" int block_size, int max_loras, "
|
||||
" int max_num_tokens_padded, "
|
||||
" int max_num_m_blocks, "
|
||||
" Tensor !sorted_token_ids,"
|
||||
" Tensor !experts_ids,"
|
||||
" Tensor !num_tokens_post_pad,"
|
||||
" Tensor !adapter_enabled,"
|
||||
" Tensor !lora_ids,"
|
||||
" Tensor? maybe_expert_map) -> () ");
|
||||
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
|
||||
"Tensor b_scales, Tensor? b_qzeros, "
|
||||
"Tensor? topk_weights, Tensor sorted_token_ids, "
|
||||
"Tensor expert_ids, Tensor num_tokens_post_pad, "
|
||||
"int top_k, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, "
|
||||
"int bit) -> Tensor");
|
||||
|
||||
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
|
||||
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor? b_bias_or_none,"
|
||||
"Tensor! b_scales, Tensor? a_scales, Tensor? global_scale, Tensor? "
|
||||
"b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||
"bool mul_topk_weights, bool is_ep, int b_type_id,"
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_full_k, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float,"
|
||||
"int thread_k, int thread_n, int blocks_per_sm) -> Tensor");
|
||||
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
|
||||
"int b_q_type, SymInt size_m, "
|
||||
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
|
||||
"topk, "
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
" -> Tensor");
|
||||
|
||||
m.def(
|
||||
"moe_permute(Tensor input, Tensor topk_ids,"
|
||||
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
|
||||
"int n_local_expert,"
|
||||
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
|
||||
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
|
||||
"permuted_idx, Tensor! m_indices)->()");
|
||||
|
||||
m.def(
|
||||
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
|
||||
"Tensor inv_permuted_idx, Tensor? expert_first_token_offset, "
|
||||
"int topk, Tensor! hidden_states)->()");
|
||||
|
||||
m.def("moe_permute_unpermute_supported() -> bool");
|
||||
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
|
||||
|
||||
// Row shuffle for MoE
|
||||
m.def(
|
||||
"shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! "
|
||||
"output_tensor) -> ()");
|
||||
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
|
||||
|
||||
// Apply grouped topk routing to select experts.
|
||||
m.def(
|
||||
"grouped_topk(Tensor scores, int n_group, int "
|
||||
"topk_group, int topk, bool renormalize, float "
|
||||
"routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, "
|
||||
"Tensor)");
|
||||
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
Reference in New Issue
Block a user