[sgl-kernel] support hadamard (#11663)

This commit is contained in:
Fan Yin
2025-10-16 10:00:44 +08:00
committed by GitHub
parent 868403f642
commit 3289da5b41
7 changed files with 147 additions and 1 deletions

View File

@@ -540,6 +540,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
"()");
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
/*
* From hadamard-transform
*/
m.def("fast_hadamard_transform(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform", torch::kCUDA, &fast_hadamard_transform);
m.def("fast_hadamard_transform_12N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_12N", torch::kCUDA, &fast_hadamard_transform_12N);
m.def("fast_hadamard_transform_20N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_20N", torch::kCUDA, &fast_hadamard_transform_20N);
m.def("fast_hadamard_transform_28N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_28N", torch::kCUDA, &fast_hadamard_transform_28N);
m.def("fast_hadamard_transform_40N(Tensor x, float scale) -> Tensor");
m.impl("fast_hadamard_transform_40N", torch::kCUDA, &fast_hadamard_transform_40N);
}
REGISTER_EXTENSION(common_ops)