[sgl-kernel] support hadamard (#11663)
This commit is contained in:
@@ -62,7 +62,7 @@ fi
|
|||||||
$PIP_CMD list
|
$PIP_CMD list
|
||||||
|
|
||||||
# Install additional dependencies
|
# Install additional dependencies
|
||||||
$PIP_CMD install mooncake-transfer-engine==0.3.6.post1 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
|
$PIP_CMD install mooncake-transfer-engine==0.3.6.post1 nvidia-cuda-nvrtc-cu12 py-spy scipy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
|
||||||
|
|
||||||
if [ "$IS_BLACKWELL" != "1" ]; then
|
if [ "$IS_BLACKWELL" != "1" ]; then
|
||||||
# For lmms_evals evaluating MMMU
|
# For lmms_evals evaluating MMMU
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ FetchContent_Declare(
|
|||||||
)
|
)
|
||||||
FetchContent_Populate(repo-deepgemm)
|
FetchContent_Populate(repo-deepgemm)
|
||||||
|
|
||||||
|
# fmt
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
repo-fmt
|
repo-fmt
|
||||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||||
@@ -113,6 +114,15 @@ FetchContent_Declare(
|
|||||||
)
|
)
|
||||||
FetchContent_Populate(repo-mscclpp)
|
FetchContent_Populate(repo-mscclpp)
|
||||||
|
|
||||||
|
# fast-hadamard-transform
|
||||||
|
FetchContent_Declare(
|
||||||
|
repo-fast-hadamard-transform
|
||||||
|
GIT_REPOSITORY https://github.com/sgl-project/fast-hadamard-transform.git
|
||||||
|
GIT_TAG 48f3c13764dc2ec662ade842a4696a90a137f1bc
|
||||||
|
GIT_SHALLOW OFF
|
||||||
|
)
|
||||||
|
FetchContent_Populate(repo-fast-hadamard-transform)
|
||||||
|
|
||||||
# ccache option
|
# ccache option
|
||||||
option(ENABLE_CCACHE "Whether to use ccache" ON)
|
option(ENABLE_CCACHE "Whether to use ccache" ON)
|
||||||
find_program(CCACHE_FOUND ccache)
|
find_program(CCACHE_FOUND ccache)
|
||||||
@@ -138,6 +148,7 @@ include_directories(
|
|||||||
${repo-flashinfer_SOURCE_DIR}/include
|
${repo-flashinfer_SOURCE_DIR}/include
|
||||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||||
${repo-mscclpp_SOURCE_DIR}/include
|
${repo-mscclpp_SOURCE_DIR}/include
|
||||||
|
${repo-fast-hadamard-transform}/csrc
|
||||||
)
|
)
|
||||||
|
|
||||||
set(SGL_KERNEL_CUDA_FLAGS
|
set(SGL_KERNEL_CUDA_FLAGS
|
||||||
@@ -329,6 +340,9 @@ set(SOURCES
|
|||||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||||
|
|
||||||
|
"${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform_cuda.cu"
|
||||||
|
"${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform.cpp"
|
||||||
|
|
||||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
|
||||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
|
||||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
|
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
|
||||||
|
|||||||
@@ -540,6 +540,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
|
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
|
||||||
"()");
|
"()");
|
||||||
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
|
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* From hadamard-transform
|
||||||
|
*/
|
||||||
|
m.def("fast_hadamard_transform(Tensor x, float scale) -> Tensor");
|
||||||
|
m.impl("fast_hadamard_transform", torch::kCUDA, &fast_hadamard_transform);
|
||||||
|
|
||||||
|
m.def("fast_hadamard_transform_12N(Tensor x, float scale) -> Tensor");
|
||||||
|
m.impl("fast_hadamard_transform_12N", torch::kCUDA, &fast_hadamard_transform_12N);
|
||||||
|
|
||||||
|
m.def("fast_hadamard_transform_20N(Tensor x, float scale) -> Tensor");
|
||||||
|
m.impl("fast_hadamard_transform_20N", torch::kCUDA, &fast_hadamard_transform_20N);
|
||||||
|
|
||||||
|
m.def("fast_hadamard_transform_28N(Tensor x, float scale) -> Tensor");
|
||||||
|
m.impl("fast_hadamard_transform_28N", torch::kCUDA, &fast_hadamard_transform_28N);
|
||||||
|
|
||||||
|
m.def("fast_hadamard_transform_40N(Tensor x, float scale) -> Tensor");
|
||||||
|
m.impl("fast_hadamard_transform_40N", torch::kCUDA, &fast_hadamard_transform_40N);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(common_ops)
|
REGISTER_EXTENSION(common_ops)
|
||||||
|
|||||||
@@ -837,3 +837,11 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
|||||||
const torch::Tensor& problem_sizes,
|
const torch::Tensor& problem_sizes,
|
||||||
const torch::Tensor& expert_offsets,
|
const torch::Tensor& expert_offsets,
|
||||||
const torch::Tensor& workspace);
|
const torch::Tensor& workspace);
|
||||||
|
/*
|
||||||
|
* From fast-hadamard-transform
|
||||||
|
*/
|
||||||
|
torch::Tensor fast_hadamard_transform(torch::Tensor& x, double scale);
|
||||||
|
torch::Tensor fast_hadamard_transform_12N(torch::Tensor& x, double scale);
|
||||||
|
torch::Tensor fast_hadamard_transform_20N(torch::Tensor& x, double scale);
|
||||||
|
torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale);
|
||||||
|
torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale);
|
||||||
|
|||||||
@@ -270,6 +270,13 @@ from sgl_kernel.gemm import (
|
|||||||
silu_and_mul_scaled_fp4_grouped_quant,
|
silu_and_mul_scaled_fp4_grouped_quant,
|
||||||
)
|
)
|
||||||
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
||||||
|
from sgl_kernel.hadamard import (
|
||||||
|
hadamard_transform,
|
||||||
|
hadamard_transform_12n,
|
||||||
|
hadamard_transform_20n,
|
||||||
|
hadamard_transform_28n,
|
||||||
|
hadamard_transform_40n,
|
||||||
|
)
|
||||||
from sgl_kernel.kvcacheio import (
|
from sgl_kernel.kvcacheio import (
|
||||||
transfer_kv_all_layer,
|
transfer_kv_all_layer,
|
||||||
transfer_kv_all_layer_mla,
|
transfer_kv_all_layer_mla,
|
||||||
|
|||||||
21
sgl-kernel/python/sgl_kernel/hadamard.py
Normal file
21
sgl-kernel/python/sgl_kernel/hadamard.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||||
|
return torch.ops.sgl_kernel.fast_hadamard_transform.default(x, scale)
|
||||||
|
|
||||||
|
|
||||||
|
def hadamard_transform_12n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||||
|
return torch.ops.sgl_kernel.fast_hadamard_transform_12N.default(x, scale)
|
||||||
|
|
||||||
|
|
||||||
|
def hadamard_transform_20n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||||
|
return torch.ops.sgl_kernel.fast_hadamard_transform_20N.default(x, scale)
|
||||||
|
|
||||||
|
|
||||||
|
def hadamard_transform_28n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||||
|
return torch.ops.sgl_kernel.fast_hadamard_transform_28N.default(x, scale)
|
||||||
|
|
||||||
|
|
||||||
|
def hadamard_transform_40n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||||
|
return torch.ops.sgl_kernel.fast_hadamard_transform_40N.default(x, scale)
|
||||||
78
sgl-kernel/tests/test_hadamard.py
Normal file
78
sgl-kernel/tests/test_hadamard.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from scipy.linalg import hadamard
|
||||||
|
from sgl_kernel import hadamard_transform
|
||||||
|
|
||||||
|
|
||||||
|
def hadamard_transform_ref(x, scale=1.0):
|
||||||
|
"""
|
||||||
|
x: (..., dim)
|
||||||
|
out: (..., dim)
|
||||||
|
"""
|
||||||
|
if hadamard is None:
|
||||||
|
raise ImportError("Please install scipy")
|
||||||
|
x_shape = x.shape
|
||||||
|
dim = x.shape[-1]
|
||||||
|
x = x.reshape(-1, dim)
|
||||||
|
log_dim = math.ceil(math.log2(dim))
|
||||||
|
dim_padded = 2**log_dim
|
||||||
|
if dim != dim_padded:
|
||||||
|
x = F.pad(x, (0, dim_padded - dim))
|
||||||
|
out = F.linear(
|
||||||
|
x,
|
||||||
|
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
|
||||||
|
)
|
||||||
|
out = out * scale
|
||||||
|
return out[..., :dim].reshape(*x_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dim",
|
||||||
|
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 137, 1024, 2048, 4096, 8192, 16384, 32768],
|
||||||
|
)
|
||||||
|
def test_fast_hadamard_transform(dim, dtype):
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
if dtype == torch.float32:
|
||||||
|
rtol, atol = 3e-4, 3e-3
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
rtol, atol = 1e-2, 5e-2
|
||||||
|
else: # float16
|
||||||
|
rtol, atol = 3e-3, 5e-3
|
||||||
|
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
batch_size = 15
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, dim, device=device, dtype=dtype)
|
||||||
|
x_ref = x.detach().clone().to(torch.float32)
|
||||||
|
x_pt = x.detach().clone()
|
||||||
|
|
||||||
|
scale = 1 / math.sqrt(dim)
|
||||||
|
|
||||||
|
out = hadamard_transform(x, scale=scale)
|
||||||
|
out_ref = hadamard_transform_ref(x_ref, scale=scale)
|
||||||
|
out_pt = hadamard_transform_ref(x_pt, scale=scale)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
out_pt.float(),
|
||||||
|
out_ref,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
msg="Reference implementations mismatch",
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
out.float(),
|
||||||
|
out_ref,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
msg="fast_hadamard_transform output mismatch",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user