Set num_fused_shared_experts as num_shared_experts when shared_experts fusion is not disabled (#6736)

This commit is contained in:
Cheng Wan
2025-06-04 15:53:22 -07:00
committed by GitHub
parent f0f84975f4
commit 81964328b7
22 changed files with 381 additions and 45 deletions

View File

@@ -68,7 +68,7 @@ __device__ void moe_fused_gate_impl(
}
// Calculate topk_excluding_share_expert_fusion from topk
int64_t topk_excluding_share_expert_fusion = topk - (num_fused_shared_experts > 0 ? 1 : 0);
int64_t topk_excluding_share_expert_fusion = topk - num_fused_shared_experts;
// Cast pointers to type T:
auto* input_ptr = reinterpret_cast<T*>(input);
@@ -224,13 +224,21 @@ __device__ void moe_fused_gate_impl(
if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
// Use round-robin to select expert
int64_t expert_offset = thread_row % num_fused_shared_experts;
int64_t expert_offset = 0;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
if (num_fused_shared_experts > 1) {
for (int i = 1; i < num_fused_shared_experts; ++i) {
++last_idx;
++expert_offset;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
// Set the weight to the sum of all weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
}
}
}
__syncthreads();