diff --git a/sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py b/sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py index 1cc06d4b5..7591c5dd1 100644 --- a/sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py +++ b/sgl-kernel/benchmark/bench_es_fp8_blockwise_grouped_gemm.py @@ -133,6 +133,7 @@ def bench_es( d_strides = torch.full( (num_groups,), c_out.stride(0), device=device, dtype=torch.int64 ) + workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8) def run_cutlass(): es_fp8_blockwise_scaled_grouped_mm( @@ -146,6 +147,7 @@ def bench_es( d_strides, problem_sizes, expert_offsets[:-1], + workspace, ) run_cutlass() diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 8fce6a276..39ed19fb8 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -537,7 +537,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { */ m.def( "es_fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " - "stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets) -> ()"); + "stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> " + "()"); m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm); } diff --git a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu index a5d2904fc..f05209b37 100644 --- a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu +++ b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu @@ -40,7 +40,8 @@ void es_fp8_blockwise_scaled_grouped_mm( const torch::Tensor& stride_b, const torch::Tensor& stride_d, const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets) { + const torch::Tensor& expert_offsets, + const torch::Tensor& workspace) { #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); @@ -135,6 +136,7 @@ void es_fp8_blockwise_scaled_grouped_mm( lm_problem_sizes, mm_problem_sizes, hm_problem_sizes, + workspace, is_h20_device, stream); } else if (output.dtype() == torch::kFloat16) { @@ -152,6 +154,7 @@ void es_fp8_blockwise_scaled_grouped_mm( lm_problem_sizes, mm_problem_sizes, hm_problem_sizes, + workspace, is_h20_device, stream); } else { diff --git a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh index 19621b24f..f6ab33eb0 100644 --- a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh +++ b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh @@ -98,6 +98,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, const torch::Tensor& problem_sizes, + const torch::Tensor& workspace, cudaStream_t stream) { using ElementA = typename GemmTraits::ElementA; using StrideA = typename GemmTraits::StrideA; @@ -143,10 +144,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( auto can_implement_status = gemm_op.can_implement(args); TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); - torch::TensorOptions options_uint8 = torch::TensorOptions().dtype(torch::kUInt8).device(out_ptrs.device()); - size_t workspace_size = gemm_op.get_workspace_size(args); - torch::Tensor workspace = torch::empty(workspace_size, options_uint8); - auto status = gemm_op.initialize(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); @@ -169,6 +166,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( const torch::Tensor& lm_problem_sizes, const torch::Tensor& mm_problem_sizes, const torch::Tensor& hm_problem_sizes, + const torch::Tensor& workspace, bool is_h20_device, cudaStream_t stream) { using LowMGemmH20Traits = @@ -199,6 +197,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( layout_sfb, layout_sfa, lm_problem_sizes, + workspace, stream); } else { launch_sm90_fp8_blockwise_scaled_group_mm( @@ -213,6 +212,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( layout_sfb, layout_sfa, lm_problem_sizes, + workspace, stream); } @@ -229,6 +229,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( layout_sfb, layout_sfa, mm_problem_sizes, + workspace, stream); } else { launch_sm90_fp8_blockwise_scaled_group_mm( @@ -243,6 +244,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( layout_sfa, layout_sfb, mm_problem_sizes, + workspace, stream); } @@ -259,6 +261,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( layout_sfa, layout_sfb, hm_problem_sizes, + workspace, stream); } else { launch_sm90_fp8_blockwise_scaled_group_mm( @@ -273,6 +276,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( layout_sfa, layout_sfb, hm_problem_sizes, + workspace, stream); } } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 255c0a311..3a0e7a28e 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -835,4 +835,5 @@ void es_fp8_blockwise_scaled_grouped_mm( const torch::Tensor& stride_b, const torch::Tensor& stride_d, const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets); + const torch::Tensor& expert_offsets, + const torch::Tensor& workspace); diff --git a/sgl-kernel/python/sgl_kernel/expert_specialization.py b/sgl-kernel/python/sgl_kernel/expert_specialization.py index 81b411cf5..8b76a682a 100644 --- a/sgl-kernel/python/sgl_kernel/expert_specialization.py +++ b/sgl-kernel/python/sgl_kernel/expert_specialization.py @@ -12,6 +12,7 @@ def es_fp8_blockwise_scaled_grouped_mm( stride_d, problem_sizes, expert_offsets, + workspace, ): torch.ops.sgl_kernel.es_fp8_blockwise_scaled_grouped_mm.default( output, @@ -24,4 +25,5 @@ def es_fp8_blockwise_scaled_grouped_mm( stride_d, problem_sizes, expert_offsets, + workspace, ) diff --git a/sgl-kernel/tests/test_es_fp8_blockwise_moe.py b/sgl-kernel/tests/test_es_fp8_blockwise_moe.py index 9118facaa..3cb456b14 100644 --- a/sgl-kernel/tests/test_es_fp8_blockwise_moe.py +++ b/sgl-kernel/tests/test_es_fp8_blockwise_moe.py @@ -168,7 +168,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): ].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major. b_scale_stack = b_scale_stack.transpose(1, 2) - + workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8) c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) a_strides = torch.full( (num_experts,), a_stack.stride(0), device=device, dtype=torch.int64 @@ -188,6 +188,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): d_strides, problem_sizes, expert_offsets[:-1], + workspace, ) for g in range(num_experts):