Rename files in sgl kernel to avoid nested folder structure (#4213)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
48
sgl-kernel/python/sgl_kernel/__init__.py
Normal file
48
sgl-kernel/python/sgl_kernel/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import ctypes
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
|
||||
ctypes.CDLL(
|
||||
"/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12",
|
||||
mode=ctypes.RTLD_GLOBAL,
|
||||
)
|
||||
|
||||
from sgl_kernel import common_ops
|
||||
from sgl_kernel.allreduce import *
|
||||
from sgl_kernel.attention import lightning_attention_decode
|
||||
from sgl_kernel.elementwise import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
gelu_tanh_and_mul,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.gemm import (
|
||||
bmm_fp8,
|
||||
cublas_grouped_gemm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
int8_scaled_mm,
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
from sgl_kernel.moe import moe_align_block_size
|
||||
from sgl_kernel.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
top_p_sampling_from_probs,
|
||||
)
|
||||
from sgl_kernel.speculative import (
|
||||
build_tree_kernel,
|
||||
build_tree_kernel_efficient,
|
||||
tree_speculative_sampling_target_only,
|
||||
)
|
||||
from sgl_kernel.version import __version__
|
||||
77
sgl-kernel/python/sgl_kernel/allreduce.py
Normal file
77
sgl-kernel/python/sgl_kernel/allreduce.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
if torch.version.hip is not None:
|
||||
# ROCM custom allreduce
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.ops.sgl_kernel.all_reduce_reg(fa, inp, out)
|
||||
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops.sgl_kernel.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernel.meta_size()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return torch.ops.sgl_kernel.register_buffer(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.allocate_meta_buffer(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
else:
|
||||
# TRTLLM custom allreduce
|
||||
def init_custom_reduce(
|
||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||
):
|
||||
return torch.ops.sgl_kernel.init_custom_ar(
|
||||
rank_id,
|
||||
num_devices,
|
||||
rank_data,
|
||||
buffers,
|
||||
tmp_buffers,
|
||||
barrier_in,
|
||||
barrier_out,
|
||||
)
|
||||
|
||||
def custom_dispose(fa):
|
||||
torch.ops.sgl_kernel.dispose(fa)
|
||||
|
||||
def custom_reduce(fa, inp, out):
|
||||
torch.ops.sgl_kernel.all_reduce(fa, inp, out)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa):
|
||||
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(fa, handles, offsets):
|
||||
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)
|
||||
7
sgl-kernel/python/sgl_kernel/attention.py
Normal file
7
sgl-kernel/python/sgl_kernel/attention.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import torch
|
||||
|
||||
|
||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
torch.ops.sgl_kernel.lightning_attention_decode(
|
||||
q, k, v, past_kv, slope, output, new_kv
|
||||
)
|
||||
152
sgl-kernel/python/sgl_kernel/elementwise.py
Normal file
152
sgl-kernel/python/sgl_kernel/elementwise.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream
|
||||
|
||||
|
||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||
# Kudos to @yzh119
|
||||
def rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps)
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm(
|
||||
input, residual, weight, eps, get_cuda_stream()
|
||||
)
|
||||
|
||||
|
||||
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
|
||||
assert (
|
||||
input.shape[:-1] == output.shape[:-1]
|
||||
), f"{input.shape[:-1]} != {output.shape[:-1]}"
|
||||
assert (
|
||||
input.shape[-1] == 2 * output.shape[-1]
|
||||
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
|
||||
|
||||
|
||||
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def apply_rope_with_cos_sin_cache_inplace(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
) -> None:
|
||||
r"""
|
||||
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
||||
This is designed to be compatible with the SGL/vLLM implementation.
|
||||
The result is inplace applied to the input tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : torch.Tensor
|
||||
Position indices, shape: ``(nnz)``.
|
||||
query : torch.Tensor
|
||||
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
|
||||
key : torch.Tensor
|
||||
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
|
||||
cos_sin_cache : torch.Tensor
|
||||
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
|
||||
Cosine is the first half and Sine is the second half on rotary_dim.
|
||||
is_neox : bool
|
||||
Whether to use Neox style RoPE, default: ``True``.
|
||||
|
||||
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
||||
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
||||
dimensions ``([..., head_dim//2:])``.
|
||||
|
||||
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
||||
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
||||
Note
|
||||
----
|
||||
The rotary dimension is determined by the cosine cache and sine cache.
|
||||
"""
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
positions = positions.int()
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
|
||||
q=query.view(query.shape[0], -1, head_size),
|
||||
k=key.view(key.shape[0], -1, head_size),
|
||||
q_rope=query.view(query.shape[0], -1, head_size),
|
||||
k_rope=key.view(key.shape[0], -1, head_size),
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
pos_ids=positions,
|
||||
interleave=(not is_neox),
|
||||
cuda_stream=get_cuda_stream(),
|
||||
)
|
||||
127
sgl-kernel/python/sgl_kernel/gemm.py
Normal file
127
sgl-kernel/python/sgl_kernel/gemm.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernel.int8_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
||||
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
)
|
||||
|
||||
|
||||
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernel.fp8_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
def _bmm_fp8_internal(
|
||||
workspace_buffer: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
D: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
) -> None:
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernel.bmm_fp8(
|
||||
A,
|
||||
B,
|
||||
D,
|
||||
A_scale,
|
||||
B_scale,
|
||||
workspace_buffer,
|
||||
cublas_handle,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def bmm_fp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty(
|
||||
(A.shape[0], A.shape[1], B.shape[2]),
|
||||
device=A.device,
|
||||
dtype=dtype,
|
||||
)
|
||||
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
|
||||
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
|
||||
return out
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float,
|
||||
fp8_min: float,
|
||||
fp8_max: float,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_tensor_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
is_static: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static)
|
||||
|
||||
|
||||
def cublas_grouped_gemm(
|
||||
inputs: List[torch.Tensor],
|
||||
weights: List[torch.Tensor],
|
||||
outputs: List[torch.Tensor],
|
||||
out_dtype: torch.dtype,
|
||||
) -> None:
|
||||
assert (
|
||||
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
||||
), "Inputs/weights/outputs should not be empty!"
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernel.cublas_grouped_gemm(
|
||||
inputs,
|
||||
weights,
|
||||
outputs,
|
||||
out_dtype,
|
||||
cublas_handle,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s)
|
||||
23
sgl-kernel/python/sgl_kernel/moe.py
Normal file
23
sgl-kernel/python/sgl_kernel/moe.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
):
|
||||
torch.ops.sgl_kernel.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
)
|
||||
210
sgl-kernel/python/sgl_kernel/sampling.py
Normal file
210
sgl-kernel/python/sgl_kernel/sampling.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream
|
||||
|
||||
|
||||
def _top_k_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
) -> torch.Tensor:
|
||||
probs = probs.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernel.top_k_renorm_probs_wrapper(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_k_arr,
|
||||
top_k_val,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_k_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
) -> torch.Tensor:
|
||||
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
|
||||
|
||||
|
||||
top_k_renorm_prob = top_k_renorm_probs
|
||||
|
||||
|
||||
def _top_p_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
) -> torch.Tensor:
|
||||
probs = probs.float()
|
||||
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernel.top_p_renorm_probs(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_p_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
) -> torch.Tensor:
|
||||
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
|
||||
|
||||
|
||||
top_p_renorm_prob = top_p_renorm_probs
|
||||
|
||||
|
||||
def _top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
torch.ops.sgl_kernel.top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
success,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
return samples, success
|
||||
|
||||
|
||||
def top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_p_sampling_from_probs_internal(
|
||||
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic
|
||||
)
|
||||
|
||||
|
||||
def _top_k_top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
success,
|
||||
maybe_top_k_arr,
|
||||
top_k_val,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
return samples, success
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
top_p: Union[torch.Tensor, float],
|
||||
filter_apply_order: str = "top_k_first",
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if filter_apply_order == "top_k_first":
|
||||
renorm_probs = top_k_renorm_probs(probs, top_k)
|
||||
return top_p_sampling_from_probs(
|
||||
renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan
|
||||
)
|
||||
elif filter_apply_order == "joint":
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_k_top_p_sampling_from_probs_internal(
|
||||
probs,
|
||||
uniform_samples,
|
||||
*_to_tensor_scalar_tuple(top_k),
|
||||
*_to_tensor_scalar_tuple(top_p),
|
||||
deterministic,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
||||
|
||||
|
||||
def _min_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_min_p_arr: Optional[torch.Tensor],
|
||||
min_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_min_p_arr = (
|
||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
torch.ops.sgl_kernel.min_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
maybe_min_p_arr,
|
||||
min_p_val,
|
||||
deterministic,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def min_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
min_p: Union[torch.Tensor, float],
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if uniform_samples.dim() == 2:
|
||||
# Take the first row (round) of uniform_samples
|
||||
uniform_samples = uniform_samples[0]
|
||||
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _min_p_sampling_from_probs_internal(
|
||||
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
|
||||
)
|
||||
83
sgl-kernel/python/sgl_kernel/speculative.py
Normal file
83
sgl-kernel/python/sgl_kernel/speculative.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream
|
||||
|
||||
|
||||
def tree_speculative_sampling_target_only(
|
||||
predicts: torch.Tensor, # mutable
|
||||
accept_index: torch.Tensor, # mutable
|
||||
accept_token_num: torch.Tensor, # mutable
|
||||
candidates: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
deterministic: bool = True,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
|
||||
predicts,
|
||||
accept_index,
|
||||
accept_token_num,
|
||||
candidates,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
uniform_samples,
|
||||
target_probs,
|
||||
draft_probs,
|
||||
deterministic,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel_efficient(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.build_tree_kernel_efficient(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.build_tree_kernel(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
)
|
||||
41
sgl-kernel/python/sgl_kernel/utils.py
Normal file
41
sgl-kernel/python/sgl_kernel/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_cuda_stream() -> int:
|
||||
return torch.cuda.current_stream().cuda_stream
|
||||
|
||||
|
||||
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
|
||||
|
||||
|
||||
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
|
||||
key = (name, device)
|
||||
buf = _cache_buf.get(key)
|
||||
if buf is None:
|
||||
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
|
||||
_cache_buf[key] = buf
|
||||
return buf
|
||||
|
||||
|
||||
def _to_tensor_scalar_tuple(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return (x, 0)
|
||||
else:
|
||||
return (None, x)
|
||||
1
sgl-kernel/python/sgl_kernel/version.py
Normal file
1
sgl-kernel/python/sgl_kernel/version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.0.3.post7"
|
||||
Reference in New Issue
Block a user