[Refactor] Rename n_share_experts_fusion as num_fused_shared_experts (#6735)
This commit is contained in:
@@ -57,7 +57,7 @@ __device__ void moe_fused_gate_impl(
|
||||
int64_t num_rows,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t n_share_experts_fusion,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
Params params) {
|
||||
int tidx = threadIdx.x;
|
||||
@@ -68,7 +68,7 @@ __device__ void moe_fused_gate_impl(
|
||||
}
|
||||
|
||||
// Calculate topk_excluding_share_expert_fusion from topk
|
||||
int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0);
|
||||
int64_t topk_excluding_share_expert_fusion = topk - (num_fused_shared_experts > 0 ? 1 : 0);
|
||||
|
||||
// Cast pointers to type T:
|
||||
auto* input_ptr = reinterpret_cast<T*>(input);
|
||||
@@ -222,11 +222,11 @@ __device__ void moe_fused_gate_impl(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (thread_group_idx == 0 && n_share_experts_fusion > 0) {
|
||||
if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
|
||||
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
|
||||
|
||||
// Use round-robin to select expert
|
||||
int64_t expert_offset = thread_row % n_share_experts_fusion;
|
||||
int64_t expert_offset = thread_row % num_fused_shared_experts;
|
||||
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
|
||||
|
||||
// Set the weight to the sum of all weights divided by routed_scaling_factor
|
||||
@@ -273,7 +273,7 @@ __global__ void moe_fused_gate_kernel(
|
||||
int64_t num_rows,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t n_share_experts_fusion,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor) {
|
||||
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
|
||||
moe_fused_gate_impl<T>(
|
||||
@@ -284,7 +284,7 @@ __global__ void moe_fused_gate_kernel(
|
||||
num_rows,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
params);
|
||||
}
|
||||
@@ -305,7 +305,7 @@ __global__ void moe_fused_gate_kernel(
|
||||
num_rows, \
|
||||
topk_group, \
|
||||
topk, \
|
||||
n_share_experts_fusion, \
|
||||
num_fused_shared_experts, \
|
||||
routed_scaling_factor); \
|
||||
dispatched = true; \
|
||||
} while (0)
|
||||
@@ -333,7 +333,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t n_share_experts_fusion,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor) {
|
||||
KernelParamsDynamic params;
|
||||
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
|
||||
@@ -351,7 +351,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
|
||||
num_rows,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
params);
|
||||
}
|
||||
@@ -365,7 +365,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t n_share_experts_fusion,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor) {
|
||||
int64_t num_rows = input.size(0);
|
||||
int32_t num_experts = input.size(1);
|
||||
@@ -464,7 +464,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
@@ -477,7 +477,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
@@ -490,7 +490,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
|
||||
|
||||
Reference in New Issue
Block a user