[sgl-kernel][3/N]Support Expert Specialization Grouped GEMM (#11674)
This commit is contained in:
@@ -133,6 +133,7 @@ def bench_es(
|
|||||||
d_strides = torch.full(
|
d_strides = torch.full(
|
||||||
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64
|
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64
|
||||||
)
|
)
|
||||||
|
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
|
||||||
|
|
||||||
def run_cutlass():
|
def run_cutlass():
|
||||||
es_fp8_blockwise_scaled_grouped_mm(
|
es_fp8_blockwise_scaled_grouped_mm(
|
||||||
@@ -146,6 +147,7 @@ def bench_es(
|
|||||||
d_strides,
|
d_strides,
|
||||||
problem_sizes,
|
problem_sizes,
|
||||||
expert_offsets[:-1],
|
expert_offsets[:-1],
|
||||||
|
workspace,
|
||||||
)
|
)
|
||||||
|
|
||||||
run_cutlass()
|
run_cutlass()
|
||||||
|
|||||||
@@ -537,7 +537,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
*/
|
*/
|
||||||
m.def(
|
m.def(
|
||||||
"es_fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
|
"es_fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
|
||||||
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets) -> ()");
|
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
|
||||||
|
"()");
|
||||||
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
|
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,8 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
|||||||
const torch::Tensor& stride_b,
|
const torch::Tensor& stride_b,
|
||||||
const torch::Tensor& stride_d,
|
const torch::Tensor& stride_d,
|
||||||
const torch::Tensor& problem_sizes,
|
const torch::Tensor& problem_sizes,
|
||||||
const torch::Tensor& expert_offsets) {
|
const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& workspace) {
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||||
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
|
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
|
||||||
@@ -135,6 +136,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
|||||||
lm_problem_sizes,
|
lm_problem_sizes,
|
||||||
mm_problem_sizes,
|
mm_problem_sizes,
|
||||||
hm_problem_sizes,
|
hm_problem_sizes,
|
||||||
|
workspace,
|
||||||
is_h20_device,
|
is_h20_device,
|
||||||
stream);
|
stream);
|
||||||
} else if (output.dtype() == torch::kFloat16) {
|
} else if (output.dtype() == torch::kFloat16) {
|
||||||
@@ -152,6 +154,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
|||||||
lm_problem_sizes,
|
lm_problem_sizes,
|
||||||
mm_problem_sizes,
|
mm_problem_sizes,
|
||||||
hm_problem_sizes,
|
hm_problem_sizes,
|
||||||
|
workspace,
|
||||||
is_h20_device,
|
is_h20_device,
|
||||||
stream);
|
stream);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
|
|||||||
const torch::Tensor& layout_sfa,
|
const torch::Tensor& layout_sfa,
|
||||||
const torch::Tensor& layout_sfb,
|
const torch::Tensor& layout_sfb,
|
||||||
const torch::Tensor& problem_sizes,
|
const torch::Tensor& problem_sizes,
|
||||||
|
const torch::Tensor& workspace,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
using ElementA = typename GemmTraits::ElementA;
|
using ElementA = typename GemmTraits::ElementA;
|
||||||
using StrideA = typename GemmTraits::StrideA;
|
using StrideA = typename GemmTraits::StrideA;
|
||||||
@@ -143,10 +144,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
|
|||||||
auto can_implement_status = gemm_op.can_implement(args);
|
auto can_implement_status = gemm_op.can_implement(args);
|
||||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
|
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
|
||||||
|
|
||||||
torch::TensorOptions options_uint8 = torch::TensorOptions().dtype(torch::kUInt8).device(out_ptrs.device());
|
|
||||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
|
||||||
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
|
|
||||||
|
|
||||||
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
|
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
|
||||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||||
|
|
||||||
@@ -169,6 +166,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
|||||||
const torch::Tensor& lm_problem_sizes,
|
const torch::Tensor& lm_problem_sizes,
|
||||||
const torch::Tensor& mm_problem_sizes,
|
const torch::Tensor& mm_problem_sizes,
|
||||||
const torch::Tensor& hm_problem_sizes,
|
const torch::Tensor& hm_problem_sizes,
|
||||||
|
const torch::Tensor& workspace,
|
||||||
bool is_h20_device,
|
bool is_h20_device,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
using LowMGemmH20Traits =
|
using LowMGemmH20Traits =
|
||||||
@@ -199,6 +197,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
|||||||
layout_sfb,
|
layout_sfb,
|
||||||
layout_sfa,
|
layout_sfa,
|
||||||
lm_problem_sizes,
|
lm_problem_sizes,
|
||||||
|
workspace,
|
||||||
stream);
|
stream);
|
||||||
} else {
|
} else {
|
||||||
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>(
|
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>(
|
||||||
@@ -213,6 +212,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
|||||||
layout_sfb,
|
layout_sfb,
|
||||||
layout_sfa,
|
layout_sfa,
|
||||||
lm_problem_sizes,
|
lm_problem_sizes,
|
||||||
|
workspace,
|
||||||
stream);
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,6 +229,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
|||||||
layout_sfb,
|
layout_sfb,
|
||||||
layout_sfa,
|
layout_sfa,
|
||||||
mm_problem_sizes,
|
mm_problem_sizes,
|
||||||
|
workspace,
|
||||||
stream);
|
stream);
|
||||||
} else {
|
} else {
|
||||||
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
|
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
|
||||||
@@ -243,6 +244,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
|||||||
layout_sfa,
|
layout_sfa,
|
||||||
layout_sfb,
|
layout_sfb,
|
||||||
mm_problem_sizes,
|
mm_problem_sizes,
|
||||||
|
workspace,
|
||||||
stream);
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,6 +261,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
|||||||
layout_sfa,
|
layout_sfa,
|
||||||
layout_sfb,
|
layout_sfb,
|
||||||
hm_problem_sizes,
|
hm_problem_sizes,
|
||||||
|
workspace,
|
||||||
stream);
|
stream);
|
||||||
} else {
|
} else {
|
||||||
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>(
|
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>(
|
||||||
@@ -273,6 +276,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
|
|||||||
layout_sfa,
|
layout_sfa,
|
||||||
layout_sfb,
|
layout_sfb,
|
||||||
hm_problem_sizes,
|
hm_problem_sizes,
|
||||||
|
workspace,
|
||||||
stream);
|
stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -835,4 +835,5 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
|||||||
const torch::Tensor& stride_b,
|
const torch::Tensor& stride_b,
|
||||||
const torch::Tensor& stride_d,
|
const torch::Tensor& stride_d,
|
||||||
const torch::Tensor& problem_sizes,
|
const torch::Tensor& problem_sizes,
|
||||||
const torch::Tensor& expert_offsets);
|
const torch::Tensor& expert_offsets,
|
||||||
|
const torch::Tensor& workspace);
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ def es_fp8_blockwise_scaled_grouped_mm(
|
|||||||
stride_d,
|
stride_d,
|
||||||
problem_sizes,
|
problem_sizes,
|
||||||
expert_offsets,
|
expert_offsets,
|
||||||
|
workspace,
|
||||||
):
|
):
|
||||||
torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default(
|
torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default(
|
||||||
output,
|
output,
|
||||||
@@ -24,4 +25,5 @@ def es_fp8_blockwise_scaled_grouped_mm(
|
|||||||
stride_d,
|
stride_d,
|
||||||
problem_sizes,
|
problem_sizes,
|
||||||
expert_offsets,
|
expert_offsets,
|
||||||
|
workspace,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
|||||||
].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
|
].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
|
||||||
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
|
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
|
||||||
b_scale_stack = b_scale_stack.transpose(1, 2)
|
b_scale_stack = b_scale_stack.transpose(1, 2)
|
||||||
|
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
|
||||||
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
||||||
a_strides = torch.full(
|
a_strides = torch.full(
|
||||||
(num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
|
(num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
|
||||||
@@ -188,6 +188,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
|||||||
d_strides,
|
d_strides,
|
||||||
problem_sizes,
|
problem_sizes,
|
||||||
expert_offsets[:-1],
|
expert_offsets[:-1],
|
||||||
|
workspace,
|
||||||
)
|
)
|
||||||
|
|
||||||
for g in range(num_experts):
|
for g in range(num_experts):
|
||||||
|
|||||||
Reference in New Issue
Block a user