[sgl-kernel][3/N]Support Expert Specialization Grouped GEMM (#11674)

This commit is contained in:
Qi Yuhang
2025-10-16 04:39:31 +08:00
committed by GitHub
parent f226d3da2a
commit 6c01844f45
7 changed files with 22 additions and 8 deletions

View File

@@ -168,7 +168,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
b_scale_stack = b_scale_stack.transpose(1, 2)
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
a_strides = torch.full(
(num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
@@ -188,6 +188,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
d_strides,
problem_sizes,
expert_offsets[:-1],
workspace,
)
for g in range(num_experts):