Reorganize python source files in sgl-kernel with multiple files (#4027)
This commit is contained in:
@@ -9,105 +9,37 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
|
||||
mode=ctypes.RTLD_GLOBAL,
|
||||
)
|
||||
|
||||
from sgl_kernel.ops.activation import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
gelu_tanh_and_mul,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.ops.allreduce import *
|
||||
from sgl_kernel.ops.attention import lightning_attention_decode
|
||||
from sgl_kernel.ops.gemm import (
|
||||
bmm_fp8,
|
||||
cublas_grouped_gemm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
int8_scaled_mm,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
)
|
||||
from sgl_kernel.ops.moe import moe_align_block_size
|
||||
from sgl_kernel.ops.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
top_p_sampling_from_probs,
|
||||
)
|
||||
from sgl_kernel.ops.speculative import (
|
||||
build_tree_kernel,
|
||||
build_tree_kernel_efficient,
|
||||
tree_speculative_sampling_target_only,
|
||||
)
|
||||
from sgl_kernel.version import __version__
|
||||
|
||||
if torch.version.cuda:
|
||||
from sgl_kernel.ops import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
bmm_fp8,
|
||||
build_tree_kernel,
|
||||
build_tree_kernel_efficient,
|
||||
cublas_grouped_gemm,
|
||||
custom_dispose,
|
||||
custom_reduce,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
gelu_tanh_and_mul,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
get_graph_buffer_ipc_meta,
|
||||
init_custom_reduce,
|
||||
int8_scaled_mm,
|
||||
lightning_attention_decode,
|
||||
min_p_sampling_from_probs,
|
||||
moe_align_block_size,
|
||||
register_graph_buffers,
|
||||
rmsnorm,
|
||||
sampling_scaling_penalties,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
silu_and_mul,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
tree_speculative_sampling_target_only,
|
||||
)
|
||||
|
||||
else:
|
||||
assert torch.version.hip
|
||||
|
||||
from sgl_kernel.ops import (
|
||||
all_reduce_reg,
|
||||
all_reduce_unreg,
|
||||
allocate_meta_buffer,
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
bmm_fp8,
|
||||
dispose,
|
||||
fp8_scaled_mm,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
gelu_tanh_and_mul,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
get_graph_buffer_ipc_meta,
|
||||
get_meta_buffer_ipc_handle,
|
||||
init_custom_ar,
|
||||
int8_scaled_mm,
|
||||
lightning_attention_decode,
|
||||
meta_size,
|
||||
min_p_sampling_from_probs,
|
||||
moe_align_block_size,
|
||||
register_buffer,
|
||||
register_graph_buffers,
|
||||
rmsnorm,
|
||||
sampling_scaling_penalties,
|
||||
silu_and_mul,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"apply_rope_with_cos_sin_cache_inplace",
|
||||
"bmm_fp8",
|
||||
"cublas_grouped_gemm",
|
||||
"custom_dispose",
|
||||
"custom_reduce",
|
||||
"build_tree_kernel_efficient",
|
||||
"build_tree_kernel",
|
||||
"fp8_blockwise_scaled_mm",
|
||||
"fp8_scaled_mm",
|
||||
"fused_add_rmsnorm",
|
||||
"gelu_and_mul",
|
||||
"gelu_tanh_and_mul",
|
||||
"gemma_fused_add_rmsnorm",
|
||||
"gemma_rmsnorm",
|
||||
"get_graph_buffer_ipc_meta",
|
||||
"init_custom_reduce",
|
||||
"int8_scaled_mm",
|
||||
"lightning_attention_decode",
|
||||
"min_p_sampling_from_probs",
|
||||
"moe_align_block_size",
|
||||
"register_graph_buffers",
|
||||
"rmsnorm",
|
||||
"sampling_scaling_penalties",
|
||||
"sgl_per_token_group_quant_fp8",
|
||||
"silu_and_mul",
|
||||
"top_k_renorm_prob",
|
||||
"top_k_top_p_sampling_from_probs",
|
||||
"top_p_renorm_prob",
|
||||
"tree_speculative_sampling_target_only",
|
||||
]
|
||||
|
||||
@@ -1,677 +0,0 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
from sgl_kernel.ops.utils import (
|
||||
_get_cache_buf,
|
||||
_get_cuda_stream,
|
||||
_to_tensor_scalar_tuple,
|
||||
)
|
||||
|
||||
|
||||
def apply_rope_with_cos_sin_cache_inplace(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
) -> None:
|
||||
r"""
|
||||
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
||||
This is designed to be compatible with the SGL/vLLM implementation.
|
||||
The result is inplace applied to the input tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : torch.Tensor
|
||||
Position indices, shape: ``(nnz)``.
|
||||
query : torch.Tensor
|
||||
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
|
||||
key : torch.Tensor
|
||||
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
|
||||
cos_sin_cache : torch.Tensor
|
||||
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
|
||||
Cosine is the first half and Sine is the second half on rotary_dim.
|
||||
is_neox : bool
|
||||
Whether to use Neox style RoPE, default: ``True``.
|
||||
|
||||
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
||||
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
||||
dimensions ``([..., head_dim//2:])``.
|
||||
|
||||
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
||||
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
||||
Note
|
||||
----
|
||||
The rotary dimension is determined by the cosine cache and sine cache.
|
||||
"""
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
with query.device as device:
|
||||
positions = positions.int()
|
||||
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache(
|
||||
q=query.view(query.shape[0], -1, head_size),
|
||||
k=key.view(key.shape[0], -1, head_size),
|
||||
q_rope=query.view(query.shape[0], -1, head_size),
|
||||
k_rope=key.view(key.shape[0], -1, head_size),
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
pos_ids=positions,
|
||||
interleave=(not is_neox),
|
||||
cuda_stream=_get_cuda_stream(device),
|
||||
)
|
||||
|
||||
|
||||
if torch.version.hip is not None:
|
||||
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernels.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out)
|
||||
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops.sgl_kernels.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernels.meta_size()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernels.allocate_meta_buffer(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
else:
|
||||
# trt_reduce
|
||||
def init_custom_reduce(
|
||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||
):
|
||||
return torch.ops.sgl_kernels.init_custom_ar(
|
||||
rank_id,
|
||||
num_devices,
|
||||
rank_data,
|
||||
buffers,
|
||||
tmp_buffers,
|
||||
barrier_in,
|
||||
barrier_out,
|
||||
)
|
||||
|
||||
def custom_dispose(fa):
|
||||
torch.ops.sgl_kernels.dispose(fa)
|
||||
|
||||
def custom_reduce(fa, inp, out):
|
||||
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa):
|
||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(fa, handles, offsets):
|
||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
):
|
||||
torch.ops.sgl_kernels.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
)
|
||||
|
||||
|
||||
def sampling_scaling_penalties(logits, scaling_penalties):
|
||||
return torch.ops.sgl_kernels.sampling_scaling_penalties(logits, scaling_penalties)
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernels.int8_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
||||
return torch.ops.sgl_kernels.fp8_blockwise_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
)
|
||||
|
||||
|
||||
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernels.fp8_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
torch.ops.sgl_kernels.lightning_attention_decode(
|
||||
q, k, v, past_kv, slope, output, new_kv
|
||||
)
|
||||
|
||||
|
||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||
# Kudos to @yzh119
|
||||
def rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
with input.device as device:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
with input.device as device:
|
||||
torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps)
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
with input.device as device:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernels.gemma_rmsnorm(
|
||||
out, input, weight, eps, _get_cuda_stream(device)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
with input.device as device:
|
||||
torch.ops.sgl_kernels.gemma_fused_add_rmsnorm(
|
||||
input, residual, weight, eps, _get_cuda_stream(device)
|
||||
)
|
||||
|
||||
|
||||
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
|
||||
assert (
|
||||
input.shape[:-1] == output.shape[:-1]
|
||||
), f"{input.shape[:-1]} != {output.shape[:-1]}"
|
||||
assert (
|
||||
input.shape[-1] == 2 * output.shape[-1]
|
||||
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
|
||||
|
||||
|
||||
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
torch.ops.sgl_kernels.silu_and_mul(out, input, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
torch.ops.sgl_kernels.gelu_and_mul(out, input, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
def _bmm_fp8_internal(
|
||||
workspace_buffer: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
D: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
) -> None:
|
||||
with A.device as device:
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernels.bmm_fp8(
|
||||
A,
|
||||
B,
|
||||
D,
|
||||
A_scale,
|
||||
B_scale,
|
||||
workspace_buffer,
|
||||
cublas_handle,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
|
||||
|
||||
def bmm_fp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty(
|
||||
(A.shape[0], A.shape[1], B.shape[2]),
|
||||
device=A.device,
|
||||
dtype=dtype,
|
||||
)
|
||||
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
|
||||
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
|
||||
return out
|
||||
|
||||
|
||||
def _top_k_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernels.top_k_renorm_probs_wrapper(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_k_arr,
|
||||
top_k_val,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_k_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
) -> torch.Tensor:
|
||||
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
|
||||
|
||||
|
||||
top_k_renorm_prob = top_k_renorm_probs
|
||||
|
||||
|
||||
def _top_p_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernels.top_p_renorm_probs(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_p_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
) -> torch.Tensor:
|
||||
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
|
||||
|
||||
|
||||
top_p_renorm_prob = top_p_renorm_probs
|
||||
|
||||
|
||||
def _top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
torch.ops.sgl_kernels.top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
success,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
return samples, success
|
||||
|
||||
|
||||
def top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_p_sampling_from_probs_internal(
|
||||
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic
|
||||
)
|
||||
|
||||
|
||||
def _top_k_top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
success,
|
||||
maybe_top_k_arr,
|
||||
top_k_val,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
return samples, success
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
top_p: Union[torch.Tensor, float],
|
||||
filter_apply_order: str = "top_k_first",
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if filter_apply_order == "top_k_first":
|
||||
renorm_probs = top_k_renorm_probs(probs, top_k)
|
||||
return top_p_sampling_from_probs(
|
||||
renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan
|
||||
)
|
||||
elif filter_apply_order == "joint":
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_k_top_p_sampling_from_probs_internal(
|
||||
probs,
|
||||
uniform_samples,
|
||||
*_to_tensor_scalar_tuple(top_k),
|
||||
*_to_tensor_scalar_tuple(top_p),
|
||||
deterministic,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
||||
|
||||
|
||||
def _min_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_min_p_arr: Optional[torch.Tensor],
|
||||
min_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_min_p_arr = (
|
||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
torch.ops.sgl_kernels.min_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
maybe_min_p_arr,
|
||||
min_p_val,
|
||||
deterministic,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def min_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
min_p: Union[torch.Tensor, float],
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if uniform_samples.dim() == 2:
|
||||
# Take the first row (round) of uniform_samples
|
||||
uniform_samples = uniform_samples[0]
|
||||
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _min_p_sampling_from_probs_internal(
|
||||
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
|
||||
)
|
||||
|
||||
|
||||
def tree_speculative_sampling_target_only(
|
||||
predicts: torch.Tensor, # mutable
|
||||
accept_index: torch.Tensor, # mutable
|
||||
accept_token_num: torch.Tensor, # mutable
|
||||
candidates: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
deterministic: bool = True,
|
||||
) -> None:
|
||||
with predicts.device as device:
|
||||
torch.ops.sgl_kernels.tree_speculative_sampling_target_only(
|
||||
predicts,
|
||||
accept_index,
|
||||
accept_token_num,
|
||||
candidates,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
uniform_samples,
|
||||
target_probs,
|
||||
draft_probs,
|
||||
deterministic,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel_efficient(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
with parent_list.device as device:
|
||||
torch.ops.sgl_kernels.build_tree_kernel_efficient(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
with parent_list.device as device:
|
||||
torch.ops.sgl_kernels.build_tree_kernel(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float,
|
||||
fp8_min: float,
|
||||
fp8_max: float,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
||||
)
|
||||
|
||||
|
||||
def cublas_grouped_gemm(
|
||||
inputs: List[torch.Tensor],
|
||||
weights: List[torch.Tensor],
|
||||
outputs: List[torch.Tensor],
|
||||
out_dtype: torch.dtype,
|
||||
) -> None:
|
||||
with inputs[0].device as device:
|
||||
assert (
|
||||
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
||||
), "Inputs/weights/outputs should not be empty!"
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernels.cublas_grouped_gemm(
|
||||
inputs,
|
||||
weights,
|
||||
outputs,
|
||||
out_dtype,
|
||||
cublas_handle,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
153
sgl-kernel/src/sgl-kernel/ops/activation.py
Normal file
153
sgl-kernel/src/sgl-kernel/ops/activation.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from typing import Optional
|
||||
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
from sgl_kernel.ops.utils import get_cuda_stream
|
||||
|
||||
|
||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||
# Kudos to @yzh119
|
||||
def rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps)
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernels.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.gemma_fused_add_rmsnorm(
|
||||
input, residual, weight, eps, get_cuda_stream()
|
||||
)
|
||||
|
||||
|
||||
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
|
||||
assert (
|
||||
input.shape[:-1] == output.shape[:-1]
|
||||
), f"{input.shape[:-1]} != {output.shape[:-1]}"
|
||||
assert (
|
||||
input.shape[-1] == 2 * output.shape[-1]
|
||||
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
|
||||
|
||||
|
||||
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernels.silu_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernels.gelu_and_mul(out, input, get_cuda_stream())
|
||||
return out
|
||||
|
||||
|
||||
def apply_rope_with_cos_sin_cache_inplace(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
) -> None:
|
||||
r"""
|
||||
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
||||
This is designed to be compatible with the SGL/vLLM implementation.
|
||||
The result is inplace applied to the input tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : torch.Tensor
|
||||
Position indices, shape: ``(nnz)``.
|
||||
query : torch.Tensor
|
||||
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
|
||||
key : torch.Tensor
|
||||
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
|
||||
cos_sin_cache : torch.Tensor
|
||||
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
|
||||
Cosine is the first half and Sine is the second half on rotary_dim.
|
||||
is_neox : bool
|
||||
Whether to use Neox style RoPE, default: ``True``.
|
||||
|
||||
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
||||
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
||||
dimensions ``([..., head_dim//2:])``.
|
||||
|
||||
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
||||
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
||||
Note
|
||||
----
|
||||
The rotary dimension is determined by the cosine cache and sine cache.
|
||||
"""
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
positions = positions.int()
|
||||
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache(
|
||||
q=query.view(query.shape[0], -1, head_size),
|
||||
k=key.view(key.shape[0], -1, head_size),
|
||||
q_rope=query.view(query.shape[0], -1, head_size),
|
||||
k_rope=key.view(key.shape[0], -1, head_size),
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
pos_ids=positions,
|
||||
interleave=(not is_neox),
|
||||
cuda_stream=get_cuda_stream(),
|
||||
)
|
||||
78
sgl-kernel/src/sgl-kernel/ops/allreduce.py
Normal file
78
sgl-kernel/src/sgl-kernel/ops/allreduce.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
|
||||
if torch.version.hip is not None:
|
||||
# ROCM custom allreduce
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernels.init_custom_ar(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out)
|
||||
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops.sgl_kernels.dispose(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernels.meta_size()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernels.allocate_meta_buffer(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp)
|
||||
|
||||
else:
|
||||
# TRTLLM custom allreduce
|
||||
def init_custom_reduce(
|
||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||
):
|
||||
return torch.ops.sgl_kernels.init_custom_ar(
|
||||
rank_id,
|
||||
num_devices,
|
||||
rank_data,
|
||||
buffers,
|
||||
tmp_buffers,
|
||||
barrier_in,
|
||||
barrier_out,
|
||||
)
|
||||
|
||||
def custom_dispose(fa):
|
||||
torch.ops.sgl_kernels.dispose(fa)
|
||||
|
||||
def custom_reduce(fa, inp, out):
|
||||
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa):
|
||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||
|
||||
def register_graph_buffers(fa, handles, offsets):
|
||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||
8
sgl-kernel/src/sgl-kernel/ops/attention.py
Normal file
8
sgl-kernel/src/sgl-kernel/ops/attention.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
|
||||
|
||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
torch.ops.sgl_kernels.lightning_attention_decode(
|
||||
q, k, v, past_kv, slope, output, new_kv
|
||||
)
|
||||
111
sgl-kernel/src/sgl-kernel/ops/gemm.py
Normal file
111
sgl-kernel/src/sgl-kernel/ops/gemm.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
from sgl_kernel.ops.utils import _get_cache_buf, get_cuda_stream
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernels.int8_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
||||
return torch.ops.sgl_kernels.fp8_blockwise_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
)
|
||||
|
||||
|
||||
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernels.fp8_scaled_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
def _bmm_fp8_internal(
|
||||
workspace_buffer: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
D: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
) -> None:
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernels.bmm_fp8(
|
||||
A,
|
||||
B,
|
||||
D,
|
||||
A_scale,
|
||||
B_scale,
|
||||
workspace_buffer,
|
||||
cublas_handle,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def bmm_fp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty(
|
||||
(A.shape[0], A.shape[1], B.shape[2]),
|
||||
device=A.device,
|
||||
dtype=dtype,
|
||||
)
|
||||
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
|
||||
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
|
||||
return out
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float,
|
||||
fp8_min: float,
|
||||
fp8_max: float,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
||||
)
|
||||
|
||||
|
||||
def cublas_grouped_gemm(
|
||||
inputs: List[torch.Tensor],
|
||||
weights: List[torch.Tensor],
|
||||
outputs: List[torch.Tensor],
|
||||
out_dtype: torch.dtype,
|
||||
) -> None:
|
||||
assert (
|
||||
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
||||
), "Inputs/weights/outputs should not be empty!"
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernels.cublas_grouped_gemm(
|
||||
inputs,
|
||||
weights,
|
||||
outputs,
|
||||
out_dtype,
|
||||
cublas_handle,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
24
sgl-kernel/src/sgl-kernel/ops/moe.py
Normal file
24
sgl-kernel/src/sgl-kernel/ops/moe.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
):
|
||||
torch.ops.sgl_kernels.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
token_cnts_buffer,
|
||||
cumsum_buffer,
|
||||
)
|
||||
211
sgl-kernel/src/sgl-kernel/ops/sampling.py
Normal file
211
sgl-kernel/src/sgl-kernel/ops/sampling.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
from sgl_kernel.ops.utils import _to_tensor_scalar_tuple, get_cuda_stream
|
||||
|
||||
|
||||
def _top_k_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
) -> torch.Tensor:
|
||||
probs = probs.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernels.top_k_renorm_probs_wrapper(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_k_arr,
|
||||
top_k_val,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_k_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
) -> torch.Tensor:
|
||||
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
|
||||
|
||||
|
||||
top_k_renorm_prob = top_k_renorm_probs
|
||||
|
||||
|
||||
def _top_p_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
) -> torch.Tensor:
|
||||
probs = probs.float()
|
||||
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernels.top_p_renorm_probs(
|
||||
probs,
|
||||
renorm_probs,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_p_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
) -> torch.Tensor:
|
||||
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
|
||||
|
||||
|
||||
top_p_renorm_prob = top_p_renorm_probs
|
||||
|
||||
|
||||
def _top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
torch.ops.sgl_kernels.top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
success,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
return samples, success
|
||||
|
||||
|
||||
def top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_p_sampling_from_probs_internal(
|
||||
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic
|
||||
)
|
||||
|
||||
|
||||
def _top_k_top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||
torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
success,
|
||||
maybe_top_k_arr,
|
||||
top_k_val,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
return samples, success
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
top_p: Union[torch.Tensor, float],
|
||||
filter_apply_order: str = "top_k_first",
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if filter_apply_order == "top_k_first":
|
||||
renorm_probs = top_k_renorm_probs(probs, top_k)
|
||||
return top_p_sampling_from_probs(
|
||||
renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan
|
||||
)
|
||||
elif filter_apply_order == "joint":
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_k_top_p_sampling_from_probs_internal(
|
||||
probs,
|
||||
uniform_samples,
|
||||
*_to_tensor_scalar_tuple(top_k),
|
||||
*_to_tensor_scalar_tuple(top_p),
|
||||
deterministic,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
||||
|
||||
|
||||
def _min_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
maybe_min_p_arr: Optional[torch.Tensor],
|
||||
min_p_val: float,
|
||||
deterministic: bool,
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
uniform_samples = uniform_samples.float()
|
||||
maybe_min_p_arr = (
|
||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
torch.ops.sgl_kernels.min_p_sampling_from_probs(
|
||||
probs,
|
||||
uniform_samples,
|
||||
samples,
|
||||
maybe_min_p_arr,
|
||||
min_p_val,
|
||||
deterministic,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def min_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
min_p: Union[torch.Tensor, float],
|
||||
deterministic: bool = True,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if uniform_samples.dim() == 2:
|
||||
# Take the first row (round) of uniform_samples
|
||||
uniform_samples = uniform_samples[0]
|
||||
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _min_p_sampling_from_probs_internal(
|
||||
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
|
||||
)
|
||||
84
sgl-kernel/src/sgl-kernel/ops/speculative.py
Normal file
84
sgl-kernel/src/sgl-kernel/ops/speculative.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import sgl_kernel.ops._kernels
|
||||
import torch
|
||||
from sgl_kernel.ops.utils import get_cuda_stream
|
||||
|
||||
|
||||
def tree_speculative_sampling_target_only(
|
||||
predicts: torch.Tensor, # mutable
|
||||
accept_index: torch.Tensor, # mutable
|
||||
accept_token_num: torch.Tensor, # mutable
|
||||
candidates: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
deterministic: bool = True,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.tree_speculative_sampling_target_only(
|
||||
predicts,
|
||||
accept_index,
|
||||
accept_token_num,
|
||||
candidates,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
uniform_samples,
|
||||
target_probs,
|
||||
draft_probs,
|
||||
deterministic,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel_efficient(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.build_tree_kernel_efficient(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernels.build_tree_kernel(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
)
|
||||
@@ -18,8 +18,8 @@ from typing import Dict, Tuple
|
||||
import torch
|
||||
|
||||
|
||||
def _get_cuda_stream(device: torch.device) -> int:
|
||||
return torch.cuda.current_stream(device).cuda_stream
|
||||
def get_cuda_stream() -> int:
|
||||
return torch.cuda.current_stream().cuda_stream
|
||||
|
||||
|
||||
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
|
||||
|
||||
@@ -7,9 +7,9 @@ import unittest
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import ray
|
||||
import sgl_kernel.ops.allreduce as custom_ops
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from sgl_kernel import ops as custom_ops
|
||||
from torch.distributed import ProcessGroup
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
|
||||
Reference in New Issue
Block a user