Reorganize python source files in sgl-kernel with multiple files (#4027)
This commit is contained in:
@@ -9,105 +9,37 @@ if os.path.exists("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12"):
|
|||||||
mode=ctypes.RTLD_GLOBAL,
|
mode=ctypes.RTLD_GLOBAL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sgl_kernel.ops.activation import (
|
||||||
|
apply_rope_with_cos_sin_cache_inplace,
|
||||||
|
fused_add_rmsnorm,
|
||||||
|
gelu_and_mul,
|
||||||
|
gelu_tanh_and_mul,
|
||||||
|
gemma_fused_add_rmsnorm,
|
||||||
|
gemma_rmsnorm,
|
||||||
|
rmsnorm,
|
||||||
|
silu_and_mul,
|
||||||
|
)
|
||||||
|
from sgl_kernel.ops.allreduce import *
|
||||||
|
from sgl_kernel.ops.attention import lightning_attention_decode
|
||||||
|
from sgl_kernel.ops.gemm import (
|
||||||
|
bmm_fp8,
|
||||||
|
cublas_grouped_gemm,
|
||||||
|
fp8_blockwise_scaled_mm,
|
||||||
|
fp8_scaled_mm,
|
||||||
|
int8_scaled_mm,
|
||||||
|
sgl_per_token_group_quant_fp8,
|
||||||
|
)
|
||||||
|
from sgl_kernel.ops.moe import moe_align_block_size
|
||||||
|
from sgl_kernel.ops.sampling import (
|
||||||
|
min_p_sampling_from_probs,
|
||||||
|
top_k_renorm_prob,
|
||||||
|
top_k_top_p_sampling_from_probs,
|
||||||
|
top_p_renorm_prob,
|
||||||
|
top_p_sampling_from_probs,
|
||||||
|
)
|
||||||
|
from sgl_kernel.ops.speculative import (
|
||||||
|
build_tree_kernel,
|
||||||
|
build_tree_kernel_efficient,
|
||||||
|
tree_speculative_sampling_target_only,
|
||||||
|
)
|
||||||
from sgl_kernel.version import __version__
|
from sgl_kernel.version import __version__
|
||||||
|
|
||||||
if torch.version.cuda:
|
|
||||||
from sgl_kernel.ops import (
|
|
||||||
apply_rope_with_cos_sin_cache_inplace,
|
|
||||||
bmm_fp8,
|
|
||||||
build_tree_kernel,
|
|
||||||
build_tree_kernel_efficient,
|
|
||||||
cublas_grouped_gemm,
|
|
||||||
custom_dispose,
|
|
||||||
custom_reduce,
|
|
||||||
fp8_blockwise_scaled_mm,
|
|
||||||
fp8_scaled_mm,
|
|
||||||
fused_add_rmsnorm,
|
|
||||||
gelu_and_mul,
|
|
||||||
gelu_tanh_and_mul,
|
|
||||||
gemma_fused_add_rmsnorm,
|
|
||||||
gemma_rmsnorm,
|
|
||||||
get_graph_buffer_ipc_meta,
|
|
||||||
init_custom_reduce,
|
|
||||||
int8_scaled_mm,
|
|
||||||
lightning_attention_decode,
|
|
||||||
min_p_sampling_from_probs,
|
|
||||||
moe_align_block_size,
|
|
||||||
register_graph_buffers,
|
|
||||||
rmsnorm,
|
|
||||||
sampling_scaling_penalties,
|
|
||||||
sgl_per_token_group_quant_fp8,
|
|
||||||
silu_and_mul,
|
|
||||||
top_k_renorm_prob,
|
|
||||||
top_k_top_p_sampling_from_probs,
|
|
||||||
top_p_renorm_prob,
|
|
||||||
tree_speculative_sampling_target_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert torch.version.hip
|
|
||||||
|
|
||||||
from sgl_kernel.ops import (
|
|
||||||
all_reduce_reg,
|
|
||||||
all_reduce_unreg,
|
|
||||||
allocate_meta_buffer,
|
|
||||||
apply_rope_with_cos_sin_cache_inplace,
|
|
||||||
bmm_fp8,
|
|
||||||
dispose,
|
|
||||||
fp8_scaled_mm,
|
|
||||||
fused_add_rmsnorm,
|
|
||||||
gelu_and_mul,
|
|
||||||
gelu_tanh_and_mul,
|
|
||||||
gemma_fused_add_rmsnorm,
|
|
||||||
gemma_rmsnorm,
|
|
||||||
get_graph_buffer_ipc_meta,
|
|
||||||
get_meta_buffer_ipc_handle,
|
|
||||||
init_custom_ar,
|
|
||||||
int8_scaled_mm,
|
|
||||||
lightning_attention_decode,
|
|
||||||
meta_size,
|
|
||||||
min_p_sampling_from_probs,
|
|
||||||
moe_align_block_size,
|
|
||||||
register_buffer,
|
|
||||||
register_graph_buffers,
|
|
||||||
rmsnorm,
|
|
||||||
sampling_scaling_penalties,
|
|
||||||
silu_and_mul,
|
|
||||||
top_k_renorm_prob,
|
|
||||||
top_k_top_p_sampling_from_probs,
|
|
||||||
top_p_renorm_prob,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"__version__",
|
|
||||||
"apply_rope_with_cos_sin_cache_inplace",
|
|
||||||
"bmm_fp8",
|
|
||||||
"cublas_grouped_gemm",
|
|
||||||
"custom_dispose",
|
|
||||||
"custom_reduce",
|
|
||||||
"build_tree_kernel_efficient",
|
|
||||||
"build_tree_kernel",
|
|
||||||
"fp8_blockwise_scaled_mm",
|
|
||||||
"fp8_scaled_mm",
|
|
||||||
"fused_add_rmsnorm",
|
|
||||||
"gelu_and_mul",
|
|
||||||
"gelu_tanh_and_mul",
|
|
||||||
"gemma_fused_add_rmsnorm",
|
|
||||||
"gemma_rmsnorm",
|
|
||||||
"get_graph_buffer_ipc_meta",
|
|
||||||
"init_custom_reduce",
|
|
||||||
"int8_scaled_mm",
|
|
||||||
"lightning_attention_decode",
|
|
||||||
"min_p_sampling_from_probs",
|
|
||||||
"moe_align_block_size",
|
|
||||||
"register_graph_buffers",
|
|
||||||
"rmsnorm",
|
|
||||||
"sampling_scaling_penalties",
|
|
||||||
"sgl_per_token_group_quant_fp8",
|
|
||||||
"silu_and_mul",
|
|
||||||
"top_k_renorm_prob",
|
|
||||||
"top_k_top_p_sampling_from_probs",
|
|
||||||
"top_p_renorm_prob",
|
|
||||||
"tree_speculative_sampling_target_only",
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,677 +0,0 @@
|
|||||||
import os
|
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import sgl_kernel.ops._kernels
|
|
||||||
import torch
|
|
||||||
from sgl_kernel.ops.utils import (
|
|
||||||
_get_cache_buf,
|
|
||||||
_get_cuda_stream,
|
|
||||||
_to_tensor_scalar_tuple,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rope_with_cos_sin_cache_inplace(
|
|
||||||
positions: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
head_size: int,
|
|
||||||
cos_sin_cache: torch.Tensor,
|
|
||||||
is_neox: bool = True,
|
|
||||||
) -> None:
|
|
||||||
r"""
|
|
||||||
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
|
||||||
This is designed to be compatible with the SGL/vLLM implementation.
|
|
||||||
The result is inplace applied to the input tensors.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
positions : torch.Tensor
|
|
||||||
Position indices, shape: ``(nnz)``.
|
|
||||||
query : torch.Tensor
|
|
||||||
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
|
|
||||||
key : torch.Tensor
|
|
||||||
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
|
|
||||||
cos_sin_cache : torch.Tensor
|
|
||||||
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
|
|
||||||
Cosine is the first half and Sine is the second half on rotary_dim.
|
|
||||||
is_neox : bool
|
|
||||||
Whether to use Neox style RoPE, default: ``True``.
|
|
||||||
|
|
||||||
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
|
||||||
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
|
||||||
dimensions ``([..., head_dim//2:])``.
|
|
||||||
|
|
||||||
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
|
||||||
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
|
||||||
Note
|
|
||||||
----
|
|
||||||
The rotary dimension is determined by the cosine cache and sine cache.
|
|
||||||
"""
|
|
||||||
if cos_sin_cache.dtype != torch.float32:
|
|
||||||
raise ValueError("cos_sin_cache should be float32")
|
|
||||||
|
|
||||||
with query.device as device:
|
|
||||||
positions = positions.int()
|
|
||||||
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache(
|
|
||||||
q=query.view(query.shape[0], -1, head_size),
|
|
||||||
k=key.view(key.shape[0], -1, head_size),
|
|
||||||
q_rope=query.view(query.shape[0], -1, head_size),
|
|
||||||
k_rope=key.view(key.shape[0], -1, head_size),
|
|
||||||
cos_sin_cache=cos_sin_cache,
|
|
||||||
pos_ids=positions,
|
|
||||||
interleave=(not is_neox),
|
|
||||||
cuda_stream=_get_cuda_stream(device),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if torch.version.hip is not None:
|
|
||||||
|
|
||||||
def init_custom_ar(
|
|
||||||
meta: torch.Tensor,
|
|
||||||
rank_data: torch.Tensor,
|
|
||||||
handles: List[str],
|
|
||||||
offsets: List[int],
|
|
||||||
rank: int,
|
|
||||||
full_nvlink: bool,
|
|
||||||
) -> int:
|
|
||||||
return torch.ops.sgl_kernels.init_custom_ar(
|
|
||||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
|
||||||
)
|
|
||||||
|
|
||||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
|
||||||
torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out)
|
|
||||||
|
|
||||||
def all_reduce_unreg(
|
|
||||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
|
||||||
) -> None:
|
|
||||||
torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out)
|
|
||||||
|
|
||||||
def dispose(fa: int) -> None:
|
|
||||||
torch.ops.sgl_kernels.dispose(fa)
|
|
||||||
|
|
||||||
def meta_size() -> int:
|
|
||||||
return torch.ops.sgl_kernels.meta_size()
|
|
||||||
|
|
||||||
def register_buffer(
|
|
||||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
|
||||||
) -> None:
|
|
||||||
return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets)
|
|
||||||
|
|
||||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
|
||||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
|
||||||
|
|
||||||
def register_graph_buffers(
|
|
||||||
fa: int, handles: List[str], offsets: List[List[int]]
|
|
||||||
) -> None:
|
|
||||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
|
||||||
|
|
||||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
|
||||||
return torch.ops.sgl_kernels.allocate_meta_buffer(size)
|
|
||||||
|
|
||||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
|
||||||
return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# trt_reduce
|
|
||||||
def init_custom_reduce(
|
|
||||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
|
||||||
):
|
|
||||||
return torch.ops.sgl_kernels.init_custom_ar(
|
|
||||||
rank_id,
|
|
||||||
num_devices,
|
|
||||||
rank_data,
|
|
||||||
buffers,
|
|
||||||
tmp_buffers,
|
|
||||||
barrier_in,
|
|
||||||
barrier_out,
|
|
||||||
)
|
|
||||||
|
|
||||||
def custom_dispose(fa):
|
|
||||||
torch.ops.sgl_kernels.dispose(fa)
|
|
||||||
|
|
||||||
def custom_reduce(fa, inp, out):
|
|
||||||
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
|
|
||||||
|
|
||||||
def get_graph_buffer_ipc_meta(fa):
|
|
||||||
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
|
||||||
|
|
||||||
def register_graph_buffers(fa, handles, offsets):
|
|
||||||
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(
|
|
||||||
topk_ids,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
sorted_token_ids,
|
|
||||||
experts_ids,
|
|
||||||
num_tokens_post_pad,
|
|
||||||
token_cnts_buffer,
|
|
||||||
cumsum_buffer,
|
|
||||||
):
|
|
||||||
torch.ops.sgl_kernels.moe_align_block_size(
|
|
||||||
topk_ids,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
sorted_token_ids,
|
|
||||||
experts_ids,
|
|
||||||
num_tokens_post_pad,
|
|
||||||
token_cnts_buffer,
|
|
||||||
cumsum_buffer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sampling_scaling_penalties(logits, scaling_penalties):
|
|
||||||
return torch.ops.sgl_kernels.sampling_scaling_penalties(logits, scaling_penalties)
|
|
||||||
|
|
||||||
|
|
||||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
|
||||||
return torch.ops.sgl_kernels.int8_scaled_mm(
|
|
||||||
mat_a,
|
|
||||||
mat_b,
|
|
||||||
scales_a,
|
|
||||||
scales_b,
|
|
||||||
out_dtype,
|
|
||||||
bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
|
||||||
return torch.ops.sgl_kernels.fp8_blockwise_scaled_mm(
|
|
||||||
mat_a,
|
|
||||||
mat_b,
|
|
||||||
scales_a,
|
|
||||||
scales_b,
|
|
||||||
out_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
|
||||||
return torch.ops.sgl_kernels.fp8_scaled_mm(
|
|
||||||
mat_a,
|
|
||||||
mat_b,
|
|
||||||
scales_a,
|
|
||||||
scales_b,
|
|
||||||
out_dtype,
|
|
||||||
bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
|
||||||
torch.ops.sgl_kernels.lightning_attention_decode(
|
|
||||||
q, k, v, past_kv, slope, output, new_kv
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
|
||||||
# Kudos to @yzh119
|
|
||||||
def rmsnorm(
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
out: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
with input.device as device:
|
|
||||||
if out is None:
|
|
||||||
out = torch.empty_like(input)
|
|
||||||
torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def fused_add_rmsnorm(
|
|
||||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
|
||||||
) -> None:
|
|
||||||
with input.device as device:
|
|
||||||
torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps)
|
|
||||||
|
|
||||||
|
|
||||||
def gemma_rmsnorm(
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
out: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
with input.device as device:
|
|
||||||
if out is None:
|
|
||||||
out = torch.empty_like(input)
|
|
||||||
torch.ops.sgl_kernels.gemma_rmsnorm(
|
|
||||||
out, input, weight, eps, _get_cuda_stream(device)
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def gemma_fused_add_rmsnorm(
|
|
||||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
|
||||||
) -> None:
|
|
||||||
with input.device as device:
|
|
||||||
torch.ops.sgl_kernels.gemma_fused_add_rmsnorm(
|
|
||||||
input, residual, weight, eps, _get_cuda_stream(device)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
|
||||||
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
|
|
||||||
assert (
|
|
||||||
input.shape[:-1] == output.shape[:-1]
|
|
||||||
), f"{input.shape[:-1]} != {output.shape[:-1]}"
|
|
||||||
assert (
|
|
||||||
input.shape[-1] == 2 * output.shape[-1]
|
|
||||||
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
|
|
||||||
|
|
||||||
|
|
||||||
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
|
||||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
|
||||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
|
||||||
if out is not None:
|
|
||||||
_check_shape(input, out)
|
|
||||||
else:
|
|
||||||
out = torch.empty(
|
|
||||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
|
||||||
device=input.device,
|
|
||||||
dtype=input.dtype,
|
|
||||||
)
|
|
||||||
with input.device as device:
|
|
||||||
torch.ops.sgl_kernels.silu_and_mul(out, input, _get_cuda_stream(device))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
|
||||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
|
||||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
|
||||||
if out is not None:
|
|
||||||
_check_shape(input, out)
|
|
||||||
else:
|
|
||||||
out = torch.empty(
|
|
||||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
|
||||||
device=input.device,
|
|
||||||
dtype=input.dtype,
|
|
||||||
)
|
|
||||||
with input.device as device:
|
|
||||||
torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, _get_cuda_stream(device))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
|
||||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
|
||||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
|
||||||
if out is not None:
|
|
||||||
_check_shape(input, out)
|
|
||||||
else:
|
|
||||||
out = torch.empty(
|
|
||||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
|
||||||
device=input.device,
|
|
||||||
dtype=input.dtype,
|
|
||||||
)
|
|
||||||
with input.device as device:
|
|
||||||
torch.ops.sgl_kernels.gelu_and_mul(out, input, _get_cuda_stream(device))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _bmm_fp8_internal(
|
|
||||||
workspace_buffer: torch.Tensor,
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
D: torch.Tensor,
|
|
||||||
A_scale: torch.Tensor,
|
|
||||||
B_scale: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
with A.device as device:
|
|
||||||
cublas_handle = torch.cuda.current_blas_handle()
|
|
||||||
torch.ops.sgl_kernels.bmm_fp8(
|
|
||||||
A,
|
|
||||||
B,
|
|
||||||
D,
|
|
||||||
A_scale,
|
|
||||||
B_scale,
|
|
||||||
workspace_buffer,
|
|
||||||
cublas_handle,
|
|
||||||
_get_cuda_stream(device),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def bmm_fp8(
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
A_scale: torch.Tensor,
|
|
||||||
B_scale: torch.Tensor,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
out: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if out is None:
|
|
||||||
out = torch.empty(
|
|
||||||
(A.shape[0], A.shape[1], B.shape[2]),
|
|
||||||
device=A.device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
|
|
||||||
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _top_k_renorm_probs_internal(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
maybe_top_k_arr: Optional[torch.Tensor],
|
|
||||||
top_k_val: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
with probs.device as device:
|
|
||||||
probs = probs.float()
|
|
||||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
|
||||||
renorm_probs = torch.empty_like(probs)
|
|
||||||
torch.ops.sgl_kernels.top_k_renorm_probs_wrapper(
|
|
||||||
probs,
|
|
||||||
renorm_probs,
|
|
||||||
maybe_top_k_arr,
|
|
||||||
top_k_val,
|
|
||||||
_get_cuda_stream(device),
|
|
||||||
)
|
|
||||||
return renorm_probs
|
|
||||||
|
|
||||||
|
|
||||||
def top_k_renorm_probs(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
top_k: Union[torch.Tensor, int],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
|
|
||||||
|
|
||||||
|
|
||||||
top_k_renorm_prob = top_k_renorm_probs
|
|
||||||
|
|
||||||
|
|
||||||
def _top_p_renorm_probs_internal(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
maybe_top_p_arr: Optional[torch.Tensor],
|
|
||||||
top_p_val: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
with probs.device as device:
|
|
||||||
probs = probs.float()
|
|
||||||
maybe_top_p_arr = (
|
|
||||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
|
||||||
)
|
|
||||||
renorm_probs = torch.empty_like(probs)
|
|
||||||
torch.ops.sgl_kernels.top_p_renorm_probs(
|
|
||||||
probs,
|
|
||||||
renorm_probs,
|
|
||||||
maybe_top_p_arr,
|
|
||||||
top_p_val,
|
|
||||||
_get_cuda_stream(device),
|
|
||||||
)
|
|
||||||
return renorm_probs
|
|
||||||
|
|
||||||
|
|
||||||
def top_p_renorm_probs(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
top_p: Union[torch.Tensor, float],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
|
|
||||||
|
|
||||||
|
|
||||||
top_p_renorm_prob = top_p_renorm_probs
|
|
||||||
|
|
||||||
|
|
||||||
def _top_p_sampling_from_probs_internal(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
uniform_samples: torch.Tensor,
|
|
||||||
maybe_top_p_arr: Optional[torch.Tensor],
|
|
||||||
top_p_val: float,
|
|
||||||
deterministic: bool,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
with probs.device as device:
|
|
||||||
probs = probs.float()
|
|
||||||
uniform_samples = uniform_samples.float()
|
|
||||||
maybe_top_p_arr = (
|
|
||||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
|
||||||
)
|
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
|
||||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
|
||||||
torch.ops.sgl_kernels.top_p_sampling_from_probs(
|
|
||||||
probs,
|
|
||||||
uniform_samples,
|
|
||||||
samples,
|
|
||||||
success,
|
|
||||||
maybe_top_p_arr,
|
|
||||||
top_p_val,
|
|
||||||
deterministic,
|
|
||||||
_get_cuda_stream(device),
|
|
||||||
)
|
|
||||||
return samples, success
|
|
||||||
|
|
||||||
|
|
||||||
def top_p_sampling_from_probs(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
uniform_samples: torch.Tensor,
|
|
||||||
top_p: Union[torch.Tensor, float],
|
|
||||||
deterministic: bool = True,
|
|
||||||
check_nan: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
if check_nan:
|
|
||||||
if torch.any(torch.isnan(probs)):
|
|
||||||
raise ValueError("Input probs contains NaN.")
|
|
||||||
return _top_p_sampling_from_probs_internal(
|
|
||||||
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _top_k_top_p_sampling_from_probs_internal(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
uniform_samples: torch.Tensor,
|
|
||||||
maybe_top_k_arr: Optional[torch.Tensor],
|
|
||||||
top_k_val: int,
|
|
||||||
maybe_top_p_arr: Optional[torch.Tensor],
|
|
||||||
top_p_val: float,
|
|
||||||
deterministic: bool,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
with probs.device as device:
|
|
||||||
probs = probs.float()
|
|
||||||
uniform_samples = uniform_samples.float()
|
|
||||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
|
||||||
maybe_top_p_arr = (
|
|
||||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
|
||||||
)
|
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
|
||||||
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
|
||||||
torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs(
|
|
||||||
probs,
|
|
||||||
uniform_samples,
|
|
||||||
samples,
|
|
||||||
success,
|
|
||||||
maybe_top_k_arr,
|
|
||||||
top_k_val,
|
|
||||||
maybe_top_p_arr,
|
|
||||||
top_p_val,
|
|
||||||
deterministic,
|
|
||||||
_get_cuda_stream(device),
|
|
||||||
)
|
|
||||||
return samples, success
|
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_sampling_from_probs(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
uniform_samples: torch.Tensor,
|
|
||||||
top_k: Union[torch.Tensor, int],
|
|
||||||
top_p: Union[torch.Tensor, float],
|
|
||||||
filter_apply_order: str = "top_k_first",
|
|
||||||
deterministic: bool = True,
|
|
||||||
check_nan: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
if filter_apply_order == "top_k_first":
|
|
||||||
renorm_probs = top_k_renorm_probs(probs, top_k)
|
|
||||||
return top_p_sampling_from_probs(
|
|
||||||
renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan
|
|
||||||
)
|
|
||||||
elif filter_apply_order == "joint":
|
|
||||||
if check_nan:
|
|
||||||
if torch.any(torch.isnan(probs)):
|
|
||||||
raise ValueError("Input probs contains NaN.")
|
|
||||||
return _top_k_top_p_sampling_from_probs_internal(
|
|
||||||
probs,
|
|
||||||
uniform_samples,
|
|
||||||
*_to_tensor_scalar_tuple(top_k),
|
|
||||||
*_to_tensor_scalar_tuple(top_p),
|
|
||||||
deterministic,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
|
||||||
|
|
||||||
|
|
||||||
def _min_p_sampling_from_probs_internal(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
uniform_samples: torch.Tensor,
|
|
||||||
maybe_min_p_arr: Optional[torch.Tensor],
|
|
||||||
min_p_val: float,
|
|
||||||
deterministic: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
with probs.device as device:
|
|
||||||
probs = probs.float()
|
|
||||||
uniform_samples = uniform_samples.float()
|
|
||||||
maybe_min_p_arr = (
|
|
||||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
|
||||||
)
|
|
||||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
|
||||||
torch.ops.sgl_kernels.min_p_sampling_from_probs(
|
|
||||||
probs,
|
|
||||||
uniform_samples,
|
|
||||||
samples,
|
|
||||||
maybe_min_p_arr,
|
|
||||||
min_p_val,
|
|
||||||
deterministic,
|
|
||||||
_get_cuda_stream(device),
|
|
||||||
)
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def min_p_sampling_from_probs(
|
|
||||||
probs: torch.Tensor,
|
|
||||||
uniform_samples: torch.Tensor,
|
|
||||||
min_p: Union[torch.Tensor, float],
|
|
||||||
deterministic: bool = True,
|
|
||||||
check_nan: bool = False,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if uniform_samples.dim() == 2:
|
|
||||||
# Take the first row (round) of uniform_samples
|
|
||||||
uniform_samples = uniform_samples[0]
|
|
||||||
|
|
||||||
if check_nan:
|
|
||||||
if torch.any(torch.isnan(probs)):
|
|
||||||
raise ValueError("Input probs contains NaN.")
|
|
||||||
return _min_p_sampling_from_probs_internal(
|
|
||||||
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def tree_speculative_sampling_target_only(
|
|
||||||
predicts: torch.Tensor, # mutable
|
|
||||||
accept_index: torch.Tensor, # mutable
|
|
||||||
accept_token_num: torch.Tensor, # mutable
|
|
||||||
candidates: torch.Tensor,
|
|
||||||
retrive_index: torch.Tensor,
|
|
||||||
retrive_next_token: torch.Tensor,
|
|
||||||
retrive_next_sibling: torch.Tensor,
|
|
||||||
uniform_samples: torch.Tensor,
|
|
||||||
target_probs: torch.Tensor,
|
|
||||||
draft_probs: torch.Tensor,
|
|
||||||
deterministic: bool = True,
|
|
||||||
) -> None:
|
|
||||||
with predicts.device as device:
|
|
||||||
torch.ops.sgl_kernels.tree_speculative_sampling_target_only(
|
|
||||||
predicts,
|
|
||||||
accept_index,
|
|
||||||
accept_token_num,
|
|
||||||
candidates,
|
|
||||||
retrive_index,
|
|
||||||
retrive_next_token,
|
|
||||||
retrive_next_sibling,
|
|
||||||
uniform_samples,
|
|
||||||
target_probs,
|
|
||||||
draft_probs,
|
|
||||||
deterministic,
|
|
||||||
_get_cuda_stream(device),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel_efficient(
|
|
||||||
parent_list: torch.Tensor,
|
|
||||||
selected_index: torch.Tensor,
|
|
||||||
verified_seq_len: torch.Tensor,
|
|
||||||
tree_mask: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
retrive_index: torch.Tensor,
|
|
||||||
retrive_next_token: torch.Tensor,
|
|
||||||
retrive_next_sibling: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
depth: int,
|
|
||||||
draft_token_num: int,
|
|
||||||
) -> None:
|
|
||||||
with parent_list.device as device:
|
|
||||||
torch.ops.sgl_kernels.build_tree_kernel_efficient(
|
|
||||||
parent_list,
|
|
||||||
selected_index,
|
|
||||||
verified_seq_len,
|
|
||||||
tree_mask,
|
|
||||||
positions,
|
|
||||||
retrive_index,
|
|
||||||
retrive_next_token,
|
|
||||||
retrive_next_sibling,
|
|
||||||
topk,
|
|
||||||
depth,
|
|
||||||
draft_token_num,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel(
|
|
||||||
parent_list: torch.Tensor,
|
|
||||||
selected_index: torch.Tensor,
|
|
||||||
verified_seq_len: torch.Tensor,
|
|
||||||
tree_mask: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
retrive_index: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
depth: int,
|
|
||||||
draft_token_num: int,
|
|
||||||
) -> None:
|
|
||||||
with parent_list.device as device:
|
|
||||||
torch.ops.sgl_kernels.build_tree_kernel(
|
|
||||||
parent_list,
|
|
||||||
selected_index,
|
|
||||||
verified_seq_len,
|
|
||||||
tree_mask,
|
|
||||||
positions,
|
|
||||||
retrive_index,
|
|
||||||
topk,
|
|
||||||
depth,
|
|
||||||
draft_token_num,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sgl_per_token_group_quant_fp8(
|
|
||||||
input: torch.Tensor,
|
|
||||||
output_q: torch.Tensor,
|
|
||||||
output_s: torch.Tensor,
|
|
||||||
group_size: int,
|
|
||||||
eps: float,
|
|
||||||
fp8_min: float,
|
|
||||||
fp8_max: float,
|
|
||||||
) -> None:
|
|
||||||
torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8(
|
|
||||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def cublas_grouped_gemm(
|
|
||||||
inputs: List[torch.Tensor],
|
|
||||||
weights: List[torch.Tensor],
|
|
||||||
outputs: List[torch.Tensor],
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
) -> None:
|
|
||||||
with inputs[0].device as device:
|
|
||||||
assert (
|
|
||||||
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
|
||||||
), "Inputs/weights/outputs should not be empty!"
|
|
||||||
cublas_handle = torch.cuda.current_blas_handle()
|
|
||||||
torch.ops.sgl_kernels.cublas_grouped_gemm(
|
|
||||||
inputs,
|
|
||||||
weights,
|
|
||||||
outputs,
|
|
||||||
out_dtype,
|
|
||||||
cublas_handle,
|
|
||||||
_get_cuda_stream(device),
|
|
||||||
)
|
|
||||||
153
sgl-kernel/src/sgl-kernel/ops/activation.py
Normal file
153
sgl-kernel/src/sgl-kernel/ops/activation.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import sgl_kernel.ops._kernels
|
||||||
|
import torch
|
||||||
|
from sgl_kernel.ops.utils import get_cuda_stream
|
||||||
|
|
||||||
|
|
||||||
|
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||||
|
# Kudos to @yzh119
|
||||||
|
def rmsnorm(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if out is None:
|
||||||
|
out = torch.empty_like(input)
|
||||||
|
torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def fused_add_rmsnorm(
|
||||||
|
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||||
|
) -> None:
|
||||||
|
torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps)
|
||||||
|
|
||||||
|
|
||||||
|
def gemma_rmsnorm(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if out is None:
|
||||||
|
out = torch.empty_like(input)
|
||||||
|
torch.ops.sgl_kernels.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def gemma_fused_add_rmsnorm(
|
||||||
|
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||||
|
) -> None:
|
||||||
|
torch.ops.sgl_kernels.gemma_fused_add_rmsnorm(
|
||||||
|
input, residual, weight, eps, get_cuda_stream()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
||||||
|
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
|
||||||
|
assert (
|
||||||
|
input.shape[:-1] == output.shape[:-1]
|
||||||
|
), f"{input.shape[:-1]} != {output.shape[:-1]}"
|
||||||
|
assert (
|
||||||
|
input.shape[-1] == 2 * output.shape[-1]
|
||||||
|
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
|
||||||
|
|
||||||
|
|
||||||
|
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||||
|
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||||
|
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||||
|
if out is not None:
|
||||||
|
_check_shape(input, out)
|
||||||
|
else:
|
||||||
|
out = torch.empty(
|
||||||
|
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype,
|
||||||
|
)
|
||||||
|
torch.ops.sgl_kernels.silu_and_mul(out, input, get_cuda_stream())
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||||
|
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||||
|
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||||
|
if out is not None:
|
||||||
|
_check_shape(input, out)
|
||||||
|
else:
|
||||||
|
out = torch.empty(
|
||||||
|
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype,
|
||||||
|
)
|
||||||
|
torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, get_cuda_stream())
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||||
|
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||||
|
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||||
|
if out is not None:
|
||||||
|
_check_shape(input, out)
|
||||||
|
else:
|
||||||
|
out = torch.empty(
|
||||||
|
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||||
|
device=input.device,
|
||||||
|
dtype=input.dtype,
|
||||||
|
)
|
||||||
|
torch.ops.sgl_kernels.gelu_and_mul(out, input, get_cuda_stream())
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope_with_cos_sin_cache_inplace(
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
head_size: int,
|
||||||
|
cos_sin_cache: torch.Tensor,
|
||||||
|
is_neox: bool = True,
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
||||||
|
This is designed to be compatible with the SGL/vLLM implementation.
|
||||||
|
The result is inplace applied to the input tensors.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
positions : torch.Tensor
|
||||||
|
Position indices, shape: ``(nnz)``.
|
||||||
|
query : torch.Tensor
|
||||||
|
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
|
||||||
|
key : torch.Tensor
|
||||||
|
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
|
||||||
|
cos_sin_cache : torch.Tensor
|
||||||
|
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
|
||||||
|
Cosine is the first half and Sine is the second half on rotary_dim.
|
||||||
|
is_neox : bool
|
||||||
|
Whether to use Neox style RoPE, default: ``True``.
|
||||||
|
|
||||||
|
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
||||||
|
we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
||||||
|
dimensions ``([..., head_dim//2:])``.
|
||||||
|
|
||||||
|
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
||||||
|
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
The rotary dimension is determined by the cosine cache and sine cache.
|
||||||
|
"""
|
||||||
|
if cos_sin_cache.dtype != torch.float32:
|
||||||
|
raise ValueError("cos_sin_cache should be float32")
|
||||||
|
|
||||||
|
positions = positions.int()
|
||||||
|
torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache(
|
||||||
|
q=query.view(query.shape[0], -1, head_size),
|
||||||
|
k=key.view(key.shape[0], -1, head_size),
|
||||||
|
q_rope=query.view(query.shape[0], -1, head_size),
|
||||||
|
k_rope=key.view(key.shape[0], -1, head_size),
|
||||||
|
cos_sin_cache=cos_sin_cache,
|
||||||
|
pos_ids=positions,
|
||||||
|
interleave=(not is_neox),
|
||||||
|
cuda_stream=get_cuda_stream(),
|
||||||
|
)
|
||||||
78
sgl-kernel/src/sgl-kernel/ops/allreduce.py
Normal file
78
sgl-kernel/src/sgl-kernel/ops/allreduce.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import sgl_kernel.ops._kernels
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.version.hip is not None:
|
||||||
|
# ROCM custom allreduce
|
||||||
|
def init_custom_ar(
|
||||||
|
meta: torch.Tensor,
|
||||||
|
rank_data: torch.Tensor,
|
||||||
|
handles: List[str],
|
||||||
|
offsets: List[int],
|
||||||
|
rank: int,
|
||||||
|
full_nvlink: bool,
|
||||||
|
) -> int:
|
||||||
|
return torch.ops.sgl_kernels.init_custom_ar(
|
||||||
|
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||||
|
)
|
||||||
|
|
||||||
|
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||||
|
torch.ops.sgl_kernels.all_reduce_reg(fa, inp, out)
|
||||||
|
|
||||||
|
def all_reduce_unreg(
|
||||||
|
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||||
|
) -> None:
|
||||||
|
torch.ops.sgl_kernels.all_reduce_unreg(fa, inp, reg_buffer, out)
|
||||||
|
|
||||||
|
def dispose(fa: int) -> None:
|
||||||
|
torch.ops.sgl_kernels.dispose(fa)
|
||||||
|
|
||||||
|
def meta_size() -> int:
|
||||||
|
return torch.ops.sgl_kernels.meta_size()
|
||||||
|
|
||||||
|
def register_buffer(
|
||||||
|
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||||
|
) -> None:
|
||||||
|
return torch.ops.sgl_kernels.register_buffer(fa, t, handles, offsets)
|
||||||
|
|
||||||
|
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||||
|
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||||
|
|
||||||
|
def register_graph_buffers(
|
||||||
|
fa: int, handles: List[str], offsets: List[List[int]]
|
||||||
|
) -> None:
|
||||||
|
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||||
|
|
||||||
|
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||||
|
return torch.ops.sgl_kernels.allocate_meta_buffer(size)
|
||||||
|
|
||||||
|
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.ops.sgl_kernels.get_meta_buffer_ipc_handle(inp)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# TRTLLM custom allreduce
|
||||||
|
def init_custom_reduce(
|
||||||
|
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||||
|
):
|
||||||
|
return torch.ops.sgl_kernels.init_custom_ar(
|
||||||
|
rank_id,
|
||||||
|
num_devices,
|
||||||
|
rank_data,
|
||||||
|
buffers,
|
||||||
|
tmp_buffers,
|
||||||
|
barrier_in,
|
||||||
|
barrier_out,
|
||||||
|
)
|
||||||
|
|
||||||
|
def custom_dispose(fa):
|
||||||
|
torch.ops.sgl_kernels.dispose(fa)
|
||||||
|
|
||||||
|
def custom_reduce(fa, inp, out):
|
||||||
|
torch.ops.sgl_kernels.all_reduce(fa, inp, out)
|
||||||
|
|
||||||
|
def get_graph_buffer_ipc_meta(fa):
|
||||||
|
return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa)
|
||||||
|
|
||||||
|
def register_graph_buffers(fa, handles, offsets):
|
||||||
|
torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets)
|
||||||
8
sgl-kernel/src/sgl-kernel/ops/attention.py
Normal file
8
sgl-kernel/src/sgl-kernel/ops/attention.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import sgl_kernel.ops._kernels
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||||
|
torch.ops.sgl_kernels.lightning_attention_decode(
|
||||||
|
q, k, v, past_kv, slope, output, new_kv
|
||||||
|
)
|
||||||
111
sgl-kernel/src/sgl-kernel/ops/gemm.py
Normal file
111
sgl-kernel/src/sgl-kernel/ops/gemm.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import sgl_kernel.ops._kernels
|
||||||
|
import torch
|
||||||
|
from sgl_kernel.ops.utils import _get_cache_buf, get_cuda_stream
|
||||||
|
|
||||||
|
|
||||||
|
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||||
|
return torch.ops.sgl_kernels.int8_scaled_mm(
|
||||||
|
mat_a,
|
||||||
|
mat_b,
|
||||||
|
scales_a,
|
||||||
|
scales_b,
|
||||||
|
out_dtype,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
||||||
|
return torch.ops.sgl_kernels.fp8_blockwise_scaled_mm(
|
||||||
|
mat_a,
|
||||||
|
mat_b,
|
||||||
|
scales_a,
|
||||||
|
scales_b,
|
||||||
|
out_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||||
|
return torch.ops.sgl_kernels.fp8_scaled_mm(
|
||||||
|
mat_a,
|
||||||
|
mat_b,
|
||||||
|
scales_a,
|
||||||
|
scales_b,
|
||||||
|
out_dtype,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _bmm_fp8_internal(
|
||||||
|
workspace_buffer: torch.Tensor,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
D: torch.Tensor,
|
||||||
|
A_scale: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
cublas_handle = torch.cuda.current_blas_handle()
|
||||||
|
torch.ops.sgl_kernels.bmm_fp8(
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
D,
|
||||||
|
A_scale,
|
||||||
|
B_scale,
|
||||||
|
workspace_buffer,
|
||||||
|
cublas_handle,
|
||||||
|
get_cuda_stream(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bmm_fp8(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
A_scale: torch.Tensor,
|
||||||
|
B_scale: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if out is None:
|
||||||
|
out = torch.empty(
|
||||||
|
(A.shape[0], A.shape[1], B.shape[2]),
|
||||||
|
device=A.device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
|
||||||
|
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def sgl_per_token_group_quant_fp8(
|
||||||
|
input: torch.Tensor,
|
||||||
|
output_q: torch.Tensor,
|
||||||
|
output_s: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
eps: float,
|
||||||
|
fp8_min: float,
|
||||||
|
fp8_max: float,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8(
|
||||||
|
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cublas_grouped_gemm(
|
||||||
|
inputs: List[torch.Tensor],
|
||||||
|
weights: List[torch.Tensor],
|
||||||
|
outputs: List[torch.Tensor],
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
assert (
|
||||||
|
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
||||||
|
), "Inputs/weights/outputs should not be empty!"
|
||||||
|
cublas_handle = torch.cuda.current_blas_handle()
|
||||||
|
torch.ops.sgl_kernels.cublas_grouped_gemm(
|
||||||
|
inputs,
|
||||||
|
weights,
|
||||||
|
outputs,
|
||||||
|
out_dtype,
|
||||||
|
cublas_handle,
|
||||||
|
get_cuda_stream(),
|
||||||
|
)
|
||||||
24
sgl-kernel/src/sgl-kernel/ops/moe.py
Normal file
24
sgl-kernel/src/sgl-kernel/ops/moe.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import sgl_kernel.ops._kernels
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def moe_align_block_size(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_token_ids,
|
||||||
|
experts_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
token_cnts_buffer,
|
||||||
|
cumsum_buffer,
|
||||||
|
):
|
||||||
|
torch.ops.sgl_kernels.moe_align_block_size(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_token_ids,
|
||||||
|
experts_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
token_cnts_buffer,
|
||||||
|
cumsum_buffer,
|
||||||
|
)
|
||||||
211
sgl-kernel/src/sgl-kernel/ops/sampling.py
Normal file
211
sgl-kernel/src/sgl-kernel/ops/sampling.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import sgl_kernel.ops._kernels
|
||||||
|
import torch
|
||||||
|
from sgl_kernel.ops.utils import _to_tensor_scalar_tuple, get_cuda_stream
|
||||||
|
|
||||||
|
|
||||||
|
def _top_k_renorm_probs_internal(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
maybe_top_k_arr: Optional[torch.Tensor],
|
||||||
|
top_k_val: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
probs = probs.float()
|
||||||
|
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||||
|
renorm_probs = torch.empty_like(probs)
|
||||||
|
torch.ops.sgl_kernels.top_k_renorm_probs_wrapper(
|
||||||
|
probs,
|
||||||
|
renorm_probs,
|
||||||
|
maybe_top_k_arr,
|
||||||
|
top_k_val,
|
||||||
|
get_cuda_stream(),
|
||||||
|
)
|
||||||
|
return renorm_probs
|
||||||
|
|
||||||
|
|
||||||
|
def top_k_renorm_probs(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
top_k: Union[torch.Tensor, int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
|
||||||
|
|
||||||
|
|
||||||
|
top_k_renorm_prob = top_k_renorm_probs
|
||||||
|
|
||||||
|
|
||||||
|
def _top_p_renorm_probs_internal(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
maybe_top_p_arr: Optional[torch.Tensor],
|
||||||
|
top_p_val: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
probs = probs.float()
|
||||||
|
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||||
|
renorm_probs = torch.empty_like(probs)
|
||||||
|
torch.ops.sgl_kernels.top_p_renorm_probs(
|
||||||
|
probs,
|
||||||
|
renorm_probs,
|
||||||
|
maybe_top_p_arr,
|
||||||
|
top_p_val,
|
||||||
|
get_cuda_stream(),
|
||||||
|
)
|
||||||
|
return renorm_probs
|
||||||
|
|
||||||
|
|
||||||
|
def top_p_renorm_probs(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
top_p: Union[torch.Tensor, float],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
|
||||||
|
|
||||||
|
|
||||||
|
top_p_renorm_prob = top_p_renorm_probs
|
||||||
|
|
||||||
|
|
||||||
|
def _top_p_sampling_from_probs_internal(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
uniform_samples: torch.Tensor,
|
||||||
|
maybe_top_p_arr: Optional[torch.Tensor],
|
||||||
|
top_p_val: float,
|
||||||
|
deterministic: bool,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
with probs.device as device:
|
||||||
|
probs = probs.float()
|
||||||
|
uniform_samples = uniform_samples.float()
|
||||||
|
maybe_top_p_arr = (
|
||||||
|
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||||
|
)
|
||||||
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
|
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||||
|
torch.ops.sgl_kernels.top_p_sampling_from_probs(
|
||||||
|
probs,
|
||||||
|
uniform_samples,
|
||||||
|
samples,
|
||||||
|
success,
|
||||||
|
maybe_top_p_arr,
|
||||||
|
top_p_val,
|
||||||
|
deterministic,
|
||||||
|
get_cuda_stream(),
|
||||||
|
)
|
||||||
|
return samples, success
|
||||||
|
|
||||||
|
|
||||||
|
def top_p_sampling_from_probs(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
uniform_samples: torch.Tensor,
|
||||||
|
top_p: Union[torch.Tensor, float],
|
||||||
|
deterministic: bool = True,
|
||||||
|
check_nan: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if check_nan:
|
||||||
|
if torch.any(torch.isnan(probs)):
|
||||||
|
raise ValueError("Input probs contains NaN.")
|
||||||
|
return _top_p_sampling_from_probs_internal(
|
||||||
|
probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _top_k_top_p_sampling_from_probs_internal(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
uniform_samples: torch.Tensor,
|
||||||
|
maybe_top_k_arr: Optional[torch.Tensor],
|
||||||
|
top_k_val: int,
|
||||||
|
maybe_top_p_arr: Optional[torch.Tensor],
|
||||||
|
top_p_val: float,
|
||||||
|
deterministic: bool,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
with probs.device as device:
|
||||||
|
probs = probs.float()
|
||||||
|
uniform_samples = uniform_samples.float()
|
||||||
|
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||||
|
maybe_top_p_arr = (
|
||||||
|
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||||
|
)
|
||||||
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
|
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
|
||||||
|
torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs(
|
||||||
|
probs,
|
||||||
|
uniform_samples,
|
||||||
|
samples,
|
||||||
|
success,
|
||||||
|
maybe_top_k_arr,
|
||||||
|
top_k_val,
|
||||||
|
maybe_top_p_arr,
|
||||||
|
top_p_val,
|
||||||
|
deterministic,
|
||||||
|
get_cuda_stream(),
|
||||||
|
)
|
||||||
|
return samples, success
|
||||||
|
|
||||||
|
|
||||||
|
def top_k_top_p_sampling_from_probs(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
uniform_samples: torch.Tensor,
|
||||||
|
top_k: Union[torch.Tensor, int],
|
||||||
|
top_p: Union[torch.Tensor, float],
|
||||||
|
filter_apply_order: str = "top_k_first",
|
||||||
|
deterministic: bool = True,
|
||||||
|
check_nan: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if filter_apply_order == "top_k_first":
|
||||||
|
renorm_probs = top_k_renorm_probs(probs, top_k)
|
||||||
|
return top_p_sampling_from_probs(
|
||||||
|
renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan
|
||||||
|
)
|
||||||
|
elif filter_apply_order == "joint":
|
||||||
|
if check_nan:
|
||||||
|
if torch.any(torch.isnan(probs)):
|
||||||
|
raise ValueError("Input probs contains NaN.")
|
||||||
|
return _top_k_top_p_sampling_from_probs_internal(
|
||||||
|
probs,
|
||||||
|
uniform_samples,
|
||||||
|
*_to_tensor_scalar_tuple(top_k),
|
||||||
|
*_to_tensor_scalar_tuple(top_p),
|
||||||
|
deterministic,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
||||||
|
|
||||||
|
|
||||||
|
def _min_p_sampling_from_probs_internal(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
uniform_samples: torch.Tensor,
|
||||||
|
maybe_min_p_arr: Optional[torch.Tensor],
|
||||||
|
min_p_val: float,
|
||||||
|
deterministic: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
with probs.device as device:
|
||||||
|
probs = probs.float()
|
||||||
|
uniform_samples = uniform_samples.float()
|
||||||
|
maybe_min_p_arr = (
|
||||||
|
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||||
|
)
|
||||||
|
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||||
|
torch.ops.sgl_kernels.min_p_sampling_from_probs(
|
||||||
|
probs,
|
||||||
|
uniform_samples,
|
||||||
|
samples,
|
||||||
|
maybe_min_p_arr,
|
||||||
|
min_p_val,
|
||||||
|
deterministic,
|
||||||
|
get_cuda_stream(),
|
||||||
|
)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def min_p_sampling_from_probs(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
uniform_samples: torch.Tensor,
|
||||||
|
min_p: Union[torch.Tensor, float],
|
||||||
|
deterministic: bool = True,
|
||||||
|
check_nan: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if uniform_samples.dim() == 2:
|
||||||
|
# Take the first row (round) of uniform_samples
|
||||||
|
uniform_samples = uniform_samples[0]
|
||||||
|
|
||||||
|
if check_nan:
|
||||||
|
if torch.any(torch.isnan(probs)):
|
||||||
|
raise ValueError("Input probs contains NaN.")
|
||||||
|
return _min_p_sampling_from_probs_internal(
|
||||||
|
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
|
||||||
|
)
|
||||||
84
sgl-kernel/src/sgl-kernel/ops/speculative.py
Normal file
84
sgl-kernel/src/sgl-kernel/ops/speculative.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import sgl_kernel.ops._kernels
|
||||||
|
import torch
|
||||||
|
from sgl_kernel.ops.utils import get_cuda_stream
|
||||||
|
|
||||||
|
|
||||||
|
def tree_speculative_sampling_target_only(
|
||||||
|
predicts: torch.Tensor, # mutable
|
||||||
|
accept_index: torch.Tensor, # mutable
|
||||||
|
accept_token_num: torch.Tensor, # mutable
|
||||||
|
candidates: torch.Tensor,
|
||||||
|
retrive_index: torch.Tensor,
|
||||||
|
retrive_next_token: torch.Tensor,
|
||||||
|
retrive_next_sibling: torch.Tensor,
|
||||||
|
uniform_samples: torch.Tensor,
|
||||||
|
target_probs: torch.Tensor,
|
||||||
|
draft_probs: torch.Tensor,
|
||||||
|
deterministic: bool = True,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.sgl_kernels.tree_speculative_sampling_target_only(
|
||||||
|
predicts,
|
||||||
|
accept_index,
|
||||||
|
accept_token_num,
|
||||||
|
candidates,
|
||||||
|
retrive_index,
|
||||||
|
retrive_next_token,
|
||||||
|
retrive_next_sibling,
|
||||||
|
uniform_samples,
|
||||||
|
target_probs,
|
||||||
|
draft_probs,
|
||||||
|
deterministic,
|
||||||
|
get_cuda_stream(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_tree_kernel_efficient(
|
||||||
|
parent_list: torch.Tensor,
|
||||||
|
selected_index: torch.Tensor,
|
||||||
|
verified_seq_len: torch.Tensor,
|
||||||
|
tree_mask: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
retrive_index: torch.Tensor,
|
||||||
|
retrive_next_token: torch.Tensor,
|
||||||
|
retrive_next_sibling: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
depth: int,
|
||||||
|
draft_token_num: int,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.sgl_kernels.build_tree_kernel_efficient(
|
||||||
|
parent_list,
|
||||||
|
selected_index,
|
||||||
|
verified_seq_len,
|
||||||
|
tree_mask,
|
||||||
|
positions,
|
||||||
|
retrive_index,
|
||||||
|
retrive_next_token,
|
||||||
|
retrive_next_sibling,
|
||||||
|
topk,
|
||||||
|
depth,
|
||||||
|
draft_token_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_tree_kernel(
|
||||||
|
parent_list: torch.Tensor,
|
||||||
|
selected_index: torch.Tensor,
|
||||||
|
verified_seq_len: torch.Tensor,
|
||||||
|
tree_mask: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
retrive_index: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
depth: int,
|
||||||
|
draft_token_num: int,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.sgl_kernels.build_tree_kernel(
|
||||||
|
parent_list,
|
||||||
|
selected_index,
|
||||||
|
verified_seq_len,
|
||||||
|
tree_mask,
|
||||||
|
positions,
|
||||||
|
retrive_index,
|
||||||
|
topk,
|
||||||
|
depth,
|
||||||
|
draft_token_num,
|
||||||
|
)
|
||||||
@@ -18,8 +18,8 @@ from typing import Dict, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _get_cuda_stream(device: torch.device) -> int:
|
def get_cuda_stream() -> int:
|
||||||
return torch.cuda.current_stream(device).cuda_stream
|
return torch.cuda.current_stream().cuda_stream
|
||||||
|
|
||||||
|
|
||||||
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
|
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import unittest
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
|
import sgl_kernel.ops.allreduce as custom_ops
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from sgl_kernel import ops as custom_ops
|
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from vllm import _custom_ops as vllm_ops
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user