diff --git a/scripts/ci/ci_install_dependency.sh b/scripts/ci/ci_install_dependency.sh index 517b2ca45..e033ec901 100755 --- a/scripts/ci/ci_install_dependency.sh +++ b/scripts/ci/ci_install_dependency.sh @@ -62,7 +62,7 @@ fi $PIP_CMD list # Install additional dependencies -$PIP_CMD install mooncake-transfer-engine==0.3.6.post1 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX +$PIP_CMD install mooncake-transfer-engine==0.3.6.post1 nvidia-cuda-nvrtc-cu12 py-spy scipy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX if [ "$IS_BLACKWELL" != "1" ]; then # For lmms_evals evaluating MMMU diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 7133ad652..7c4b61171 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -60,6 +60,7 @@ FetchContent_Declare( ) FetchContent_Populate(repo-deepgemm) +# fmt FetchContent_Declare( repo-fmt GIT_REPOSITORY https://github.com/fmtlib/fmt @@ -113,6 +114,15 @@ FetchContent_Declare( ) FetchContent_Populate(repo-mscclpp) +# fast-hadamard-transform +FetchContent_Declare( + repo-fast-hadamard-transform + GIT_REPOSITORY https://github.com/sgl-project/fast-hadamard-transform.git + GIT_TAG 48f3c13764dc2ec662ade842a4696a90a137f1bc + GIT_SHALLOW OFF +) +FetchContent_Populate(repo-fast-hadamard-transform) + # ccache option option(ENABLE_CCACHE "Whether to use ccache" ON) find_program(CCACHE_FOUND ccache) @@ -138,6 +148,7 @@ include_directories( ${repo-flashinfer_SOURCE_DIR}/include ${repo-flashinfer_SOURCE_DIR}/csrc ${repo-mscclpp_SOURCE_DIR}/include + ${repo-fast-hadamard-transform}/csrc ) set(SGL_KERNEL_CUDA_FLAGS @@ -329,6 +340,9 @@ set(SOURCES "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" + "${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform_cuda.cu" + "${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform.cpp" + "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu" "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 39ed19fb8..125ed29dc 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -540,6 +540,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> " "()"); m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm); + + /* + * From hadamard-transform + */ + m.def("fast_hadamard_transform(Tensor x, float scale) -> Tensor"); + m.impl("fast_hadamard_transform", torch::kCUDA, &fast_hadamard_transform); + + m.def("fast_hadamard_transform_12N(Tensor x, float scale) -> Tensor"); + m.impl("fast_hadamard_transform_12N", torch::kCUDA, &fast_hadamard_transform_12N); + + m.def("fast_hadamard_transform_20N(Tensor x, float scale) -> Tensor"); + m.impl("fast_hadamard_transform_20N", torch::kCUDA, &fast_hadamard_transform_20N); + + m.def("fast_hadamard_transform_28N(Tensor x, float scale) -> Tensor"); + m.impl("fast_hadamard_transform_28N", torch::kCUDA, &fast_hadamard_transform_28N); + + m.def("fast_hadamard_transform_40N(Tensor x, float scale) -> Tensor"); + m.impl("fast_hadamard_transform_40N", torch::kCUDA, &fast_hadamard_transform_40N); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 3a0e7a28e..1b4b5c91e 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -837,3 +837,11 @@ void es_fp8_blockwise_scaled_grouped_mm( const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, const torch::Tensor& workspace); +/* + * From fast-hadamard-transform + */ +torch::Tensor fast_hadamard_transform(torch::Tensor& x, double scale); +torch::Tensor fast_hadamard_transform_12N(torch::Tensor& x, double scale); +torch::Tensor fast_hadamard_transform_20N(torch::Tensor& x, double scale); +torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale); +torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 209f81434..907177d6b 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -270,6 +270,13 @@ from sgl_kernel.gemm import ( silu_and_mul_scaled_fp4_grouped_quant, ) from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda +from sgl_kernel.hadamard import ( + hadamard_transform, + hadamard_transform_12n, + hadamard_transform_20n, + hadamard_transform_28n, + hadamard_transform_40n, +) from sgl_kernel.kvcacheio import ( transfer_kv_all_layer, transfer_kv_all_layer_mla, diff --git a/sgl-kernel/python/sgl_kernel/hadamard.py b/sgl-kernel/python/sgl_kernel/hadamard.py new file mode 100644 index 000000000..102c540f9 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/hadamard.py @@ -0,0 +1,21 @@ +import torch + + +def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + return torch.ops.sgl_kernel.fast_hadamard_transform.default(x, scale) + + +def hadamard_transform_12n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + return torch.ops.sgl_kernel.fast_hadamard_transform_12N.default(x, scale) + + +def hadamard_transform_20n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + return torch.ops.sgl_kernel.fast_hadamard_transform_20N.default(x, scale) + + +def hadamard_transform_28n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + return torch.ops.sgl_kernel.fast_hadamard_transform_28N.default(x, scale) + + +def hadamard_transform_40n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + return torch.ops.sgl_kernel.fast_hadamard_transform_40N.default(x, scale) diff --git a/sgl-kernel/tests/test_hadamard.py b/sgl-kernel/tests/test_hadamard.py new file mode 100644 index 000000000..5d1cd40e2 --- /dev/null +++ b/sgl-kernel/tests/test_hadamard.py @@ -0,0 +1,78 @@ +import math + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from scipy.linalg import hadamard +from sgl_kernel import hadamard_transform + + +def hadamard_transform_ref(x, scale=1.0): + """ + x: (..., dim) + out: (..., dim) + """ + if hadamard is None: + raise ImportError("Please install scipy") + x_shape = x.shape + dim = x.shape[-1] + x = x.reshape(-1, dim) + log_dim = math.ceil(math.log2(dim)) + dim_padded = 2**log_dim + if dim != dim_padded: + x = F.pad(x, (0, dim_padded - dim)) + out = F.linear( + x, + torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device), + ) + out = out * scale + return out[..., :dim].reshape(*x_shape) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "dim", + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 137, 1024, 2048, 4096, 8192, 16384, 32768], +) +def test_fast_hadamard_transform(dim, dtype): + device = "cuda" + + if dtype == torch.float32: + rtol, atol = 3e-4, 3e-3 + elif dtype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + else: # float16 + rtol, atol = 3e-3, 5e-3 + + torch.random.manual_seed(0) + batch_size = 15 + + x = torch.randn(batch_size, dim, device=device, dtype=dtype) + x_ref = x.detach().clone().to(torch.float32) + x_pt = x.detach().clone() + + scale = 1 / math.sqrt(dim) + + out = hadamard_transform(x, scale=scale) + out_ref = hadamard_transform_ref(x_ref, scale=scale) + out_pt = hadamard_transform_ref(x_pt, scale=scale) + + torch.testing.assert_close( + out_pt.float(), + out_ref, + rtol=rtol, + atol=atol, + msg="Reference implementations mismatch", + ) + torch.testing.assert_close( + out.float(), + out_ref, + rtol=rtol, + atol=atol, + msg="fast_hadamard_transform output mismatch", + ) + + +if __name__ == "__main__": + pytest.main([__file__])