[sgl-kernel][2/N]Support Expert Specialization Grouped GEMM (#11534)
This commit is contained in:
@@ -68,24 +68,58 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor lm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
|
||||
torch::Tensor mm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
|
||||
torch::Tensor hm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
|
||||
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes,
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
problem_sizes,
|
||||
expert_offsets);
|
||||
|
||||
const std::string H20_device_type_str("NVIDIA H20");
|
||||
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
|
||||
at::cuda::CUDAGuard device_guard{(char)a.get_device()};
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
if (output.dtype() == torch::kBFloat16) {
|
||||
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::bfloat16_t>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes,
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
is_h20_device,
|
||||
stream);
|
||||
} else if (output.dtype() == torch::kFloat16) {
|
||||
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::half_t>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes,
|
||||
output,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
is_h20_device,
|
||||
stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
|
||||
if (output.dtype() == torch::kBFloat16) {
|
||||
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::bfloat16_t>(
|
||||
out_ptrs,
|
||||
@@ -100,7 +134,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
layout_sfb,
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes);
|
||||
hm_problem_sizes,
|
||||
is_h20_device,
|
||||
stream);
|
||||
} else if (output.dtype() == torch::kFloat16) {
|
||||
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::half_t>(
|
||||
out_ptrs,
|
||||
@@ -115,7 +151,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
layout_sfb,
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes);
|
||||
hm_problem_sizes,
|
||||
is_h20_device,
|
||||
stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user