[sgl-kernel] support hadamard (#11663)
This commit is contained in:
@@ -540,6 +540,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
|
||||
"()");
|
||||
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
|
||||
|
||||
/*
|
||||
* From hadamard-transform
|
||||
*/
|
||||
m.def("fast_hadamard_transform(Tensor x, float scale) -> Tensor");
|
||||
m.impl("fast_hadamard_transform", torch::kCUDA, &fast_hadamard_transform);
|
||||
|
||||
m.def("fast_hadamard_transform_12N(Tensor x, float scale) -> Tensor");
|
||||
m.impl("fast_hadamard_transform_12N", torch::kCUDA, &fast_hadamard_transform_12N);
|
||||
|
||||
m.def("fast_hadamard_transform_20N(Tensor x, float scale) -> Tensor");
|
||||
m.impl("fast_hadamard_transform_20N", torch::kCUDA, &fast_hadamard_transform_20N);
|
||||
|
||||
m.def("fast_hadamard_transform_28N(Tensor x, float scale) -> Tensor");
|
||||
m.impl("fast_hadamard_transform_28N", torch::kCUDA, &fast_hadamard_transform_28N);
|
||||
|
||||
m.def("fast_hadamard_transform_40N(Tensor x, float scale) -> Tensor");
|
||||
m.impl("fast_hadamard_transform_40N", torch::kCUDA, &fast_hadamard_transform_40N);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
Reference in New Issue
Block a user