[Refactor] Rename n_share_experts_fusion as num_fused_shared_experts (#6735)

This commit is contained in:
Cheng Wan
2025-06-03 17:48:24 -07:00
committed by GitHub
parent b6d0ce9f78
commit 8a5480528d
14 changed files with 82 additions and 93 deletions

View File

@@ -57,7 +57,7 @@ __device__ void moe_fused_gate_impl(
int64_t num_rows,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
int64_t num_fused_shared_experts,
double routed_scaling_factor,
Params params) {
int tidx = threadIdx.x;
@@ -68,7 +68,7 @@ __device__ void moe_fused_gate_impl(
}
// Calculate topk_excluding_share_expert_fusion from topk
int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0);
int64_t topk_excluding_share_expert_fusion = topk - (num_fused_shared_experts > 0 ? 1 : 0);
// Cast pointers to type T:
auto* input_ptr = reinterpret_cast<T*>(input);
@@ -222,11 +222,11 @@ __device__ void moe_fused_gate_impl(
__syncthreads();
}
if (thread_group_idx == 0 && n_share_experts_fusion > 0) {
if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
// Use round-robin to select expert
int64_t expert_offset = thread_row % n_share_experts_fusion;
int64_t expert_offset = thread_row % num_fused_shared_experts;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
@@ -273,7 +273,7 @@ __global__ void moe_fused_gate_kernel(
int64_t num_rows,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
int64_t num_fused_shared_experts,
double routed_scaling_factor) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>(
@@ -284,7 +284,7 @@ __global__ void moe_fused_gate_kernel(
num_rows,
topk_group,
topk,
n_share_experts_fusion,
num_fused_shared_experts,
routed_scaling_factor,
params);
}
@@ -305,7 +305,7 @@ __global__ void moe_fused_gate_kernel(
num_rows, \
topk_group, \
topk, \
n_share_experts_fusion, \
num_fused_shared_experts, \
routed_scaling_factor); \
dispatched = true; \
} while (0)
@@ -333,7 +333,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
int64_t num_fused_shared_experts,
double routed_scaling_factor) {
KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
@@ -351,7 +351,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
num_rows,
topk_group,
topk,
n_share_experts_fusion,
num_fused_shared_experts,
routed_scaling_factor,
params);
}
@@ -365,7 +365,7 @@ std::vector<at::Tensor> moe_fused_gate(
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
int64_t num_fused_shared_experts,
double routed_scaling_factor) {
int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1);
@@ -464,7 +464,7 @@ std::vector<at::Tensor> moe_fused_gate(
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
num_fused_shared_experts,
routed_scaling_factor);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
@@ -477,7 +477,7 @@ std::vector<at::Tensor> moe_fused_gate(
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
num_fused_shared_experts,
routed_scaling_factor);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
@@ -490,7 +490,7 @@ std::vector<at::Tensor> moe_fused_gate(
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
num_fused_shared_experts,
routed_scaling_factor);
} else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");