Rename files in sgl kernel to avoid nested folder structure (#4213)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Lianmin Zheng
2025-03-08 22:54:51 -08:00
committed by GitHub
parent ee132a4515
commit 8abf74e3c9
47 changed files with 184 additions and 199 deletions

View File

@@ -0,0 +1,48 @@
import ctypes
import os
import torch
if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
ctypes.CDLL(
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12",
mode=ctypes.RTLD_GLOBAL,
)
from sgl_kernel import common_ops
from sgl_kernel.allreduce import *
from sgl_kernel.attention import lightning_attention_decode
from sgl_kernel.elementwise import (
apply_rope_with_cos_sin_cache_inplace,
fused_add_rmsnorm,
gelu_and_mul,
gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
silu_and_mul,
)
from sgl_kernel.gemm import (
bmm_fp8,
cublas_grouped_gemm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
int8_scaled_mm,
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8,
sgl_per_token_quant_fp8,
)
from sgl_kernel.moe import moe_align_block_size
from sgl_kernel.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
top_p_sampling_from_probs,
)
from sgl_kernel.speculative import (
build_tree_kernel,
build_tree_kernel_efficient,
tree_speculative_sampling_target_only,
)
from sgl_kernel.version import __version__

View File

@@ -0,0 +1,77 @@
from typing import List, Tuple
import torch
if torch.version.hip is not None:
# ROCM custom allreduce
def init_custom_ar(
meta: torch.Tensor,
rank_data: torch.Tensor,
handles: List[str],
offsets: List[int],
rank: int,
full_nvlink: bool,
) -> int:
return torch.ops.sgl_kernel.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink
)
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops.sgl_kernel.all_reduce_reg(fa, inp, out)
def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
torch.ops.sgl_kernel.all_reduce_unreg(fa, inp, reg_buffer, out)
def dispose(fa: int) -> None:
torch.ops.sgl_kernel.dispose(fa)
def meta_size() -> int:
return torch.ops.sgl_kernel.meta_size()
def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return torch.ops.sgl_kernel.register_buffer(fa, t, handles, offsets)
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)
def allocate_meta_buffer(size: int) -> torch.Tensor:
return torch.ops.sgl_kernel.allocate_meta_buffer(size)
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle(inp)
else:
# TRTLLM custom allreduce
def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
):
return torch.ops.sgl_kernel.init_custom_ar(
rank_id,
num_devices,
rank_data,
buffers,
tmp_buffers,
barrier_in,
barrier_out,
)
def custom_dispose(fa):
torch.ops.sgl_kernel.dispose(fa)
def custom_reduce(fa, inp, out):
torch.ops.sgl_kernel.all_reduce(fa, inp, out)
def get_graph_buffer_ipc_meta(fa):
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa, handles, offsets):
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)

View File

@@ -0,0 +1,7 @@
import torch
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
torch.ops.sgl_kernel.lightning_attention_decode(
q, k, v, past_kv, slope, output, new_kv
)

View File

@@ -0,0 +1,152 @@
from typing import Optional
import torch
from sgl_kernel.utils import get_cuda_stream
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
# Kudos to @yzh119
def rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty_like(input)
torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream())
return out
def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps)
def gemma_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty_like(input)
torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
return out
def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm(
input, residual, weight, eps, get_cuda_stream()
)
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
assert (
input.shape[:-1] == output.shape[:-1]
), f"{input.shape[:-1]} != {output.shape[:-1]}"
assert (
input.shape[-1] == 2 * output.shape[-1]
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream())
return out
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream())
return out
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
raise ValueError("The pointers must be multiple of 16 bytes.")
if out is not None:
_check_shape(input, out)
else:
out = torch.empty(
input.shape[:-1] + (input.shape[-1] // 2,),
device=input.device,
dtype=input.dtype,
)
torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream())
return out
def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
) -> None:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
This is designed to be compatible with the SGL/vLLM implementation.
The result is inplace applied to the input tensors.
Parameters
----------
positions : torch.Tensor
Position indices, shape: ``(nnz)``.
query : torch.Tensor
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
key : torch.Tensor
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
cos_sin_cache : torch.Tensor
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
Cosine is the first half and Sine is the second half on rotary_dim.
is_neox : bool
Whether to use Neox style RoPE, default: ``True``.
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
dimensions ``([..., head_dim//2:])``.
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
"""
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
positions = positions.int()
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size),
k_rope=key.view(key.shape[0], -1, head_size),
cos_sin_cache=cos_sin_cache,
pos_ids=positions,
interleave=(not is_neox),
cuda_stream=get_cuda_stream(),
)

View File

@@ -0,0 +1,127 @@
from typing import List, Optional
import torch
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.int8_scaled_mm(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
)
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.fp8_scaled_mm(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def _bmm_fp8_internal(
workspace_buffer: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
D: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
) -> None:
cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernel.bmm_fp8(
A,
B,
D,
A_scale,
B_scale,
workspace_buffer,
cublas_handle,
get_cuda_stream(),
)
def bmm_fp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty(
(A.shape[0], A.shape[1], B.shape[2]),
device=A.device,
dtype=dtype,
)
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
return out
def sgl_per_token_group_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
fp8_min: float,
fp8_max: float,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
)
def sgl_per_tensor_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
is_static: bool,
) -> None:
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static)
def cublas_grouped_gemm(
inputs: List[torch.Tensor],
weights: List[torch.Tensor],
outputs: List[torch.Tensor],
out_dtype: torch.dtype,
) -> None:
assert (
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
), "Inputs/weights/outputs should not be empty!"
cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernel.cublas_grouped_gemm(
inputs,
weights,
outputs,
out_dtype,
cublas_handle,
get_cuda_stream(),
)
def sgl_per_token_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s)

View File

@@ -0,0 +1,23 @@
import torch
def moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
):
torch.ops.sgl_kernel.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
)

View File

@@ -0,0 +1,210 @@
from typing import Optional, Tuple, Union
import torch
from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream
def _top_k_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
) -> torch.Tensor:
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_k_renorm_probs_wrapper(
probs,
renorm_probs,
maybe_top_k_arr,
top_k_val,
get_cuda_stream(),
)
return renorm_probs
def top_k_renorm_probs(
probs: torch.Tensor,
top_k: Union[torch.Tensor, int],
) -> torch.Tensor:
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
top_k_renorm_prob = top_k_renorm_probs
def _top_p_renorm_probs_internal(
probs: torch.Tensor,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
) -> torch.Tensor:
probs = probs.float()
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernel.top_p_renorm_probs(
probs,
renorm_probs,
maybe_top_p_arr,
top_p_val,
get_cuda_stream(),
)
return renorm_probs
def top_p_renorm_probs(
probs: torch.Tensor,
top_p: Union[torch.Tensor, float],
) -> torch.Tensor:
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
top_p_renorm_prob = top_p_renorm_probs
def _top_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernel.top_p_sampling_from_probs(
probs,
uniform_samples,
samples,
success,
maybe_top_p_arr,
top_p_val,
deterministic,
get_cuda_stream(),
)
return samples, success
def top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_p: Union[torch.Tensor, float],
deterministic: bool = True,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_p_sampling_from_probs_internal(
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic
)
def _top_k_top_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
maybe_top_k_arr: Optional[torch.Tensor],
top_k_val: int,
maybe_top_p_arr: Optional[torch.Tensor],
top_p_val: float,
deterministic: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
maybe_top_p_arr = (
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
samples,
success,
maybe_top_k_arr,
top_k_val,
maybe_top_p_arr,
top_p_val,
deterministic,
get_cuda_stream(),
)
return samples, success
def top_k_top_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
top_k: Union[torch.Tensor, int],
top_p: Union[torch.Tensor, float],
filter_apply_order: str = "top_k_first",
deterministic: bool = True,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if filter_apply_order == "top_k_first":
renorm_probs = top_k_renorm_probs(probs, top_k)
return top_p_sampling_from_probs(
renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan
)
elif filter_apply_order == "joint":
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _top_k_top_p_sampling_from_probs_internal(
probs,
uniform_samples,
*_to_tensor_scalar_tuple(top_k),
*_to_tensor_scalar_tuple(top_p),
deterministic,
)
else:
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
def _min_p_sampling_from_probs_internal(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
maybe_min_p_arr: Optional[torch.Tensor],
min_p_val: float,
deterministic: bool,
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
uniform_samples = uniform_samples.float()
maybe_min_p_arr = (
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernel.min_p_sampling_from_probs(
probs,
uniform_samples,
samples,
maybe_min_p_arr,
min_p_val,
deterministic,
get_cuda_stream(),
)
return samples
def min_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
min_p: Union[torch.Tensor, float],
deterministic: bool = True,
check_nan: bool = False,
) -> torch.Tensor:
if uniform_samples.dim() == 2:
# Take the first row (round) of uniform_samples
uniform_samples = uniform_samples[0]
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _min_p_sampling_from_probs_internal(
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
)

View File

@@ -0,0 +1,83 @@
import torch
from sgl_kernel.utils import get_cuda_stream
def tree_speculative_sampling_target_only(
predicts: torch.Tensor, # mutable
accept_index: torch.Tensor, # mutable
accept_token_num: torch.Tensor, # mutable
candidates: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
uniform_samples: torch.Tensor,
target_probs: torch.Tensor,
draft_probs: torch.Tensor,
deterministic: bool = True,
) -> None:
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
predicts,
accept_index,
accept_token_num,
candidates,
retrive_index,
retrive_next_token,
retrive_next_sibling,
uniform_samples,
target_probs,
draft_probs,
deterministic,
get_cuda_stream(),
)
def build_tree_kernel_efficient(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
retrive_next_token: torch.Tensor,
retrive_next_sibling: torch.Tensor,
topk: int,
depth: int,
draft_token_num: int,
) -> None:
torch.ops.sgl_kernel.build_tree_kernel_efficient(
parent_list,
selected_index,
verified_seq_len,
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
topk,
depth,
draft_token_num,
)
def build_tree_kernel(
parent_list: torch.Tensor,
selected_index: torch.Tensor,
verified_seq_len: torch.Tensor,
tree_mask: torch.Tensor,
positions: torch.Tensor,
retrive_index: torch.Tensor,
topk: int,
depth: int,
draft_token_num: int,
) -> None:
torch.ops.sgl_kernel.build_tree_kernel(
parent_list,
selected_index,
verified_seq_len,
tree_mask,
positions,
retrive_index,
topk,
depth,
draft_token_num,
)

View File

@@ -0,0 +1,41 @@
# Copyright 2025 SGLang Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Dict, Tuple
import torch
def get_cuda_stream() -> int:
return torch.cuda.current_stream().cuda_stream
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
key = (name, device)
buf = _cache_buf.get(key)
if buf is None:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
return buf
def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)

View File

@@ -0,0 +1 @@
__version__ = "0.0.3.post7"