[sgl-kernel] support hadamard (#11663)
This commit is contained in:
@@ -62,7 +62,7 @@ fi
|
||||
$PIP_CMD list
|
||||
|
||||
# Install additional dependencies
|
||||
$PIP_CMD install mooncake-transfer-engine==0.3.6.post1 nvidia-cuda-nvrtc-cu12 py-spy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
|
||||
$PIP_CMD install mooncake-transfer-engine==0.3.6.post1 nvidia-cuda-nvrtc-cu12 py-spy scipy huggingface_hub[hf_xet] $PIP_INSTALL_SUFFIX
|
||||
|
||||
if [ "$IS_BLACKWELL" != "1" ]; then
|
||||
# For lmms_evals evaluating MMMU
|
||||
|
||||
@@ -60,6 +60,7 @@ FetchContent_Declare(
|
||||
)
|
||||
FetchContent_Populate(repo-deepgemm)
|
||||
|
||||
# fmt
|
||||
FetchContent_Declare(
|
||||
repo-fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||
@@ -113,6 +114,15 @@ FetchContent_Declare(
|
||||
)
|
||||
FetchContent_Populate(repo-mscclpp)
|
||||
|
||||
# fast-hadamard-transform
|
||||
FetchContent_Declare(
|
||||
repo-fast-hadamard-transform
|
||||
GIT_REPOSITORY https://github.com/sgl-project/fast-hadamard-transform.git
|
||||
GIT_TAG 48f3c13764dc2ec662ade842a4696a90a137f1bc
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-fast-hadamard-transform)
|
||||
|
||||
# ccache option
|
||||
option(ENABLE_CCACHE "Whether to use ccache" ON)
|
||||
find_program(CCACHE_FOUND ccache)
|
||||
@@ -138,6 +148,7 @@ include_directories(
|
||||
${repo-flashinfer_SOURCE_DIR}/include
|
||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||
${repo-mscclpp_SOURCE_DIR}/include
|
||||
${repo-fast-hadamard-transform}/csrc
|
||||
)
|
||||
|
||||
set(SGL_KERNEL_CUDA_FLAGS
|
||||
@@ -329,6 +340,9 @@ set(SOURCES
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
|
||||
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
|
||||
|
||||
"${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform_cuda.cu"
|
||||
"${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform.cpp"
|
||||
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu"
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu"
|
||||
|
||||
@@ -540,6 +540,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"stride_a, Tensor stride_b, Tensor stride_d, Tensor problem_sizes, Tensor expert_offsets, Tensor workspace) -> "
|
||||
"()");
|
||||
m.impl("es_fp8_blockwise_scaled_grouped_mm", &es_fp8_blockwise_scaled_grouped_mm);
|
||||
|
||||
/*
|
||||
* From hadamard-transform
|
||||
*/
|
||||
m.def("fast_hadamard_transform(Tensor x, float scale) -> Tensor");
|
||||
m.impl("fast_hadamard_transform", torch::kCUDA, &fast_hadamard_transform);
|
||||
|
||||
m.def("fast_hadamard_transform_12N(Tensor x, float scale) -> Tensor");
|
||||
m.impl("fast_hadamard_transform_12N", torch::kCUDA, &fast_hadamard_transform_12N);
|
||||
|
||||
m.def("fast_hadamard_transform_20N(Tensor x, float scale) -> Tensor");
|
||||
m.impl("fast_hadamard_transform_20N", torch::kCUDA, &fast_hadamard_transform_20N);
|
||||
|
||||
m.def("fast_hadamard_transform_28N(Tensor x, float scale) -> Tensor");
|
||||
m.impl("fast_hadamard_transform_28N", torch::kCUDA, &fast_hadamard_transform_28N);
|
||||
|
||||
m.def("fast_hadamard_transform_40N(Tensor x, float scale) -> Tensor");
|
||||
m.impl("fast_hadamard_transform_40N", torch::kCUDA, &fast_hadamard_transform_40N);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
@@ -837,3 +837,11 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace);
|
||||
/*
|
||||
* From fast-hadamard-transform
|
||||
*/
|
||||
torch::Tensor fast_hadamard_transform(torch::Tensor& x, double scale);
|
||||
torch::Tensor fast_hadamard_transform_12N(torch::Tensor& x, double scale);
|
||||
torch::Tensor fast_hadamard_transform_20N(torch::Tensor& x, double scale);
|
||||
torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale);
|
||||
torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale);
|
||||
|
||||
@@ -270,6 +270,13 @@ from sgl_kernel.gemm import (
|
||||
silu_and_mul_scaled_fp4_grouped_quant,
|
||||
)
|
||||
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
||||
from sgl_kernel.hadamard import (
|
||||
hadamard_transform,
|
||||
hadamard_transform_12n,
|
||||
hadamard_transform_20n,
|
||||
hadamard_transform_28n,
|
||||
hadamard_transform_40n,
|
||||
)
|
||||
from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_all_layer,
|
||||
transfer_kv_all_layer_mla,
|
||||
|
||||
21
sgl-kernel/python/sgl_kernel/hadamard.py
Normal file
21
sgl-kernel/python/sgl_kernel/hadamard.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
|
||||
|
||||
def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.fast_hadamard_transform.default(x, scale)
|
||||
|
||||
|
||||
def hadamard_transform_12n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.fast_hadamard_transform_12N.default(x, scale)
|
||||
|
||||
|
||||
def hadamard_transform_20n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.fast_hadamard_transform_20N.default(x, scale)
|
||||
|
||||
|
||||
def hadamard_transform_28n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.fast_hadamard_transform_28N.default(x, scale)
|
||||
|
||||
|
||||
def hadamard_transform_40n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.fast_hadamard_transform_40N.default(x, scale)
|
||||
78
sgl-kernel/tests/test_hadamard.py
Normal file
78
sgl-kernel/tests/test_hadamard.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from scipy.linalg import hadamard
|
||||
from sgl_kernel import hadamard_transform
|
||||
|
||||
|
||||
def hadamard_transform_ref(x, scale=1.0):
|
||||
"""
|
||||
x: (..., dim)
|
||||
out: (..., dim)
|
||||
"""
|
||||
if hadamard is None:
|
||||
raise ImportError("Please install scipy")
|
||||
x_shape = x.shape
|
||||
dim = x.shape[-1]
|
||||
x = x.reshape(-1, dim)
|
||||
log_dim = math.ceil(math.log2(dim))
|
||||
dim_padded = 2**log_dim
|
||||
if dim != dim_padded:
|
||||
x = F.pad(x, (0, dim_padded - dim))
|
||||
out = F.linear(
|
||||
x,
|
||||
torch.tensor(hadamard(dim_padded, dtype=float), dtype=x.dtype, device=x.device),
|
||||
)
|
||||
out = out * scale
|
||||
return out[..., :dim].reshape(*x_shape)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize(
|
||||
"dim",
|
||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 137, 1024, 2048, 4096, 8192, 16384, 32768],
|
||||
)
|
||||
def test_fast_hadamard_transform(dim, dtype):
|
||||
device = "cuda"
|
||||
|
||||
if dtype == torch.float32:
|
||||
rtol, atol = 3e-4, 3e-3
|
||||
elif dtype == torch.bfloat16:
|
||||
rtol, atol = 1e-2, 5e-2
|
||||
else: # float16
|
||||
rtol, atol = 3e-3, 5e-3
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 15
|
||||
|
||||
x = torch.randn(batch_size, dim, device=device, dtype=dtype)
|
||||
x_ref = x.detach().clone().to(torch.float32)
|
||||
x_pt = x.detach().clone()
|
||||
|
||||
scale = 1 / math.sqrt(dim)
|
||||
|
||||
out = hadamard_transform(x, scale=scale)
|
||||
out_ref = hadamard_transform_ref(x_ref, scale=scale)
|
||||
out_pt = hadamard_transform_ref(x_pt, scale=scale)
|
||||
|
||||
torch.testing.assert_close(
|
||||
out_pt.float(),
|
||||
out_ref,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg="Reference implementations mismatch",
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
out.float(),
|
||||
out_ref,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg="fast_hadamard_transform output mismatch",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user