Add fp8 fused_experts kernel for CPU in sgl-kernel and add UT (#6404)
This commit is contained in:
@@ -85,6 +85,32 @@ void fused_experts_int8_kernel_impl(
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// moe implementations for fp8 w8a16
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// shared expert implementation for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
|
||||
@@ -932,6 +932,40 @@ void shared_expert_kernel_impl(
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// common checks
|
||||
static inline void check_moe_scales(
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale) {
|
||||
if (use_int8_w8a8) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8.");
|
||||
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
|
||||
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
|
||||
TORCH_CHECK(block_size.value().size() == 2, "expect block_size.size() to be 2.");
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
|
||||
auto w1s = w1_scale.value(); \
|
||||
auto w2s = w2_scale.value(); \
|
||||
auto block_size_val = block_size.value(); \
|
||||
int64_t block_size_N = block_size_val[0]; \
|
||||
int64_t block_size_K = block_size_val[1]; \
|
||||
TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \
|
||||
TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \
|
||||
TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \
|
||||
TORCH_CHECK(w2s.size(DIM1) == N / block_size_K)
|
||||
|
||||
// hidden_states: [M, K]
|
||||
// w1: [E, 2N, K]
|
||||
// w2: [E, K, N]
|
||||
@@ -946,8 +980,10 @@ at::Tensor fused_experts_cpu(
|
||||
at::Tensor& topk_ids,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni) {
|
||||
@@ -990,12 +1026,8 @@ at::Tensor fused_experts_cpu(
|
||||
CHECK_EQ(packed_w1.size(2), packed_K);
|
||||
CHECK_EQ(packed_w2.size(2), packed_N);
|
||||
|
||||
if (use_int8_w8a8) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8.");
|
||||
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
|
||||
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
|
||||
}
|
||||
// check scales
|
||||
check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale);
|
||||
|
||||
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
|
||||
|
||||
@@ -1047,6 +1079,9 @@ at::Tensor fused_experts_cpu(
|
||||
// 5. Aq_tmp : [M, K] or [M * topk, N]
|
||||
// 6. As_tmp : [M * topk]
|
||||
//
|
||||
// for fp8 w8a16:
|
||||
// 7. intermediate_cache1 : [M * topk, 2N]
|
||||
//
|
||||
int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 +
|
||||
num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) +
|
||||
num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
|
||||
@@ -1054,6 +1089,9 @@ at::Tensor fused_experts_cpu(
|
||||
if (use_int8_w8a8) {
|
||||
buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float);
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
buffer_size_nbytes += M * topk * 2 * N * 2;
|
||||
}
|
||||
|
||||
auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
|
||||
|
||||
@@ -1095,6 +1133,35 @@ at::Tensor fused_experts_cpu(
|
||||
E,
|
||||
topk,
|
||||
num_tokens_post_pad);
|
||||
} else if (use_fp8_w8a16) {
|
||||
// here we just ignore C_tmp as it is not used
|
||||
scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K));
|
||||
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(A_tmp + num_threads * BLOCK_M * K));
|
||||
|
||||
CHECK_MOE_SCALES_FP8(1, 2);
|
||||
fused_experts_fp8_kernel_impl(
|
||||
out_hidden_states.data_ptr<scalar_t>(),
|
||||
intermediate_cache0,
|
||||
intermediate_cache1,
|
||||
intermediate_cache2,
|
||||
A_tmp,
|
||||
hidden_states.data_ptr<scalar_t>(),
|
||||
packed_w1.data_ptr<at::Float8_e4m3fn>(),
|
||||
packed_w2.data_ptr<at::Float8_e4m3fn>(),
|
||||
w1s.data_ptr<float>(),
|
||||
w2s.data_ptr<float>(),
|
||||
block_size_N,
|
||||
block_size_K,
|
||||
topk_weights.data_ptr<float>(),
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
offsets,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
E,
|
||||
topk,
|
||||
num_tokens_post_pad);
|
||||
} else {
|
||||
scalar_t* __restrict__ A_tmp = intermediate_cache2 + M * topk * K;
|
||||
float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K));
|
||||
@@ -1176,17 +1243,8 @@ at::Tensor shared_expert_cpu(
|
||||
CHECK_EQ(packed_w1.size(1), packed_K);
|
||||
CHECK_EQ(packed_w2.size(1), packed_N);
|
||||
|
||||
if (use_int8_w8a8) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8.");
|
||||
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
|
||||
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
|
||||
}
|
||||
// check scales
|
||||
check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale);
|
||||
|
||||
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
|
||||
|
||||
@@ -1244,17 +1302,7 @@ at::Tensor shared_expert_cpu(
|
||||
} else if (use_fp8_w8a16) {
|
||||
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
|
||||
|
||||
auto w1s = w1_scale.value();
|
||||
auto w2s = w2_scale.value();
|
||||
auto block_size_val = block_size.value();
|
||||
TORCH_CHECK(block_size_val.size() == 2, "shared_expert: expect block_size.size() to be 2.");
|
||||
int64_t block_size_N = block_size_val[0];
|
||||
int64_t block_size_K = block_size_val[1];
|
||||
TORCH_CHECK(w1s.size(0) == 2 * N / block_size_N);
|
||||
TORCH_CHECK(w1s.size(1) == K / block_size_K);
|
||||
TORCH_CHECK(w2s.size(0) == K / block_size_N);
|
||||
TORCH_CHECK(w2s.size(1) == N / block_size_K);
|
||||
|
||||
CHECK_MOE_SCALES_FP8(0, 1);
|
||||
shared_expert_fp8_kernel_impl<scalar_t>(
|
||||
out_hidden_states.data_ptr<scalar_t>(),
|
||||
intermediate_cache0,
|
||||
|
||||
@@ -4,6 +4,76 @@
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
Vec data = Vec::loadu(input + d);
|
||||
data.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec weight_vec = fVec(weight);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
x0 = x0 * weight_vec;
|
||||
x1 = x1 * weight_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
// acc from [topk, K] to [K]
|
||||
template <typename scalar_t>
|
||||
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
if (topk == 1) {
|
||||
// do copy for topk = 1
|
||||
copy_stub(out, input, K);
|
||||
} else {
|
||||
// do sum for topk != 1
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= K - kVecSize; d += kVecSize) {
|
||||
fVec sum_fvec0 = fVec(0.f);
|
||||
fVec sum_fvec1 = fVec(0.f);
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
bVec x_bvec = bVec::loadu(input + t * K + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec0 += x_fvec0;
|
||||
sum_fvec1 += x_fvec1;
|
||||
}
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < K; ++d) {
|
||||
float sum_val = 0.f;
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
sum_val += static_cast<float>(input[t * K + d]);
|
||||
}
|
||||
out[d] = static_cast<scalar_t>(sum_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(
|
||||
@@ -65,6 +135,215 @@ inline void silu_and_mul_stub(
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_N = div_up(2 * N, block_size_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const int64_t stride_e = 2 * N * K;
|
||||
const int64_t stride_n = K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs =
|
||||
w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, input + index * K, K);
|
||||
}
|
||||
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ Btmp,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
|
||||
if (is_brgemm_used) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(ic1 + m * N, ic0 + m * 2 * N, ic0 + m * 2 * N + N, N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [E, K, N] as [E, OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
scale_size_N = div_up(K, block_size_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
const int64_t stride_e2 = OC * IC;
|
||||
const int64_t stride_oc = IC;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N];
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_K];
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs =
|
||||
w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* Btmp */ Btmp,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m];
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_brgemm_used) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 3: out = intermediate_cache2.sum(dim=1)
|
||||
// from [M, topk, K] to [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \
|
||||
template void fused_experts_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ ic2, \
|
||||
TYPE* __restrict__ A_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const float* __restrict__ topk_weights, \
|
||||
const int32_t* __restrict__ sorted_ids, \
|
||||
const int32_t* __restrict__ expert_ids, \
|
||||
const int32_t* __restrict__ offsets, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t E, \
|
||||
int64_t topk, \
|
||||
int64_t num_tokens_post_pad)
|
||||
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::Half);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
@@ -100,8 +379,8 @@ void shared_expert_fp8_kernel_impl(
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
int64_t mb_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t nb_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ input + mb * BLOCK_M * K,
|
||||
@@ -110,11 +389,11 @@ void shared_expert_fp8_kernel_impl(
|
||||
/* Btmp */ Btmp,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
@@ -149,8 +428,8 @@ void shared_expert_fp8_kernel_impl(
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
int64_t mb_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t nb_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
@@ -160,11 +439,11 @@ void shared_expert_fp8_kernel_impl(
|
||||
/* Btmp */ Btmp,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ nb_size,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
@@ -172,8 +451,8 @@ void shared_expert_fp8_kernel_impl(
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < mb_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, nb_size);
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -130,8 +130,10 @@ at::Tensor fused_experts_cpu(
|
||||
at::Tensor& topk_ids,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni);
|
||||
@@ -260,7 +262,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
// moe
|
||||
m.def(
|
||||
"fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool "
|
||||
"inplace, bool use_int8_w8a8, Tensor? w1_scale, Tensor? w2_scale, Tensor? a1_scale, Tensor? a2_scale, bool "
|
||||
"inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, int[]? block_size, Tensor? "
|
||||
"a1_scale, Tensor? a2_scale, bool "
|
||||
"is_vnni) -> Tensor");
|
||||
m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
|
||||
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
# Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools.command.build_py import build_py
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension
|
||||
|
||||
root = Path(__file__).parent.resolve()
|
||||
arch = platform.machine().lower()
|
||||
|
||||
if arch in ("x86_64", "amd64"):
|
||||
plat_name = "manylinux2014_x86_64"
|
||||
elif arch in ("aarch64", "arm64"):
|
||||
plat_name = "manylinux2014_aarch64"
|
||||
elif arch.startswith("ppc"):
|
||||
plat_name = "manylinux2014_ppc64le"
|
||||
else:
|
||||
plat_name = f"manylinux2014_{arch}"
|
||||
|
||||
if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv:
|
||||
sys.argv.extend(["--plat-name", plat_name])
|
||||
|
||||
|
||||
def _get_version():
|
||||
with open(root / "pyproject.toml") as f:
|
||||
for line in f:
|
||||
if line.startswith("version"):
|
||||
return line.split("=")[1].strip().strip('"')
|
||||
|
||||
|
||||
cpu_fp8_ftz = os.getenv("SGLANG_CPU_FP8_CVT_FTZ", "1") == "1"
|
||||
|
||||
operator_namespace = "sgl_kernel"
|
||||
include_dirs = [
|
||||
"../../include",
|
||||
]
|
||||
|
||||
sources = [
|
||||
"csrc/cpu/activation.cpp",
|
||||
"csrc/cpu/bmm.cpp",
|
||||
"csrc/cpu/decode.cpp",
|
||||
"csrc/cpu/extend.cpp",
|
||||
"csrc/cpu/gemm.cpp",
|
||||
"csrc/cpu/gemm_fp8.cpp",
|
||||
"csrc/cpu/gemm_int8.cpp",
|
||||
"csrc/cpu/moe.cpp",
|
||||
"csrc/cpu/moe_fp8.cpp",
|
||||
"csrc/cpu/moe_int8.cpp",
|
||||
"csrc/cpu/norm.cpp",
|
||||
"csrc/cpu/qkv_proj.cpp",
|
||||
"csrc/cpu/topk.cpp",
|
||||
"csrc/cpu/interface.cpp",
|
||||
"csrc/cpu/shm.cpp",
|
||||
"csrc/cpu/rope.cpp",
|
||||
"csrc/cpu/torch_extension_cpu.cpp",
|
||||
]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": [
|
||||
"-O3",
|
||||
"-Wno-unknown-pragmas",
|
||||
"-march=native",
|
||||
"-fopenmp",
|
||||
]
|
||||
}
|
||||
if cpu_fp8_ftz:
|
||||
extra_compile_args["cxx"].append("-DSGLANG_CPU_FP8_CVT_FTZ")
|
||||
|
||||
libraries = ["c10", "torch", "torch_python"]
|
||||
cmdclass = {
|
||||
"build_ext": BuildExtension.with_options(use_ninja=True),
|
||||
}
|
||||
Extension = CppExtension
|
||||
|
||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"]
|
||||
|
||||
ext_modules = [
|
||||
Extension(
|
||||
name="sgl_kernel.common_ops",
|
||||
sources=sources,
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=extra_compile_args,
|
||||
libraries=libraries,
|
||||
extra_link_args=extra_link_args,
|
||||
py_limited_api=False,
|
||||
),
|
||||
]
|
||||
|
||||
setup(
|
||||
name="sgl-kernel",
|
||||
version=_get_version(),
|
||||
packages=find_packages(where="python"),
|
||||
package_dir={"": "python"},
|
||||
ext_modules=ext_modules,
|
||||
cmdclass=cmdclass,
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
||||
259
test/srt/cpu/test_moe.py
Normal file
259
test/srt/cpu/test_moe.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import itertools
|
||||
import math
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
kernel = torch.ops.sgl_kernel
|
||||
|
||||
from utils import (
|
||||
BLOCK_K,
|
||||
BLOCK_N,
|
||||
factor_for_scale,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
native_fp8_fused_moe,
|
||||
precision,
|
||||
scaled_weight,
|
||||
torch_naive_fused_moe,
|
||||
torch_w8a8_per_column_fused_moe,
|
||||
)
|
||||
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
def fused_moe(a, w1, w2, score, topk, renormalize, prepack):
|
||||
|
||||
G = 1
|
||||
topk_group = 1
|
||||
|
||||
B, D = a.shape
|
||||
topk_weights = torch.empty(B, topk, dtype=torch.float32)
|
||||
topk_ids = torch.empty(B, topk, dtype=torch.int32)
|
||||
topk_weights, topk_ids = kernel.grouped_topk_cpu(
|
||||
a, score, topk, renormalize, G, topk_group
|
||||
)
|
||||
|
||||
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
|
||||
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
|
||||
|
||||
inplace = True
|
||||
return kernel.fused_experts_cpu(
|
||||
a,
|
||||
packed_w1,
|
||||
packed_w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
prepack,
|
||||
)
|
||||
|
||||
|
||||
class TestFusedExperts(CustomTestCase):
|
||||
M = [2, 114]
|
||||
N = [32]
|
||||
K = [32]
|
||||
E = [4]
|
||||
topk = [2]
|
||||
renormalize = [False, True]
|
||||
|
||||
M_int8 = [1, 39]
|
||||
N_int8 = [128]
|
||||
K_int8 = [256]
|
||||
E_int8 = [8]
|
||||
topk_int8 = [3]
|
||||
|
||||
M_fp8 = [2, 121]
|
||||
N_fp8 = [512]
|
||||
K_fp8 = [256]
|
||||
E_fp8 = [8]
|
||||
topk_fp8 = [4]
|
||||
|
||||
def _bf16_moe(self, m, n, k, e, topk, renormalize):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
a = torch.randn((m, k), device="cpu", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cpu", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cpu", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cpu", dtype=dtype)
|
||||
|
||||
torch_output = torch_naive_fused_moe(a, w1, w2, score, topk, renormalize)
|
||||
fused_output = fused_moe(a, w1, w2, score, topk, renormalize, prepack)
|
||||
|
||||
atol = rtol = precision[torch_output.dtype]
|
||||
self.assertTrue(
|
||||
torch.allclose(torch_output, fused_output, atol=atol, rtol=rtol)
|
||||
)
|
||||
|
||||
def test_bf16_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M,
|
||||
self.N,
|
||||
self.K,
|
||||
self.E,
|
||||
self.topk,
|
||||
self.renormalize,
|
||||
):
|
||||
with self.subTest(
|
||||
m=params[0],
|
||||
n=params[1],
|
||||
k=params[2],
|
||||
e=params[3],
|
||||
topk=params[4],
|
||||
renormalize=params[5],
|
||||
):
|
||||
self._bf16_moe(*params)
|
||||
|
||||
def _int8_moe(self, M, N, K, E, topk):
|
||||
dtype = torch.bfloat16
|
||||
prepack = True
|
||||
|
||||
# Initialize int8 quantization parameters
|
||||
int8_factor_for_scale = 1e-2
|
||||
int8_max = 127
|
||||
int8_min = -128
|
||||
|
||||
# Input tensor
|
||||
# M * K
|
||||
a = torch.randn((M, K), dtype=dtype) / math.sqrt(K)
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * int8_factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * int8_factor_for_scale
|
||||
|
||||
# Calculate routing
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
ref_out = torch_w8a8_per_column_fused_moe(
|
||||
a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk
|
||||
)
|
||||
|
||||
inplace = True
|
||||
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
|
||||
packed_w2 = kernel.convert_weight_packed(w2) if prepack else w2
|
||||
out = kernel.fused_experts_cpu(
|
||||
a,
|
||||
packed_w1,
|
||||
packed_w2,
|
||||
topk_weight,
|
||||
topk_ids.to(torch.int32),
|
||||
inplace,
|
||||
True,
|
||||
False,
|
||||
w1_s,
|
||||
w2_s,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
prepack,
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
# Increase the tolerance for large input shapes
|
||||
if M > 35:
|
||||
atol = rtol = 0.02
|
||||
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||
|
||||
def test_int8_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M_int8,
|
||||
self.N_int8,
|
||||
self.K_int8,
|
||||
self.E_int8,
|
||||
self.topk_int8,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
topk=params[4],
|
||||
):
|
||||
self._int8_moe(*params)
|
||||
|
||||
def _fp8_moe(self, M, N, K, E, topk):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn(M, K, dtype=dtype) / math.sqrt(K)
|
||||
|
||||
w1_fp32 = torch.randn(E, 2 * N, K)
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = torch.randn(E, K, N)
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w1s = torch.randn(E, 2 * N // BLOCK_N, K // BLOCK_K) * factor_for_scale
|
||||
w2s = torch.randn(E, K // BLOCK_N, N // BLOCK_K) * factor_for_scale
|
||||
|
||||
w1_scaled = scaled_weight(w1, w1s)
|
||||
w2_scaled = scaled_weight(w2, w2s)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
w1 = kernel.convert_weight_packed(w1)
|
||||
w2 = kernel.convert_weight_packed(w2)
|
||||
|
||||
ref_out = native_fp8_fused_moe(
|
||||
a, w1_scaled, w2_scaled, topk_weight, topk_ids, topk
|
||||
)
|
||||
out = kernel.fused_experts_cpu(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids.to(torch.int32),
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
w1s,
|
||||
w2s,
|
||||
[BLOCK_N, BLOCK_K],
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
atol = rtol = precision[dtype]
|
||||
self.assertTrue(torch.allclose(ref_out.bfloat16(), out, atol=atol, rtol=rtol))
|
||||
|
||||
def test_fp8_moe(self):
|
||||
for params in itertools.product(
|
||||
self.M_fp8,
|
||||
self.N_fp8,
|
||||
self.K_fp8,
|
||||
self.E_fp8,
|
||||
self.topk_fp8,
|
||||
):
|
||||
with self.subTest(
|
||||
M=params[0],
|
||||
N=params[1],
|
||||
K=params[2],
|
||||
E=params[3],
|
||||
topk=params[4],
|
||||
):
|
||||
self._fp8_moe(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -148,3 +148,99 @@ def scaled_weight(weight, scales):
|
||||
.contiguous()
|
||||
.view(E, N, K)
|
||||
)
|
||||
|
||||
|
||||
def torch_naive_fused_moe(a, w1, w2, score, topk, renormalize):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
|
||||
if renormalize:
|
||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
|
||||
0, 1
|
||||
)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_fused_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, topk):
|
||||
"""This function performs fused moe with per-column int8 quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
bias=None,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = per_token_quant_int8(act_out)
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
bias=None,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (
|
||||
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
|
||||
.sum(dim=1)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
|
||||
def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk):
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D).float()
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=torch.float32, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
ic0 = torch.matmul(a[mask], w1[i].transpose(0, 1))
|
||||
ic1 = SiluAndMul(ic0)
|
||||
out[mask] = torch.matmul(ic1, w2[i].transpose(0, 1))
|
||||
|
||||
return (
|
||||
(out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype))
|
||||
.sum(dim=1)
|
||||
.to(a.dtype)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user