|
|
|
|
@@ -41,15 +41,29 @@ template <
|
|
|
|
|
/// Alignment requirement in bytes
|
|
|
|
|
int Alignment = sizeof(T) * N>
|
|
|
|
|
class alignas(Alignment) AlignedArray {
|
|
|
|
|
float data[N];
|
|
|
|
|
T data[N];
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// ========================== Util functions to convert types ==========================
|
|
|
|
|
template <typename T>
|
|
|
|
|
__device__ float convert_to_float(T x) {
|
|
|
|
|
if constexpr (std::is_same_v<T, __half>) {
|
|
|
|
|
return __half2float(x);
|
|
|
|
|
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
|
|
|
|
|
return __bfloat162float(x);
|
|
|
|
|
} else if constexpr (std::is_same_v<T, float>) {
|
|
|
|
|
return x;
|
|
|
|
|
} else {
|
|
|
|
|
return static_cast<float>(x);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ====================== Softmax things ===============================
|
|
|
|
|
// We have our own implementation of softmax here so we can support transposing the output
|
|
|
|
|
// in the softmax kernel when we extend this module to support expert-choice routing.
|
|
|
|
|
template <int TPB>
|
|
|
|
|
template <typename T, int TPB>
|
|
|
|
|
__launch_bounds__(TPB) __global__
|
|
|
|
|
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) {
|
|
|
|
|
void moeSoftmax(const T* input, const bool* finished, float* output, const int num_cols) {
|
|
|
|
|
using BlockReduce = cub::BlockReduce<float, TPB>;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
|
|
|
|
|
|
|
|
|
@@ -68,7 +82,7 @@ __launch_bounds__(TPB) __global__
|
|
|
|
|
|
|
|
|
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
|
|
|
|
const int idx = thread_row_offset + ii;
|
|
|
|
|
threadData = max(static_cast<float>(input[idx]), threadData);
|
|
|
|
|
threadData = max(convert_to_float<T>(input[idx]), threadData);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
|
|
|
|
@@ -82,7 +96,7 @@ __launch_bounds__(TPB) __global__
|
|
|
|
|
|
|
|
|
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
|
|
|
|
const int idx = thread_row_offset + ii;
|
|
|
|
|
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
|
|
|
|
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
|
|
|
|
@@ -94,7 +108,7 @@ __launch_bounds__(TPB) __global__
|
|
|
|
|
|
|
|
|
|
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
|
|
|
|
const int idx = thread_row_offset + ii;
|
|
|
|
|
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
|
|
|
|
const float val = exp((convert_to_float<T>(input[idx]) - float_max)) * normalizing_factor;
|
|
|
|
|
output[idx] = val;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -105,11 +119,11 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
|
|
|
|
const bool* finished,
|
|
|
|
|
float* output,
|
|
|
|
|
int* indices,
|
|
|
|
|
int* source_rows,
|
|
|
|
|
const int num_experts,
|
|
|
|
|
const int k,
|
|
|
|
|
const int start_expert,
|
|
|
|
|
const int end_expert) {
|
|
|
|
|
const int end_expert,
|
|
|
|
|
const bool renormalize) {
|
|
|
|
|
using cub_kvp = cub::KeyValuePair<int, float>;
|
|
|
|
|
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
|
|
|
|
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
|
|
|
|
@@ -117,11 +131,11 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
|
|
|
|
cub_kvp thread_kvp;
|
|
|
|
|
cub::ArgMax arg_max;
|
|
|
|
|
|
|
|
|
|
const int num_rows = gridDim.x;
|
|
|
|
|
const int block_row = blockIdx.x;
|
|
|
|
|
|
|
|
|
|
const bool row_is_active = finished ? !finished[block_row] : true;
|
|
|
|
|
const int thread_read_offset = blockIdx.x * num_experts;
|
|
|
|
|
float row_sum_for_renormalize = 0;
|
|
|
|
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
|
|
|
|
thread_kvp.key = 0;
|
|
|
|
|
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
|
|
|
|
|
@@ -154,10 +168,18 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
|
|
|
|
output[idx] = result_kvp.value;
|
|
|
|
|
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
|
|
|
|
|
assert(indices[idx] >= 0);
|
|
|
|
|
source_rows[idx] = k_idx * num_rows + block_row;
|
|
|
|
|
row_sum_for_renormalize += result_kvp.value;
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (renormalize && threadIdx.x == 0) {
|
|
|
|
|
float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize;
|
|
|
|
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
|
|
|
|
const int idx = k * block_row + k_idx;
|
|
|
|
|
output[idx] = output[idx] * row_sum_for_renormalize_inv;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ====================== TopK softmax things ===============================
|
|
|
|
|
@@ -174,17 +196,17 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
|
|
|
|
2) This implementation assumes k is small, but will work for any k.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
|
|
|
|
template <typename T, int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
|
|
|
|
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
|
|
|
|
|
const float* input,
|
|
|
|
|
const T* input,
|
|
|
|
|
const bool* finished,
|
|
|
|
|
float* output,
|
|
|
|
|
const int num_rows,
|
|
|
|
|
int* indices,
|
|
|
|
|
int* source_rows,
|
|
|
|
|
const int k,
|
|
|
|
|
const int start_expert,
|
|
|
|
|
const int end_expert) {
|
|
|
|
|
const int end_expert,
|
|
|
|
|
const bool renormalize) {
|
|
|
|
|
// We begin by enforcing compile time assertions and setting up compile time constants.
|
|
|
|
|
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
|
|
|
|
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
|
|
|
|
@@ -192,7 +214,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
|
|
|
|
|
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
|
|
|
|
|
|
|
|
|
// Number of bytes each thread pulls in per load
|
|
|
|
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
|
|
|
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
|
|
|
|
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
|
|
|
|
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
|
|
|
|
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
|
|
|
|
@@ -233,28 +255,34 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
|
|
|
|
|
|
|
|
|
|
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
|
|
|
|
|
// row it will read.
|
|
|
|
|
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
|
|
|
|
const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
|
|
|
|
|
|
|
|
|
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
|
|
|
|
|
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
|
|
|
|
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
|
|
|
|
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
|
|
|
|
const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
|
|
|
|
|
|
|
|
|
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
|
|
|
|
|
// this can support all powers of 2 up to 16.
|
|
|
|
|
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
|
|
|
|
|
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
|
|
|
|
|
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
|
|
|
|
|
using AccessType = AlignedArray<T, ELTS_PER_LDG>;
|
|
|
|
|
|
|
|
|
|
// Finally, we pull in the data from global mem
|
|
|
|
|
float row_chunk[VPT];
|
|
|
|
|
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
|
|
|
|
|
T row_chunk_temp[VPT];
|
|
|
|
|
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk_temp);
|
|
|
|
|
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {
|
|
|
|
|
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float row_chunk[VPT];
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int ii = 0; ii < VPT; ++ii) {
|
|
|
|
|
row_chunk[ii] = convert_to_float<T>(row_chunk_temp[ii]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
|
|
|
|
|
// convert to float afterwards for the exp + sum reduction.
|
|
|
|
|
float thread_max = row_chunk[0];
|
|
|
|
|
@@ -301,6 +329,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
|
|
|
|
|
int start_col = first_elt_read_by_thread;
|
|
|
|
|
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
|
|
|
|
|
|
|
|
|
float row_sum_for_renormalize = 0;
|
|
|
|
|
|
|
|
|
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
|
|
|
|
// First, each thread does the local argmax
|
|
|
|
|
float max_val = row_chunk[0];
|
|
|
|
|
@@ -346,7 +376,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
|
|
|
|
|
const int idx = k * thread_row + k_idx;
|
|
|
|
|
output[idx] = max_val;
|
|
|
|
|
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
|
|
|
|
source_rows[idx] = k_idx * num_rows + thread_row;
|
|
|
|
|
row_sum_for_renormalize += max_val;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
|
|
|
|
|
@@ -362,13 +392,23 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Fuse renormalization of topk_weights into this kernel
|
|
|
|
|
if (renormalize && thread_group_idx == 0) {
|
|
|
|
|
float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize;
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
|
|
|
|
const int idx = k * thread_row + k_idx;
|
|
|
|
|
output[idx] = output[idx] * row_sum_for_renormalize_inv;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
namespace detail {
|
|
|
|
|
// Constructs some constants needed to partition the work across threads at compile time.
|
|
|
|
|
template <int EXPERTS, int BYTES_PER_LDG>
|
|
|
|
|
template <typename T, int EXPERTS, int BYTES_PER_LDG>
|
|
|
|
|
struct TopkConstants {
|
|
|
|
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
|
|
|
|
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T);
|
|
|
|
|
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
|
|
|
|
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
|
|
|
|
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
|
|
|
|
@@ -377,113 +417,120 @@ struct TopkConstants {
|
|
|
|
|
};
|
|
|
|
|
} // namespace detail
|
|
|
|
|
|
|
|
|
|
template <int EXPERTS, int WARPS_PER_TB>
|
|
|
|
|
template <typename T, int EXPERTS, int WARPS_PER_TB>
|
|
|
|
|
void topkGatingSoftmaxLauncherHelper(
|
|
|
|
|
const float* input,
|
|
|
|
|
const T* input,
|
|
|
|
|
const bool* finished,
|
|
|
|
|
float* output,
|
|
|
|
|
int* indices,
|
|
|
|
|
int* source_row,
|
|
|
|
|
const int num_rows,
|
|
|
|
|
const int k,
|
|
|
|
|
const int start_expert,
|
|
|
|
|
const int end_expert,
|
|
|
|
|
const bool renormalize,
|
|
|
|
|
cudaStream_t stream) {
|
|
|
|
|
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
|
|
|
|
|
|
|
|
|
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
|
|
|
|
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
|
|
|
|
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS);
|
|
|
|
|
using Constants = detail::TopkConstants<T, EXPERTS, BYTES_PER_LDG>;
|
|
|
|
|
static constexpr int VPT = Constants::VPT;
|
|
|
|
|
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
|
|
|
|
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
|
|
|
|
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
|
|
|
|
|
|
|
|
|
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
|
|
|
|
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
|
|
|
|
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
|
|
|
|
topkGatingSoftmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
|
|
|
|
input, finished, output, num_rows, indices, k, start_expert, end_expert, renormalize);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
|
|
|
|
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
|
|
|
|
|
gating_output, \
|
|
|
|
|
nullptr, \
|
|
|
|
|
topk_weights, \
|
|
|
|
|
topk_indices, \
|
|
|
|
|
token_expert_indices, \
|
|
|
|
|
num_tokens, \
|
|
|
|
|
topk, \
|
|
|
|
|
0, \
|
|
|
|
|
num_experts, \
|
|
|
|
|
stream);
|
|
|
|
|
#define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \
|
|
|
|
|
topkGatingSoftmaxLauncherHelper<TYPE, NUM_EXPERTS, WARPS_PER_TB>( \
|
|
|
|
|
gating_output, nullptr, topk_weights, topk_indices, num_tokens, topk, 0, num_experts, renormalize, stream);
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void topkGatingSoftmaxKernelLauncher(
|
|
|
|
|
const float* gating_output,
|
|
|
|
|
const T* gating_output,
|
|
|
|
|
float* topk_weights,
|
|
|
|
|
int* topk_indices,
|
|
|
|
|
int* token_expert_indices,
|
|
|
|
|
float* softmax_workspace,
|
|
|
|
|
const int num_tokens,
|
|
|
|
|
const int num_experts,
|
|
|
|
|
const int topk,
|
|
|
|
|
const bool renormalize,
|
|
|
|
|
cudaStream_t stream) {
|
|
|
|
|
static constexpr int WARPS_PER_TB = 4;
|
|
|
|
|
switch (num_experts) {
|
|
|
|
|
case 1:
|
|
|
|
|
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
|
|
|
|
LAUNCH_SOFTMAX(T, 1, WARPS_PER_TB);
|
|
|
|
|
break;
|
|
|
|
|
case 2:
|
|
|
|
|
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
|
|
|
|
LAUNCH_SOFTMAX(T, 2, WARPS_PER_TB);
|
|
|
|
|
break;
|
|
|
|
|
case 4:
|
|
|
|
|
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
|
|
|
|
LAUNCH_SOFTMAX(T, 4, WARPS_PER_TB);
|
|
|
|
|
break;
|
|
|
|
|
case 8:
|
|
|
|
|
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
|
|
|
|
LAUNCH_SOFTMAX(T, 8, WARPS_PER_TB);
|
|
|
|
|
break;
|
|
|
|
|
case 16:
|
|
|
|
|
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
|
|
|
|
LAUNCH_SOFTMAX(T, 16, WARPS_PER_TB);
|
|
|
|
|
break;
|
|
|
|
|
case 32:
|
|
|
|
|
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
|
|
|
|
LAUNCH_SOFTMAX(T, 32, WARPS_PER_TB);
|
|
|
|
|
break;
|
|
|
|
|
case 64:
|
|
|
|
|
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
|
|
|
|
LAUNCH_SOFTMAX(T, 64, WARPS_PER_TB);
|
|
|
|
|
break;
|
|
|
|
|
case 128:
|
|
|
|
|
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
|
|
|
|
LAUNCH_SOFTMAX(T, 128, WARPS_PER_TB);
|
|
|
|
|
break;
|
|
|
|
|
case 256:
|
|
|
|
|
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
|
|
|
|
LAUNCH_SOFTMAX(T, 256, WARPS_PER_TB);
|
|
|
|
|
break;
|
|
|
|
|
default: {
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
softmax_workspace != nullptr,
|
|
|
|
|
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
|
|
|
|
static constexpr int TPB = 256;
|
|
|
|
|
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(gating_output, nullptr, softmax_workspace, num_experts);
|
|
|
|
|
moeSoftmax<T, TPB><<<num_tokens, TPB, 0, stream>>>(gating_output, nullptr, softmax_workspace, num_experts);
|
|
|
|
|
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
|
|
|
|
|
softmax_workspace,
|
|
|
|
|
nullptr,
|
|
|
|
|
topk_weights,
|
|
|
|
|
topk_indices,
|
|
|
|
|
token_expert_indices,
|
|
|
|
|
num_experts,
|
|
|
|
|
topk,
|
|
|
|
|
0,
|
|
|
|
|
num_experts);
|
|
|
|
|
softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void topk_softmax(
|
|
|
|
|
torch::Tensor& topk_weights, // [num_tokens, topk]
|
|
|
|
|
torch::Tensor& topk_indices, // [num_tokens, topk]
|
|
|
|
|
torch::Tensor& token_expert_indices, // [num_tokens, topk]
|
|
|
|
|
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
|
|
|
|
torch::Tensor& topk_weights, // [num_tokens, topk]
|
|
|
|
|
torch::Tensor& topk_indices, // [num_tokens, topk]
|
|
|
|
|
torch::Tensor& gating_output,
|
|
|
|
|
const bool renormalize) // [num_tokens, num_experts]
|
|
|
|
|
{
|
|
|
|
|
const int num_experts = gating_output.size(-1);
|
|
|
|
|
const int num_tokens = gating_output.numel() / num_experts;
|
|
|
|
|
const int topk = topk_weights.size(-1);
|
|
|
|
|
// Check data type
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half ||
|
|
|
|
|
gating_output.scalar_type() == at::ScalarType::BFloat16,
|
|
|
|
|
"gating_output must be float32, float16, or bfloat16");
|
|
|
|
|
|
|
|
|
|
// Check dimensions
|
|
|
|
|
TORCH_CHECK(gating_output.dim() == 2, "gating_output must be 2D tensor [num_tokens, num_experts]");
|
|
|
|
|
TORCH_CHECK(topk_weights.dim() == 2, "topk_weights must be 2D tensor [num_tokens, topk]");
|
|
|
|
|
TORCH_CHECK(topk_indices.dim() == 2, "topk_indices must be 2D tensor [num_tokens, topk]");
|
|
|
|
|
|
|
|
|
|
// Check shapes
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
gating_output.size(0) == topk_weights.size(0),
|
|
|
|
|
"First dimension of topk_weights must match num_tokens in gating_output");
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
gating_output.size(0) == topk_indices.size(0),
|
|
|
|
|
"First dimension of topk_indices must match num_tokens in gating_output");
|
|
|
|
|
TORCH_CHECK(
|
|
|
|
|
topk_weights.size(-1) == topk_indices.size(-1),
|
|
|
|
|
"Second dimension of topk_indices must match topk in topk_weights");
|
|
|
|
|
TORCH_CHECK(topk_weights.size(-1) <= gating_output.size(-1), "topk must be less than or equal to num_experts");
|
|
|
|
|
|
|
|
|
|
const int num_experts = static_cast<int>(gating_output.size(-1));
|
|
|
|
|
const int num_tokens = static_cast<int>(gating_output.size(0));
|
|
|
|
|
const int topk = static_cast<int>(topk_weights.size(-1));
|
|
|
|
|
|
|
|
|
|
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
|
|
|
|
const bool needs_workspace = !is_pow_2 || num_experts > 256;
|
|
|
|
|
@@ -491,15 +538,44 @@ void topk_softmax(
|
|
|
|
|
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
|
|
|
|
topkGatingSoftmaxKernelLauncher(
|
|
|
|
|
gating_output.data_ptr<float>(),
|
|
|
|
|
topk_weights.data_ptr<float>(),
|
|
|
|
|
topk_indices.data_ptr<int>(),
|
|
|
|
|
token_expert_indices.data_ptr<int>(),
|
|
|
|
|
softmax_workspace.data_ptr<float>(),
|
|
|
|
|
num_tokens,
|
|
|
|
|
num_experts,
|
|
|
|
|
topk,
|
|
|
|
|
stream);
|
|
|
|
|
torch::Tensor softmax_workspace =
|
|
|
|
|
torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float));
|
|
|
|
|
|
|
|
|
|
const at::ScalarType dtype = gating_output.scalar_type();
|
|
|
|
|
if (dtype == at::ScalarType::Float) {
|
|
|
|
|
topkGatingSoftmaxKernelLauncher<float>(
|
|
|
|
|
gating_output.data_ptr<float>(),
|
|
|
|
|
topk_weights.data_ptr<float>(),
|
|
|
|
|
topk_indices.data_ptr<int>(),
|
|
|
|
|
softmax_workspace.data_ptr<float>(),
|
|
|
|
|
num_tokens,
|
|
|
|
|
num_experts,
|
|
|
|
|
topk,
|
|
|
|
|
renormalize,
|
|
|
|
|
stream);
|
|
|
|
|
} else if (dtype == at::ScalarType::Half) {
|
|
|
|
|
topkGatingSoftmaxKernelLauncher<__half>(
|
|
|
|
|
reinterpret_cast<const __half*>(gating_output.data_ptr<at::Half>()),
|
|
|
|
|
topk_weights.data_ptr<float>(),
|
|
|
|
|
topk_indices.data_ptr<int>(),
|
|
|
|
|
softmax_workspace.data_ptr<float>(),
|
|
|
|
|
num_tokens,
|
|
|
|
|
num_experts,
|
|
|
|
|
topk,
|
|
|
|
|
renormalize,
|
|
|
|
|
stream);
|
|
|
|
|
} else if (dtype == at::ScalarType::BFloat16) {
|
|
|
|
|
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>(
|
|
|
|
|
reinterpret_cast<const __nv_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
|
|
|
|
|
topk_weights.data_ptr<float>(),
|
|
|
|
|
topk_indices.data_ptr<int>(),
|
|
|
|
|
softmax_workspace.data_ptr<float>(),
|
|
|
|
|
num_tokens,
|
|
|
|
|
num_experts,
|
|
|
|
|
topk,
|
|
|
|
|
renormalize,
|
|
|
|
|
stream);
|
|
|
|
|
} else {
|
|
|
|
|
TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|