[Refactor] Rename n_share_experts_fusion as num_fused_shared_experts (#6735)
This commit is contained in:
@@ -161,7 +161,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
m.def(
|
||||
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
|
||||
"n_share_experts_fusion, float routed_scaling_factor) -> "
|
||||
"num_fused_shared_experts, float routed_scaling_factor) -> "
|
||||
"(Tensor[])");
|
||||
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
|
||||
m.def(
|
||||
|
||||
@@ -57,7 +57,7 @@ __device__ void moe_fused_gate_impl(
|
||||
int64_t num_rows,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t n_share_experts_fusion,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor,
|
||||
Params params) {
|
||||
int tidx = threadIdx.x;
|
||||
@@ -68,7 +68,7 @@ __device__ void moe_fused_gate_impl(
|
||||
}
|
||||
|
||||
// Calculate topk_excluding_share_expert_fusion from topk
|
||||
int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0);
|
||||
int64_t topk_excluding_share_expert_fusion = topk - (num_fused_shared_experts > 0 ? 1 : 0);
|
||||
|
||||
// Cast pointers to type T:
|
||||
auto* input_ptr = reinterpret_cast<T*>(input);
|
||||
@@ -222,11 +222,11 @@ __device__ void moe_fused_gate_impl(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (thread_group_idx == 0 && n_share_experts_fusion > 0) {
|
||||
if (thread_group_idx == 0 && num_fused_shared_experts > 0) {
|
||||
int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion;
|
||||
|
||||
// Use round-robin to select expert
|
||||
int64_t expert_offset = thread_row % n_share_experts_fusion;
|
||||
int64_t expert_offset = thread_row % num_fused_shared_experts;
|
||||
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);
|
||||
|
||||
// Set the weight to the sum of all weights divided by routed_scaling_factor
|
||||
@@ -273,7 +273,7 @@ __global__ void moe_fused_gate_kernel(
|
||||
int64_t num_rows,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t n_share_experts_fusion,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor) {
|
||||
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
|
||||
moe_fused_gate_impl<T>(
|
||||
@@ -284,7 +284,7 @@ __global__ void moe_fused_gate_kernel(
|
||||
num_rows,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
params);
|
||||
}
|
||||
@@ -305,7 +305,7 @@ __global__ void moe_fused_gate_kernel(
|
||||
num_rows, \
|
||||
topk_group, \
|
||||
topk, \
|
||||
n_share_experts_fusion, \
|
||||
num_fused_shared_experts, \
|
||||
routed_scaling_factor); \
|
||||
dispatched = true; \
|
||||
} while (0)
|
||||
@@ -333,7 +333,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t n_share_experts_fusion,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor) {
|
||||
KernelParamsDynamic params;
|
||||
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
|
||||
@@ -351,7 +351,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
|
||||
num_rows,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
params);
|
||||
}
|
||||
@@ -365,7 +365,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t n_share_experts_fusion,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor) {
|
||||
int64_t num_rows = input.size(0);
|
||||
int32_t num_experts = input.size(1);
|
||||
@@ -464,7 +464,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor);
|
||||
} else if (input.scalar_type() == at::kHalf) {
|
||||
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
@@ -477,7 +477,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor);
|
||||
} else if (input.scalar_type() == at::kFloat) {
|
||||
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
|
||||
@@ -490,7 +490,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
|
||||
|
||||
@@ -206,7 +206,7 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t topk,
|
||||
int64_t n_share_experts_fusion,
|
||||
int64_t num_fused_shared_experts,
|
||||
double routed_scaling_factor);
|
||||
|
||||
void fp8_blockwise_scaled_grouped_mm(
|
||||
|
||||
@@ -42,7 +42,7 @@ def moe_fused_gate(
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion=0,
|
||||
num_fused_shared_experts=0,
|
||||
routed_scaling_factor=0,
|
||||
):
|
||||
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
|
||||
@@ -51,7 +51,7 @@ def moe_fused_gate(
|
||||
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
|
||||
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
|
||||
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
|
||||
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
|
||||
# num_fused_shared_experts: if > 0, the last expert will be replaced with a round-robin shared expert
|
||||
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
|
||||
return torch.ops.sgl_kernel.moe_fused_gate.default(
|
||||
input_tensor,
|
||||
@@ -59,7 +59,7 @@ def moe_fused_gate(
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,15 +19,15 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
|
||||
(512, 16, 8, 16),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("n_share_experts_fusion", [0, 1, 8, 16])
|
||||
def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusion):
|
||||
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1])
|
||||
def test_moe_fused_gate_combined(seq_length, dtype, params, num_fused_shared_experts):
|
||||
num_experts, num_expert_group, topk_group, topk = params
|
||||
|
||||
torch.manual_seed(seq_length)
|
||||
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
|
||||
scores = tensor.clone()
|
||||
bias = torch.rand(num_experts).to(dtype).cuda()
|
||||
topk = topk + min(1, n_share_experts_fusion)
|
||||
topk = topk + min(1, num_fused_shared_experts)
|
||||
|
||||
output, indices = moe_fused_gate(
|
||||
tensor,
|
||||
@@ -35,7 +35,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
topk=topk,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
routed_scaling_factor=2.5,
|
||||
)
|
||||
ref_output, ref_indices = biased_grouped_topk(
|
||||
@@ -47,12 +47,12 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
compiled=False,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
routed_scaling_factor=2.5,
|
||||
)
|
||||
|
||||
# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension
|
||||
if n_share_experts_fusion > 0:
|
||||
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
|
||||
if num_fused_shared_experts > 0:
|
||||
original_indices = indices.clone()
|
||||
original_ref_indices = ref_indices.clone()
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
|
||||
ref_indices = ref_indices[:, :-1]
|
||||
|
||||
valid_min = num_experts
|
||||
valid_max = num_experts + n_share_experts_fusion
|
||||
valid_max = num_experts + num_fused_shared_experts
|
||||
shared_indices = original_indices[:, -1]
|
||||
shared_ref_indices = original_ref_indices[:, -1]
|
||||
if shared_indices is not None:
|
||||
@@ -87,11 +87,11 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
|
||||
|
||||
assert idx_check, (
|
||||
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
|
||||
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
|
||||
f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
|
||||
)
|
||||
assert output_check, (
|
||||
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
|
||||
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
|
||||
f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user