Set num_fused_shared_experts as num_shared_experts when shared_experts fusion is not disabled (#6736)
This commit is contained in:
@@ -68,7 +68,7 @@ __device__ void moe_fused_gate_impl(
|
||||
}
|
||||
|
||||
// Calculate topk_excluding_share_expert_fusion from topk
|
||||
int64_t topk_excluding_share_expert_fusion = topk - (num_fused_shared_experts > 0 ? 1 : 0);
|
||||
int64_t topk_excluding_share_expert_fusion = topk - num_fused_shared_experts;
|
||||
|
||||
// Cast pointers to type T:
|
||||
auto* input_ptr = reinterpret_cast<T*>(input);
|
||||
@@ -224,13 +224,21 @@ __device__ void moe_fused_gate_impl(
|
||||
|
||||
if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
|
||||
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
|
||||
|
||||
// Use round-robin to select expert
|
||||
int64_t expert_offset = thread_row % num_fused_shared_experts;
|
||||
int64_t expert_offset = 0;
|
||||
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
|
||||
|
||||
// Set the weight to the sum of all weights divided by routed_scaling_factor
|
||||
output_ptr[last_idx] = output_sum / routed_scaling_factor;
|
||||
|
||||
if (num_fused_shared_experts > 1) {
|
||||
for (int i = 1; i < num_fused_shared_experts; ++i) {
|
||||
++last_idx;
|
||||
++expert_offset;
|
||||
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
|
||||
// Set the weight to the sum of all weights divided by routed_scaling_factor
|
||||
output_ptr[last_idx] = output_sum / routed_scaling_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user