2025-02-07 19:32:45 +08:00
|
|
|
import ctypes
|
|
|
|
|
import os
|
2025-04-30 00:06:16 +02:00
|
|
|
import platform
|
2025-09-17 08:01:45 +02:00
|
|
|
import shutil
|
|
|
|
|
from pathlib import Path
|
2025-02-07 19:32:45 +08:00
|
|
|
|
2025-03-02 15:19:06 -08:00
|
|
|
import torch
|
|
|
|
|
|
2025-04-30 00:06:16 +02:00
|
|
|
|
2025-09-17 08:01:45 +02:00
|
|
|
# copy & modify from torch/utils/cpp_extension.py
|
|
|
|
|
def _find_cuda_home():
|
|
|
|
|
"""Find the CUDA install path."""
|
|
|
|
|
# Guess #1
|
|
|
|
|
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
|
|
|
|
|
if cuda_home is None:
|
|
|
|
|
# Guess #2
|
|
|
|
|
nvcc_path = shutil.which("nvcc")
|
|
|
|
|
if nvcc_path is not None:
|
|
|
|
|
cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
|
|
|
|
|
else:
|
|
|
|
|
# Guess #3
|
|
|
|
|
cuda_home = "/usr/local/cuda"
|
|
|
|
|
return cuda_home
|
|
|
|
|
|
|
|
|
|
|
2025-09-19 02:38:02 +08:00
|
|
|
if torch.version.cuda is not None:
|
2025-09-17 08:01:45 +02:00
|
|
|
cuda_home = Path(_find_cuda_home())
|
|
|
|
|
|
|
|
|
|
if (cuda_home / "lib").is_dir():
|
|
|
|
|
cuda_path = cuda_home / "lib"
|
|
|
|
|
elif (cuda_home / "lib64").is_dir():
|
|
|
|
|
cuda_path = cuda_home / "lib64"
|
|
|
|
|
else:
|
|
|
|
|
# Search for 'libcudart.so.12' in subdirectories
|
|
|
|
|
for path in cuda_home.rglob("libcudart.so.12"):
|
|
|
|
|
cuda_path = path.parent
|
|
|
|
|
break
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError("Could not find CUDA lib directory.")
|
|
|
|
|
|
|
|
|
|
cuda_include = (cuda_path / "libcudart.so.12").resolve()
|
|
|
|
|
if cuda_include.exists():
|
|
|
|
|
ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
|
2025-02-07 19:32:45 +08:00
|
|
|
|
2025-03-08 22:54:51 -08:00
|
|
|
from sgl_kernel import common_ops
|
|
|
|
|
from sgl_kernel.allreduce import *
|
2025-04-11 22:16:51 -07:00
|
|
|
from sgl_kernel.attention import (
|
|
|
|
|
cutlass_mla_decode,
|
|
|
|
|
cutlass_mla_get_workspace_size,
|
|
|
|
|
lightning_attention_decode,
|
2025-04-12 21:14:04 -07:00
|
|
|
merge_state,
|
2025-04-15 12:28:23 +08:00
|
|
|
merge_state_v2,
|
2025-04-11 22:16:51 -07:00
|
|
|
)
|
2025-07-05 11:50:12 +08:00
|
|
|
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
|
2025-03-08 22:54:51 -08:00
|
|
|
from sgl_kernel.elementwise import (
|
2025-08-12 16:46:40 +08:00
|
|
|
FusedSetKVBufferArg,
|
2025-03-03 06:36:40 -08:00
|
|
|
apply_rope_with_cos_sin_cache_inplace,
|
2025-09-16 02:53:21 +08:00
|
|
|
concat_mla_absorb_q,
|
2025-09-09 00:00:33 +08:00
|
|
|
concat_mla_k,
|
2025-09-05 20:07:19 +08:00
|
|
|
copy_to_gpu_no_ce,
|
2025-08-18 23:45:00 -07:00
|
|
|
downcast_fp8,
|
2025-03-03 06:36:40 -08:00
|
|
|
fused_add_rmsnorm,
|
|
|
|
|
gelu_and_mul,
|
|
|
|
|
gelu_tanh_and_mul,
|
|
|
|
|
gemma_fused_add_rmsnorm,
|
|
|
|
|
gemma_rmsnorm,
|
|
|
|
|
rmsnorm,
|
|
|
|
|
silu_and_mul,
|
|
|
|
|
)
|
2025-08-18 09:38:35 -07:00
|
|
|
from sgl_kernel.fused_moe import fused_marlin_moe
|
2025-03-08 22:54:51 -08:00
|
|
|
from sgl_kernel.gemm import (
|
2025-03-12 00:10:02 -07:00
|
|
|
awq_dequantize,
|
2025-03-03 06:36:40 -08:00
|
|
|
bmm_fp8,
|
2025-03-24 19:50:23 -07:00
|
|
|
cutlass_scaled_fp4_mm,
|
2025-06-29 17:52:24 +08:00
|
|
|
dsv3_fused_a_gemm,
|
2025-06-29 23:31:55 -07:00
|
|
|
dsv3_router_gemm,
|
2025-03-03 06:36:40 -08:00
|
|
|
fp8_blockwise_scaled_mm,
|
|
|
|
|
fp8_scaled_mm,
|
2025-08-14 18:19:03 +08:00
|
|
|
gptq_gemm,
|
|
|
|
|
gptq_marlin_gemm,
|
|
|
|
|
gptq_shuffle,
|
2025-03-03 06:36:40 -08:00
|
|
|
int8_scaled_mm,
|
2025-05-22 10:48:59 +08:00
|
|
|
qserve_w4a8_per_chn_gemm,
|
|
|
|
|
qserve_w4a8_per_group_gemm,
|
2025-06-02 13:48:03 -07:00
|
|
|
scaled_fp4_experts_quant,
|
2025-08-22 12:19:45 -07:00
|
|
|
scaled_fp4_grouped_quant,
|
2025-03-24 19:50:23 -07:00
|
|
|
scaled_fp4_quant,
|
2025-03-07 10:05:43 +08:00
|
|
|
sgl_per_tensor_quant_fp8,
|
2025-09-10 18:24:23 -07:00
|
|
|
sgl_per_token_group_quant_fp8,
|
|
|
|
|
sgl_per_token_group_quant_int8,
|
2025-03-06 20:53:05 -08:00
|
|
|
sgl_per_token_quant_fp8,
|
2025-06-02 13:48:03 -07:00
|
|
|
shuffle_rows,
|
2025-08-22 12:19:45 -07:00
|
|
|
silu_and_mul_scaled_fp4_grouped_quant,
|
2025-03-03 06:36:40 -08:00
|
|
|
)
|
2025-04-23 01:18:30 -07:00
|
|
|
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
2025-06-23 11:58:59 -07:00
|
|
|
from sgl_kernel.kvcacheio import (
|
|
|
|
|
transfer_kv_all_layer,
|
|
|
|
|
transfer_kv_all_layer_mla,
|
|
|
|
|
transfer_kv_per_layer,
|
|
|
|
|
transfer_kv_per_layer_mla,
|
|
|
|
|
)
|
2025-09-12 22:20:21 -07:00
|
|
|
from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update
|
2025-07-02 14:21:25 +08:00
|
|
|
from sgl_kernel.marlin import (
|
|
|
|
|
awq_marlin_moe_repack,
|
|
|
|
|
awq_marlin_repack,
|
|
|
|
|
gptq_marlin_repack,
|
|
|
|
|
)
|
2025-08-12 16:56:51 -07:00
|
|
|
from sgl_kernel.memory import set_kv_buffer_kernel
|
2025-04-22 22:28:20 -07:00
|
|
|
from sgl_kernel.moe import (
|
2025-06-07 15:24:39 -07:00
|
|
|
apply_shuffle_mul_sum,
|
2025-06-02 13:48:03 -07:00
|
|
|
cutlass_fp4_group_mm,
|
2025-04-22 22:28:20 -07:00
|
|
|
fp8_blockwise_scaled_grouped_mm,
|
|
|
|
|
moe_align_block_size,
|
|
|
|
|
moe_fused_gate,
|
2025-05-16 13:14:07 -07:00
|
|
|
prepare_moe_input,
|
2025-04-22 22:28:20 -07:00
|
|
|
topk_softmax,
|
|
|
|
|
)
|
2025-03-08 22:54:51 -08:00
|
|
|
from sgl_kernel.sampling import (
|
2025-03-03 06:36:40 -08:00
|
|
|
min_p_sampling_from_probs,
|
2025-08-15 01:56:36 +08:00
|
|
|
top_k_mask_logits,
|
2025-03-03 06:36:40 -08:00
|
|
|
top_k_renorm_prob,
|
2025-08-15 01:56:36 +08:00
|
|
|
top_k_top_p_sampling_from_logits,
|
2025-03-03 06:36:40 -08:00
|
|
|
top_k_top_p_sampling_from_probs,
|
|
|
|
|
top_p_renorm_prob,
|
|
|
|
|
top_p_sampling_from_probs,
|
|
|
|
|
)
|
2025-08-18 23:45:00 -07:00
|
|
|
from sgl_kernel.speculative import (
|
|
|
|
|
build_tree_kernel_efficient,
|
|
|
|
|
segment_packbits,
|
|
|
|
|
tree_speculative_sampling_target_only,
|
|
|
|
|
verify_tree_greedy,
|
|
|
|
|
)
|
|
|
|
|
from sgl_kernel.top_k import fast_topk
|
|
|
|
|
from sgl_kernel.version import __version__
|
2025-08-15 21:33:52 +08:00
|
|
|
|
2025-09-12 22:20:21 -07:00
|
|
|
if torch.version.hip is not None:
|
|
|
|
|
from sgl_kernel.elementwise import gelu_quick
|
|
|
|
|
|
2025-08-15 21:33:52 +08:00
|
|
|
|
|
|
|
|
def create_greenctx_stream_by_value(*args, **kwargs):
|
|
|
|
|
from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl
|
|
|
|
|
|
|
|
|
|
return _impl(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_sm_available(*args, **kwargs):
|
|
|
|
|
from sgl_kernel.spatial import get_sm_available as _impl
|
|
|
|
|
|
|
|
|
|
return _impl(*args, **kwargs)
|