[sgl-kernel] support hadamard (#11663)
This commit is contained in:
@@ -270,6 +270,13 @@ from sgl_kernel.gemm import (
|
||||
silu_and_mul_scaled_fp4_grouped_quant,
|
||||
)
|
||||
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
||||
from sgl_kernel.hadamard import (
|
||||
hadamard_transform,
|
||||
hadamard_transform_12n,
|
||||
hadamard_transform_20n,
|
||||
hadamard_transform_28n,
|
||||
hadamard_transform_40n,
|
||||
)
|
||||
from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_all_layer,
|
||||
transfer_kv_all_layer_mla,
|
||||
|
||||
21
sgl-kernel/python/sgl_kernel/hadamard.py
Normal file
21
sgl-kernel/python/sgl_kernel/hadamard.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
|
||||
|
||||
def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.fast_hadamard_transform.default(x, scale)
|
||||
|
||||
|
||||
def hadamard_transform_12n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.fast_hadamard_transform_12N.default(x, scale)
|
||||
|
||||
|
||||
def hadamard_transform_20n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.fast_hadamard_transform_20N.default(x, scale)
|
||||
|
||||
|
||||
def hadamard_transform_28n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.fast_hadamard_transform_28N.default(x, scale)
|
||||
|
||||
|
||||
def hadamard_transform_40n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.fast_hadamard_transform_40N.default(x, scale)
|
||||
Reference in New Issue
Block a user