sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
117
sgl-kernel/python/sgl_kernel/__init__.py
Executable file
117
sgl-kernel/python/sgl_kernel/__init__.py
Executable file
@@ -0,0 +1,117 @@
|
||||
import ctypes
|
||||
import os
|
||||
import platform
|
||||
|
||||
import torch
|
||||
|
||||
SYSTEM_ARCH = platform.machine()
|
||||
|
||||
cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
|
||||
if os.path.exists(cuda_path):
|
||||
ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
|
||||
|
||||
from sgl_kernel import common_ops
|
||||
from sgl_kernel.allreduce import *
|
||||
from sgl_kernel.attention import (
|
||||
cutlass_mla_decode,
|
||||
cutlass_mla_get_workspace_size,
|
||||
lightning_attention_decode,
|
||||
merge_state,
|
||||
merge_state_v2,
|
||||
)
|
||||
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
|
||||
from sgl_kernel.elementwise import (
|
||||
FusedSetKVBufferArg,
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
concat_mla_k,
|
||||
copy_to_gpu_no_ce,
|
||||
downcast_fp8,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
gelu_tanh_and_mul,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update
|
||||
|
||||
if torch.version.hip is not None:
|
||||
from sgl_kernel.elementwise import gelu_quick
|
||||
|
||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
cutlass_scaled_fp4_mm,
|
||||
dsv3_fused_a_gemm,
|
||||
dsv3_router_gemm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
gptq_gemm,
|
||||
gptq_marlin_gemm,
|
||||
gptq_shuffle,
|
||||
int8_scaled_mm,
|
||||
qserve_w4a8_per_chn_gemm,
|
||||
qserve_w4a8_per_group_gemm,
|
||||
scaled_fp4_experts_quant,
|
||||
scaled_fp4_grouped_quant,
|
||||
scaled_fp4_quant,
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_group_quant_int8,
|
||||
sgl_per_token_quant_fp8,
|
||||
shuffle_rows,
|
||||
silu_and_mul_scaled_fp4_grouped_quant,
|
||||
)
|
||||
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
||||
from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_all_layer,
|
||||
transfer_kv_all_layer_mla,
|
||||
transfer_kv_per_layer,
|
||||
transfer_kv_per_layer_mla,
|
||||
)
|
||||
from sgl_kernel.marlin import (
|
||||
awq_marlin_moe_repack,
|
||||
awq_marlin_repack,
|
||||
gptq_marlin_repack,
|
||||
)
|
||||
from sgl_kernel.memory import set_kv_buffer_kernel
|
||||
from sgl_kernel.moe import (
|
||||
apply_shuffle_mul_sum,
|
||||
cutlass_fp4_group_mm,
|
||||
fp8_blockwise_scaled_grouped_mm,
|
||||
moe_align_block_size,
|
||||
moe_fused_gate,
|
||||
prepare_moe_input,
|
||||
topk_softmax,
|
||||
)
|
||||
from sgl_kernel.sampling import (
|
||||
min_p_sampling_from_probs,
|
||||
top_k_mask_logits,
|
||||
top_k_renorm_prob,
|
||||
top_k_top_p_sampling_from_logits,
|
||||
top_k_top_p_sampling_from_probs,
|
||||
top_p_renorm_prob,
|
||||
top_p_sampling_from_probs,
|
||||
)
|
||||
from sgl_kernel.speculative import (
|
||||
build_tree_kernel_efficient,
|
||||
segment_packbits,
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
from sgl_kernel.top_k import fast_topk
|
||||
from sgl_kernel.version import __version__
|
||||
|
||||
|
||||
def create_greenctx_stream_by_value(*args, **kwargs):
|
||||
from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl
|
||||
|
||||
return _impl(*args, **kwargs)
|
||||
|
||||
|
||||
def get_sm_available(*args, **kwargs):
|
||||
from sgl_kernel.spatial import get_sm_available as _impl
|
||||
|
||||
return _impl(*args, **kwargs)
|
||||
376
sgl-kernel/python/sgl_kernel/_fa4_interface.py
Normal file
376
sgl-kernel/python/sgl_kernel/_fa4_interface.py
Normal file
@@ -0,0 +1,376 @@
|
||||
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/203b9b3dba39d5d08dffb49c09aa622984dff07d/flash_attn/cute/interface.py
|
||||
|
||||
# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0.
|
||||
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import torch
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90
|
||||
from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100
|
||||
|
||||
|
||||
def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
|
||||
torch2cute_dtype_map = {
|
||||
torch.float16: cutlass.Float16,
|
||||
torch.bfloat16: cutlass.BFloat16,
|
||||
torch.float32: cutlass.Float32,
|
||||
}
|
||||
|
||||
|
||||
def _flash_attn_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
seqused_q: Optional[torch.Tensor] = None,
|
||||
seqused_k: Optional[torch.Tensor] = None,
|
||||
page_table: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
softcap: Optional[float] = None,
|
||||
window_size_left: Optional[int] = None,
|
||||
window_size_right: Optional[int] = None,
|
||||
learnable_sink: Optional[torch.Tensor] = None,
|
||||
# m_block_size: int = 128,
|
||||
# n_block_size: int = 64,
|
||||
# num_threads: int = 128,
|
||||
m_block_size: int = 128,
|
||||
n_block_size: int = 128,
|
||||
num_threads: int = 384,
|
||||
pack_gqa: Optional[bool] = None,
|
||||
_compute_capability: Optional[int] = None,
|
||||
return_softmax_lse: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
q, k, v = [maybe_contiguous(t) for t in (q, k, v)]
|
||||
num_head, head_dim = q.shape[-2:]
|
||||
if cu_seqlens_q is None:
|
||||
batch_size, seqlen_q = q.shape[:2]
|
||||
total_q = batch_size * seqlen_q
|
||||
else:
|
||||
batch_size = cu_seqlens_q.shape[0] - 1
|
||||
seqlen_q = None
|
||||
total_q = q.shape[0]
|
||||
if page_table is not None:
|
||||
assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k"
|
||||
assert page_table.dtype == torch.int32, "page_table must be int32"
|
||||
assert (
|
||||
page_table.stride(-1) == 1
|
||||
), "page_table must be contiguous in the last dimension"
|
||||
max_num_pages_per_seq = page_table.shape[1]
|
||||
assert page_table.shape == (batch_size, max_num_pages_per_seq)
|
||||
num_pages, page_size = k.shape[:2]
|
||||
seqlen_k = num_pages * page_size
|
||||
else:
|
||||
num_pages, page_size = None, None
|
||||
seqlen_k = k.shape[-3]
|
||||
num_head_kv = k.shape[-2]
|
||||
head_dim_v = v.shape[-1]
|
||||
if cu_seqlens_k is None:
|
||||
if page_table is None:
|
||||
assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim)
|
||||
assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v)
|
||||
else:
|
||||
assert k.shape == (num_pages, page_size, num_head_kv, head_dim)
|
||||
assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v)
|
||||
else:
|
||||
assert k.shape == (seqlen_k, num_head_kv, head_dim)
|
||||
assert v.shape == (seqlen_k, num_head_kv, head_dim_v)
|
||||
assert cu_seqlens_k.shape == (
|
||||
batch_size + 1,
|
||||
), "cu_seqlens_k must have shape (batch_size + 1,)"
|
||||
if cu_seqlens_q is not None:
|
||||
assert cu_seqlens_q.shape == (
|
||||
batch_size + 1,
|
||||
), "cu_seqlens_q must have shape (batch_size + 1,)"
|
||||
assert seqused_q is None or seqused_q.shape == (
|
||||
batch_size,
|
||||
), "seqused_q must have shape (batch_size,)"
|
||||
assert seqused_k is None or seqused_k.shape == (
|
||||
batch_size,
|
||||
), "seqused_k must have shape (batch_size,)"
|
||||
assert q.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], "inputs must be float16 or bfloat16"
|
||||
assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype"
|
||||
for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]:
|
||||
if t is not None:
|
||||
assert (
|
||||
t.dtype == torch.int32
|
||||
), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32"
|
||||
assert (
|
||||
t.stride(0) == 1
|
||||
), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous"
|
||||
if learnable_sink is not None:
|
||||
assert learnable_sink.shape == (num_head,)
|
||||
assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16"
|
||||
assert all(
|
||||
t is None or t.is_cuda
|
||||
for t in (
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
page_table,
|
||||
learnable_sink,
|
||||
)
|
||||
), "inputs must be on CUDA device"
|
||||
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
|
||||
assert head_dim <= 256, "head_dim must be less than or equal to 256"
|
||||
alignment = 16 // q.element_size()
|
||||
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
|
||||
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
|
||||
if softmax_scale is None:
|
||||
softmax_scale = 1.0 / math.sqrt(head_dim)
|
||||
if softcap == 0.0:
|
||||
softcap = None
|
||||
qhead_per_kvhead = num_head // num_head_kv
|
||||
if pack_gqa is None:
|
||||
pack_gqa = qhead_per_kvhead > 1
|
||||
|
||||
out_torch_dtype = q.dtype
|
||||
device = q.device
|
||||
q_batch_seqlen_shape = (
|
||||
(batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,)
|
||||
)
|
||||
out = torch.empty(
|
||||
*q_batch_seqlen_shape,
|
||||
num_head,
|
||||
head_dim_v,
|
||||
dtype=out_torch_dtype,
|
||||
device=device,
|
||||
)
|
||||
lse_shape = (
|
||||
(batch_size, num_head, seqlen_q)
|
||||
if cu_seqlens_q is None
|
||||
else (num_head, total_q)
|
||||
)
|
||||
lse = (
|
||||
torch.empty(lse_shape, dtype=torch.float32, device=device)
|
||||
if return_softmax_lse
|
||||
else None
|
||||
)
|
||||
|
||||
dtype = torch2cute_dtype_map[q.dtype]
|
||||
q_tensor, k_tensor, v_tensor, o_tensor = [
|
||||
from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(
|
||||
leading_dim=t.ndim - 1
|
||||
)
|
||||
for t in (q, k, v, out)
|
||||
]
|
||||
lse_tensor = (
|
||||
from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(
|
||||
leading_dim=lse.ndim - 1
|
||||
)
|
||||
if lse is not None
|
||||
else None
|
||||
)
|
||||
(
|
||||
cu_seqlens_q_tensor,
|
||||
cu_seqlens_k_tensor,
|
||||
seqused_q_tensor,
|
||||
seqused_k_tensor,
|
||||
learnable_sink_tensor,
|
||||
) = [
|
||||
(
|
||||
from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
|
||||
if t is not None
|
||||
else None
|
||||
)
|
||||
for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink)
|
||||
]
|
||||
page_table_tensor = (
|
||||
from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic(
|
||||
leading_dim=1
|
||||
)
|
||||
if page_table is not None
|
||||
else None
|
||||
)
|
||||
if causal:
|
||||
window_size_right = 0
|
||||
local = window_size_left is not None or window_size_right is not None
|
||||
if window_size_left is not None or window_size_right is not None:
|
||||
if window_size_left is None and window_size_right == 0:
|
||||
causal, local = True, False
|
||||
else:
|
||||
causal, local = False, True
|
||||
compute_capability = (
|
||||
torch.cuda.get_device_capability()[0]
|
||||
if _compute_capability is None
|
||||
else _compute_capability
|
||||
)
|
||||
assert compute_capability in [
|
||||
9,
|
||||
10,
|
||||
], "Unsupported compute capability. Supported: 9.x, 10.x"
|
||||
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
if compute_capability == 9: # TODO: tune block size according to hdim
|
||||
if head_dim == head_dim_v == 128 and not causal and not local:
|
||||
n_block_size = 192
|
||||
if compute_capability == 10:
|
||||
# TODO: fix the varlen case
|
||||
if (
|
||||
pack_gqa
|
||||
and (128 % qhead_per_kvhead != 0)
|
||||
or (cu_seqlens_q is not None or seqused_q is not None)
|
||||
):
|
||||
pack_gqa = False
|
||||
|
||||
compile_key = (
|
||||
dtype,
|
||||
head_dim,
|
||||
head_dim_v,
|
||||
qhead_per_kvhead,
|
||||
causal,
|
||||
softcap is not None,
|
||||
lse is None,
|
||||
cu_seqlens_q is None,
|
||||
cu_seqlens_k is None,
|
||||
seqused_q is None,
|
||||
seqused_k is None,
|
||||
page_table is not None,
|
||||
window_size_left is not None,
|
||||
window_size_right is not None,
|
||||
learnable_sink is not None,
|
||||
m_block_size,
|
||||
n_block_size,
|
||||
num_threads,
|
||||
pack_gqa,
|
||||
compute_capability,
|
||||
)
|
||||
if compile_key not in _flash_attn_fwd.compile_cache:
|
||||
if compute_capability == 9:
|
||||
assert page_table is None, "paged KV not supported on SM 9.0"
|
||||
# fa_fwd = FlashAttentionForwardSm80(
|
||||
fa_fwd = FlashAttentionForwardSm90(
|
||||
dtype,
|
||||
head_dim,
|
||||
head_dim_v,
|
||||
qhead_per_kvhead,
|
||||
is_causal=causal,
|
||||
is_local=local,
|
||||
pack_gqa=pack_gqa,
|
||||
m_block_size=m_block_size,
|
||||
n_block_size=n_block_size,
|
||||
# num_stages=1,
|
||||
num_stages=2,
|
||||
num_threads=num_threads,
|
||||
Q_in_regs=False,
|
||||
)
|
||||
elif compute_capability == 10:
|
||||
assert page_size in [
|
||||
None,
|
||||
128,
|
||||
], "Only page_size=128 is supported for paged KV on SM 10.0"
|
||||
fa_fwd = FlashAttentionForwardSm100(
|
||||
head_dim,
|
||||
head_dim_v,
|
||||
qhead_per_kvhead=qhead_per_kvhead,
|
||||
is_causal=causal,
|
||||
is_local=local,
|
||||
pack_gqa=pack_gqa,
|
||||
is_persistent=not causal
|
||||
and not local
|
||||
and cu_seqlens_q is None
|
||||
and seqused_q is None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x"
|
||||
)
|
||||
# TODO: check @can_implement
|
||||
_flash_attn_fwd.compile_cache[compile_key] = cute.compile(
|
||||
fa_fwd,
|
||||
q_tensor,
|
||||
k_tensor,
|
||||
v_tensor,
|
||||
o_tensor,
|
||||
lse_tensor,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
cu_seqlens_q_tensor,
|
||||
cu_seqlens_k_tensor,
|
||||
seqused_q_tensor,
|
||||
seqused_k_tensor,
|
||||
page_table_tensor,
|
||||
softcap,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
learnable_sink_tensor,
|
||||
)
|
||||
_flash_attn_fwd.compile_cache[compile_key](
|
||||
q_tensor,
|
||||
k_tensor,
|
||||
v_tensor,
|
||||
o_tensor,
|
||||
lse_tensor,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
cu_seqlens_q_tensor,
|
||||
cu_seqlens_k_tensor,
|
||||
seqused_q_tensor,
|
||||
seqused_k_tensor,
|
||||
page_table_tensor,
|
||||
softcap,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
learnable_sink_tensor,
|
||||
)
|
||||
return out, lse
|
||||
|
||||
|
||||
_flash_attn_fwd.compile_cache = {}
|
||||
|
||||
|
||||
def flash_attn_varlen_func(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
seqused_q: Optional[torch.Tensor] = None,
|
||||
seqused_k: Optional[torch.Tensor] = None,
|
||||
page_table: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
|
||||
learnable_sink: Optional[torch.Tensor] = None,
|
||||
softcap: float = 0.0,
|
||||
pack_gqa: Optional[bool] = None,
|
||||
return_softmax_lse: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
out, lse = _flash_attn_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
page_table=page_table,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size_left=window_size[0],
|
||||
window_size_right=window_size[1],
|
||||
learnable_sink=learnable_sink,
|
||||
softcap=softcap,
|
||||
pack_gqa=pack_gqa,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
)
|
||||
|
||||
return (out, lse) if return_softmax_lse else out
|
||||
173
sgl-kernel/python/sgl_kernel/allreduce.py
Normal file
173
sgl-kernel/python/sgl_kernel/allreduce.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
if torch.version.hip is not None:
|
||||
# ROCM custom allreduce
|
||||
def init_custom_ar(
|
||||
meta: torch.Tensor,
|
||||
rank_data: torch.Tensor,
|
||||
handles: List[str],
|
||||
offsets: List[int],
|
||||
rank: int,
|
||||
full_nvlink: bool,
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.init_custom_ar.default(
|
||||
meta, rank_data, handles, offsets, rank, full_nvlink
|
||||
)
|
||||
|
||||
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
|
||||
torch.ops.sgl_kernel.all_reduce_reg.default(fa, inp, out)
|
||||
|
||||
def all_reduce_unreg(
|
||||
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.all_reduce_unreg.default(fa, inp, reg_buffer, out)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops.sgl_kernel.dispose.default(fa)
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernel.meta_size.default()
|
||||
|
||||
def register_buffer(
|
||||
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
|
||||
) -> None:
|
||||
return torch.ops.sgl_kernel.register_buffer.default(fa, t, handles, offsets)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
|
||||
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[str], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
|
||||
|
||||
def allocate_meta_buffer(size: int) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.allocate_meta_buffer.default(size)
|
||||
|
||||
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
|
||||
|
||||
# ROCM quick allreduce
|
||||
def init_custom_qr(
|
||||
rank: int, world_size: int, qr_max_size: Optional[int] = None
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.init_custom_qr.default(
|
||||
world_size, rank, qr_max_size
|
||||
)
|
||||
|
||||
def qr_get_handle(fa: int) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.qr_get_handle.default(fa)
|
||||
|
||||
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
|
||||
torch.ops.sgl_kernel.qr_open_handles.default(fa, handles)
|
||||
|
||||
def qr_all_reduce(
|
||||
fa: int,
|
||||
profile: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
cast_bf162half: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.qr_all_reduce.default(
|
||||
fa, profile, inp, out, cast_bf162half
|
||||
)
|
||||
|
||||
def qr_destroy(fa: int) -> None:
|
||||
torch.ops.sgl_kernel.qr_destroy.default(fa)
|
||||
|
||||
def qr_max_size() -> int:
|
||||
return torch.ops.sgl_kernel.qr_max_size.default()
|
||||
|
||||
# mscclpp
|
||||
def mscclpp_generate_unique_id() -> bytes:
|
||||
raise NotImplementedError()
|
||||
|
||||
def mscclpp_init_context(
|
||||
unique_id: bytes,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scratch: torch.Tensor,
|
||||
put_buffer: torch.Tensor,
|
||||
nranks_per_node: int,
|
||||
rank_to_node: List[int],
|
||||
rank_to_ib: List[int],
|
||||
context_selection: int,
|
||||
) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
else:
|
||||
|
||||
def init_custom_ar(
|
||||
ipc_tensors: List[int], rank_data: torch.Tensor, rank: int, full_nvlink: bool
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.init_custom_ar.default(
|
||||
ipc_tensors, rank_data, rank, full_nvlink
|
||||
)
|
||||
|
||||
def dispose(fa: int) -> None:
|
||||
torch.ops.sgl_kernel.dispose.default(fa)
|
||||
|
||||
def all_reduce(
|
||||
fa: int,
|
||||
inp: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
reg_buffer: int,
|
||||
reg_buffer_sz_bytes: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.all_reduce.default(
|
||||
fa, inp, out, reg_buffer, reg_buffer_sz_bytes
|
||||
)
|
||||
|
||||
def get_graph_buffer_ipc_meta(fa) -> Tuple[List[int], List[int]]:
|
||||
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
|
||||
|
||||
def register_buffer(fa: int, fake_ipc_ptrs: List[int]) -> None:
|
||||
return torch.ops.sgl_kernel.register_buffer.default(fa, fake_ipc_ptrs)
|
||||
|
||||
def register_graph_buffers(
|
||||
fa: int, handles: List[List[int]], offsets: List[List[int]]
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
|
||||
|
||||
def meta_size() -> int:
|
||||
return torch.ops.sgl_kernel.meta_size.default()
|
||||
|
||||
def mscclpp_generate_unique_id() -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.mscclpp_generate_unique_id.default()
|
||||
|
||||
def mscclpp_init_context(
|
||||
unique_id: torch.Tensor,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
scratch: torch.Tensor,
|
||||
put_buffer: torch.Tensor,
|
||||
nranks_per_node: int,
|
||||
rank_to_node: List[int],
|
||||
rank_to_ib: List[int],
|
||||
context_selection: int,
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.mscclpp_init_context.default(
|
||||
unique_id,
|
||||
rank,
|
||||
world_size,
|
||||
scratch,
|
||||
put_buffer,
|
||||
nranks_per_node,
|
||||
rank_to_node,
|
||||
rank_to_ib,
|
||||
context_selection,
|
||||
)
|
||||
|
||||
def mscclpp_allreduce(
|
||||
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.mscclpp_allreduce.default(
|
||||
context, inp, out, nthreads, nblocks
|
||||
)
|
||||
138
sgl-kernel/python/sgl_kernel/attention.py
Normal file
138
sgl-kernel/python/sgl_kernel/attention.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
torch.ops.sgl_kernel.lightning_attention_decode.default(
|
||||
q, k, v, past_kv, slope, output, new_kv
|
||||
)
|
||||
|
||||
|
||||
def merge_state(
|
||||
v_a: torch.Tensor,
|
||||
s_a: torch.Tensor,
|
||||
v_b: torch.Tensor,
|
||||
s_b: torch.Tensor,
|
||||
v_merged: Optional[torch.Tensor] = None,
|
||||
s_merged: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if v_merged is None:
|
||||
v_merged = torch.empty_like(v_a)
|
||||
if s_merged is None:
|
||||
s_merged = torch.empty_like(s_a)
|
||||
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
def merge_state_v2(
|
||||
v_a: torch.Tensor,
|
||||
s_a: torch.Tensor,
|
||||
v_b: torch.Tensor,
|
||||
s_b: torch.Tensor,
|
||||
v_merged: Optional[torch.Tensor] = None,
|
||||
s_merged: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
|
||||
# does not support the FP8 data type and non - CUDA devices.
|
||||
# It may be necessary to fall back to using the Triton kernel.
|
||||
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if v_merged is None:
|
||||
v_merged = torch.empty_like(v_a)
|
||||
if s_merged is None:
|
||||
s_merged = torch.empty_like(s_a)
|
||||
torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
def cutlass_mla_decode(
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
sm_scale: float,
|
||||
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
|
||||
) -> torch.Tensor:
|
||||
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
||||
assert (
|
||||
kv_c_and_k_pe_cache.ndim == 3
|
||||
), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
|
||||
|
||||
B_q, H, D_q_nope = q_nope.shape
|
||||
B_q_2, H_2, D_q_pe = q_pe.shape
|
||||
assert (B_q == B_q_2) and (H == H_2)
|
||||
|
||||
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
|
||||
|
||||
D_latent = 512
|
||||
D_rope = 64
|
||||
assert D_q_nope == D_latent
|
||||
assert D_q_pe == D_rope
|
||||
assert D_ckv == D_latent + D_rope
|
||||
|
||||
MAX_HEADS = 128
|
||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||
if H < MAX_HEADS:
|
||||
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
|
||||
q_nope_padded[:, :H] = q_nope
|
||||
q_nope = q_nope_padded
|
||||
|
||||
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
|
||||
q_pe_padded[:, :H] = q_pe
|
||||
q_pe = q_pe_padded
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
assert B_block_table == B_q
|
||||
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
|
||||
assert block_num % (128 / PAGE_SIZE) == 0
|
||||
|
||||
# TODO(kaixih@nvidia): support fp8
|
||||
assert q_nope.dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
|
||||
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
||||
assert (
|
||||
seq_lens.dtype == torch.int32
|
||||
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
|
||||
assert (
|
||||
page_table.dtype == torch.int32
|
||||
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
|
||||
|
||||
torch.ops.sgl_kernel.cutlass_mla_decode.default(
|
||||
out,
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
seq_lens,
|
||||
page_table,
|
||||
workspace,
|
||||
sm_scale,
|
||||
num_kv_splits,
|
||||
)
|
||||
return out[:, :H].contiguous()
|
||||
|
||||
|
||||
def cutlass_mla_get_workspace_size(
|
||||
max_seq_len: int,
|
||||
num_batches: int,
|
||||
sm_count: int = 0,
|
||||
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
|
||||
) -> int:
|
||||
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
|
||||
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
|
||||
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
|
||||
max_seq_len, num_batches, sm_count, num_kv_splits
|
||||
)
|
||||
112
sgl-kernel/python/sgl_kernel/cutlass_moe.py
Normal file
112
sgl-kernel/python/sgl_kernel/cutlass_moe.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_cutlass_w4a8_moe_mm_data(
|
||||
topk_ids: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor,
|
||||
problem_sizes2: torch.Tensor,
|
||||
input_permutation: torch.Tensor,
|
||||
output_permutation: torch.Tensor,
|
||||
num_experts: int,
|
||||
n: int,
|
||||
k: int,
|
||||
):
|
||||
"""
|
||||
Prepare data necessary to perform CUTLASS grouped matrix multiplications
|
||||
used in CUTLASS-based fused MoE.
|
||||
|
||||
The function takes in topk_ids (token-expert mapping) and uses it to
|
||||
compute:
|
||||
- expert_offsets: Indices that mark at which token index each expert begins
|
||||
its computation after the input is sorted with
|
||||
input_permutation. The number of tokens computed with
|
||||
expert E is expert_offsets[E + 1] - expert_offsets[E]
|
||||
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
|
||||
multiplication in two grouped MMs used in
|
||||
the fused MoE operation.
|
||||
- input_permutation: Permutation that must be used to shuffle the input
|
||||
before executing the MMs.
|
||||
- output_permutation: Permutation that must be used to shuffle the output
|
||||
after executing the MMs.
|
||||
"""
|
||||
torch.ops.sgl_kernel.get_cutlass_w4a8_moe_mm_data.default(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_w4a8_moe_mm(
|
||||
d: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a_scales: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
experts_offsets: torch.tensor,
|
||||
problem_sizes: torch.tensor,
|
||||
a_strides: torch.tensor,
|
||||
b_strides: torch.tensor,
|
||||
d_strides: torch.tensor,
|
||||
s_strides: torch.tensor,
|
||||
chunk_size: int = 128,
|
||||
topk: int = 8,
|
||||
):
|
||||
"""
|
||||
Perform grouped matrix multiplication between int4 weights and fp8 activations.
|
||||
|
||||
This function executes multiple GEMM operations in parallel, which is useful for
|
||||
scenarios like Mixture of Experts (MoE) where different inputs go through different
|
||||
experts. The implementation leverages NVIDIA Hopper architecture features for
|
||||
optimal performance with quantized weights.
|
||||
|
||||
Args:
|
||||
d: Output matrices of shape [total_m, total_n]
|
||||
a: Activation matrices in FP8 (float_e4m3_t) format
|
||||
Each tensor should be of shape [total_m, K] in row-major layout
|
||||
b: Weight matrices in packed int4 format
|
||||
Each tensor should be of shape [E, N, K//2] in column-major layout
|
||||
where each byte contains two 4-bit integers
|
||||
a_scales: Scale factors for the inputs
|
||||
b_scales: Scale factors for the quantized weights
|
||||
Each tensor should be of shape [E, K//512, N*8]
|
||||
experts_offsets: Tensor containing expert offsets for determining group boundaries
|
||||
problem_sizes: with shape [num_experts, 3] (M, N, K for each group) (int32)
|
||||
a_strides: Strides information for A matrices
|
||||
b_strides: Strides information for B matrices
|
||||
d_strides: Strides information for D matrices
|
||||
s_strides: Strides information for b_scales matrices
|
||||
chunk_size: Number of elements each scale value applies to (K//512), default to 128
|
||||
|
||||
Requirements:
|
||||
- All tensors must be on a CUDA device
|
||||
- Requires an NVIDIA Hopper GPU (H100)
|
||||
- A tensors must be in float8_e4m3fn format
|
||||
- B tensors must contain packed int4 values (stored as int8)
|
||||
|
||||
Note:
|
||||
The function computes: D = (A * (B * scales))
|
||||
for each group of tensors in parallel
|
||||
"""
|
||||
|
||||
torch.ops.sgl_kernel.cutlass_w4a8_moe_mm.default(
|
||||
d,
|
||||
a,
|
||||
b,
|
||||
a_scales,
|
||||
b_scales,
|
||||
experts_offsets,
|
||||
problem_sizes,
|
||||
a_strides,
|
||||
b_strides,
|
||||
d_strides,
|
||||
s_strides,
|
||||
chunk_size,
|
||||
topk,
|
||||
)
|
||||
381
sgl-kernel/python/sgl_kernel/elementwise.py
Normal file
381
sgl-kernel/python/sgl_kernel/elementwise.py
Normal file
@@ -0,0 +1,381 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl
|
||||
|
||||
|
||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||
# Kudos to @yzh119
|
||||
def rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Root mean square normalization.
|
||||
|
||||
``out[i] = (input[i] / RMS(input)) * weight[i]``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
out: Optional[torch.Tensor]
|
||||
The output tensor, if specified, the kernel will update this tensor inplace.
|
||||
enable_pdl: Optional[bool]
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
If None, will be automatically enabled on Hopper architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: torch.Tensor
|
||||
Normalized tensor, shape (batch_size, hidden_size).
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
if enable_pdl is None:
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rmsnorm(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""Fused add root mean square normalization.
|
||||
|
||||
Step 1:
|
||||
``residual[i] += input[i]``
|
||||
|
||||
Step 2:
|
||||
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
residual: torch.Tensor
|
||||
Residual tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
enable_pdl: Optional[bool]
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
If None, will be automatically enabled on Hopper architecture.
|
||||
"""
|
||||
if enable_pdl is None:
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
|
||||
input, residual, weight, eps, enable_pdl
|
||||
)
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Gemma-style root mean square normalization.
|
||||
|
||||
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
out: Optional[torch.Tensor]
|
||||
The output tensor, if specified, the kernel will update this tensor inplace.
|
||||
enable_pdl: Optional[bool]
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
If None, will be automatically enabled on Hopper architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: torch.Tensor
|
||||
Gemma Normalized tensor, shape (batch_size, hidden_size).
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
if enable_pdl is None:
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
|
||||
return out
|
||||
|
||||
|
||||
def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""Gemma-style fused add root mean square normalization.
|
||||
|
||||
Step 1:
|
||||
``residual[i] += input[i]``
|
||||
|
||||
Step 2:
|
||||
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
residual: torch.Tensor
|
||||
Residual tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
enable_pdl: Optional[bool]
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
If None, will be automatically enabled on Hopper architecture.
|
||||
"""
|
||||
if enable_pdl is None:
|
||||
enable_pdl = is_arch_support_pdl()
|
||||
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
|
||||
input, residual, weight, eps, enable_pdl
|
||||
)
|
||||
|
||||
|
||||
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
|
||||
assert (
|
||||
input.shape[:-1] == output.shape[:-1]
|
||||
), f"{input.shape[:-1]} != {output.shape[:-1]}"
|
||||
assert (
|
||||
input.shape[-1] == 2 * output.shape[-1]
|
||||
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
|
||||
|
||||
|
||||
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.silu_and_mul.default(out, input)
|
||||
return out
|
||||
|
||||
|
||||
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input)
|
||||
return out
|
||||
|
||||
|
||||
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.gelu_and_mul.default(out, input)
|
||||
return out
|
||||
|
||||
|
||||
if torch.version.hip is not None:
|
||||
|
||||
def gelu_quick(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
"""
|
||||
Quick-GELU: y = x * sigmoid(1.702 * x)
|
||||
|
||||
The CUDA/HIP kernel uses 128-bit (16-byte) vector loads & stores,
|
||||
so the last-dimension byte length must be a multiple of 16 bytes.
|
||||
"""
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError(
|
||||
f"The last dimension ({input.shape[-1]}) x itemsize "
|
||||
f"({input.dtype.itemsize}) must be a multiple of 16 bytes."
|
||||
)
|
||||
|
||||
if out is not None:
|
||||
assert input.shape == out.shape, f"{input.shape} != {out.shape}"
|
||||
else:
|
||||
out = torch.empty_like(input)
|
||||
|
||||
torch.ops.sgl_kernel.gelu_quick(out, input)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedSetKVBufferArg:
|
||||
"""
|
||||
value : Optional[torch.Tensor]
|
||||
Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
|
||||
k_buffer : Optional[torch.Tensor]
|
||||
Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
|
||||
v_buffer : Optional[torch.Tensor]
|
||||
Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
|
||||
k_scale : Optional[float]
|
||||
Scale factor for keys.
|
||||
v_scale : Optional[float]
|
||||
Scale factor for values.
|
||||
cache_loc : Optional[torch.Tensor]
|
||||
Cache location tensor, used for indexing kv cache.
|
||||
"""
|
||||
|
||||
value: torch.Tensor
|
||||
k_buffer: torch.Tensor
|
||||
v_buffer: torch.Tensor
|
||||
k_scale: Optional[float]
|
||||
v_scale: Optional[float]
|
||||
cache_loc: torch.Tensor
|
||||
|
||||
|
||||
def apply_rope_with_cos_sin_cache_inplace(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
enable_pdl: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
||||
This is designed to be compatible with the SGL/vLLM implementation.
|
||||
The result is inplace applied to the input tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : torch.Tensor
|
||||
Position indices, shape: ``(nnz)``.
|
||||
query : torch.Tensor
|
||||
Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
|
||||
key : torch.Tensor
|
||||
Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
|
||||
cos_sin_cache : torch.Tensor
|
||||
Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
|
||||
Cosine is the first half and Sine is the second half on rotary_dim.
|
||||
is_neox : bool
|
||||
Whether to use Neox style RoPE, default: ``True``.
|
||||
|
||||
* If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
|
||||
we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
|
||||
dimensions ``([..., head_dim//2:])``.
|
||||
|
||||
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
||||
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
||||
fused_set_kv_buffer_arg : FusedSetKVBufferArg
|
||||
Fuse the set-kv-buffer operation into this kernel
|
||||
|
||||
Note
|
||||
----
|
||||
The rotary dimension is determined by the cosine cache and sine cache.
|
||||
"""
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
if enable_pdl is None:
|
||||
# the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will
|
||||
enable_pdl = is_arch_support_pdl() and (fused_set_kv_buffer_arg is not None)
|
||||
|
||||
if (a := fused_set_kv_buffer_arg) is not None:
|
||||
assert a.k_scale is None, "k_scale is not yet supported"
|
||||
assert a.v_scale is None, "v_scale is not yet supported"
|
||||
assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"
|
||||
|
||||
def _view_3d(x):
|
||||
return x.view(x.shape[0], -1, head_size)
|
||||
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
|
||||
_view_3d(query),
|
||||
_view_3d(key),
|
||||
_view_3d(query),
|
||||
_view_3d(key),
|
||||
cos_sin_cache,
|
||||
positions.long(),
|
||||
(not is_neox),
|
||||
enable_pdl,
|
||||
get_cuda_stream(),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.value)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.k_buffer)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.v_buffer)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
fused_set_kv_buffer_arg.cache_loc
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downcast_fp8(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_out: torch.Tensor,
|
||||
v_out: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
mult: int = 1,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.downcast_fp8(
|
||||
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
|
||||
)
|
||||
|
||||
|
||||
def copy_to_gpu_no_ce(input: List[int], output: torch.Tensor):
|
||||
torch.ops.sgl_kernel.copy_to_gpu_no_ce(input, output)
|
||||
|
||||
|
||||
def concat_mla_k(
|
||||
k: torch.Tensor,
|
||||
k_nope: torch.Tensor,
|
||||
k_rope: torch.Tensor,
|
||||
):
|
||||
torch.ops.sgl_kernel.concat_mla_k(k, k_nope, k_rope)
|
||||
459
sgl-kernel/python/sgl_kernel/flash_attn.py
Normal file
459
sgl-kernel/python/sgl_kernel/flash_attn.py
Normal file
@@ -0,0 +1,459 @@
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# try:
|
||||
# from sgl_kernel import flash_ops
|
||||
# except:
|
||||
# raise ImportError("Can not import sgl_kernel. Please check your installation.")
|
||||
|
||||
try:
|
||||
from ._fa4_interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
|
||||
except ImportError:
|
||||
flash_attn_varlen_func_v4 = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_fa3_supported(device=None) -> bool:
|
||||
# There some fa3 FYI
|
||||
# FA3 can fail without a enough shared memory for a some shapes, such as higher
|
||||
# hidden_dim or some special cases.
|
||||
# Right now, fa3 is supported for sm80/sm87 and sm86/sm89. The main different
|
||||
# Between sm80/sm87 and sm86/sm89 is the shared memory size. you can follow the link below for more information
|
||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
|
||||
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
|
||||
return (torch.version.cuda >= "12.3") and (
|
||||
torch.cuda.get_device_capability(device)[0] == 9
|
||||
or torch.cuda.get_device_capability(device)[0] == 8
|
||||
)
|
||||
|
||||
|
||||
def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
|
||||
# def flash_attn_with_kvcache(
|
||||
# q,
|
||||
# k_cache,
|
||||
# v_cache,
|
||||
# k=None,
|
||||
# v=None,
|
||||
# qv=None,
|
||||
# rotary_cos=None,
|
||||
# rotary_sin=None,
|
||||
# cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
||||
# cache_batch_idx: Optional[torch.Tensor] = None,
|
||||
# cache_leftpad: Optional[torch.Tensor] = None,
|
||||
# page_table: Optional[torch.Tensor] = None,
|
||||
# cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
# cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
||||
# max_seqlen_q: Optional[int] = None,
|
||||
# rotary_seqlens: Optional[torch.Tensor] = None,
|
||||
# q_descale: Optional[torch.Tensor] = None,
|
||||
# k_descale: Optional[torch.Tensor] = None,
|
||||
# v_descale: Optional[torch.Tensor] = None,
|
||||
# softmax_scale=None,
|
||||
# causal=False,
|
||||
# window_size=(-1, -1), # -1 means infinite context window
|
||||
# softcap=0.0, # 0.0 means deactivated
|
||||
# rotary_interleaved=True,
|
||||
# scheduler_metadata=None,
|
||||
# num_splits=0, # Can be tuned for speed
|
||||
# pack_gqa=None, # Can be tuned for speed
|
||||
# sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
# return_softmax_lse=False,
|
||||
# sinks=None,
|
||||
# ver=3,
|
||||
# ):
|
||||
# """
|
||||
# If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
||||
# k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
|
||||
# the previous step, and update them with the new keys/values from the current step, and do
|
||||
# attention with the updated cache, all in 1 kernel.
|
||||
|
||||
# If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
|
||||
# For example, the KV cache could be pre-allocated with the max sequence length, and you can use
|
||||
# cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
|
||||
|
||||
# Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
|
||||
# rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
||||
# If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
|
||||
# and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
||||
# If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
|
||||
# indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
||||
|
||||
# See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
||||
|
||||
# Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
# than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
||||
# For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
# 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
# If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
||||
# For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
||||
# 1 1 1 1 0
|
||||
# 1 1 1 1 1
|
||||
# If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
||||
# 0 0
|
||||
# 0 0
|
||||
# 0 0
|
||||
# 1 0
|
||||
# 1 1
|
||||
# If the row of the mask is all zero, the output will be zero.
|
||||
|
||||
# If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
||||
# will only attend to keys between
|
||||
# [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
||||
|
||||
# Note: Does not support backward pass.
|
||||
|
||||
# Arguments:
|
||||
# q: (batch_size, seqlen, nheads, headdim)
|
||||
# k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
|
||||
# or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
|
||||
# page_block_size must be a multiple of 256.
|
||||
# v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
|
||||
# or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
|
||||
# k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
||||
# k with k_cache, starting at the indices specified by cache_seqlens.
|
||||
# v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
|
||||
# qv [optional]: (batch_size, seqlen, nheads, headdim_v)
|
||||
# rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
||||
# to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
||||
# rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
||||
# cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
||||
# KV cache.
|
||||
# cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
||||
# If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
||||
# If the indices are not distinct, and k and v are provided, the values updated in the cache
|
||||
# might come from any of the duplicate indices.
|
||||
# cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
|
||||
# page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
|
||||
# softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
# Default to 1 / sqrt(headdim).
|
||||
# causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
# window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
||||
# softcap: float. Anything > 0 activates softcapping attention.
|
||||
# rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
||||
# If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
||||
# rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
||||
# (i.e. GPT-NeoX style).
|
||||
# num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
|
||||
# If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
||||
# to automatically determine the number of splits.
|
||||
# Don't change this unless you know what you are doing.
|
||||
# return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
|
||||
|
||||
# Return:
|
||||
# out: (batch_size, seqlen, nheads, headdim).
|
||||
# softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
||||
# logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
# normalization factor).
|
||||
# """
|
||||
# if ver == 4:
|
||||
# raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4")
|
||||
|
||||
# assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||
# assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||
# if softmax_scale is None:
|
||||
# softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
|
||||
# -0.5
|
||||
# )
|
||||
# if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
||||
# cache_seqlens = torch.full(
|
||||
# (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
||||
# )
|
||||
# cache_seqlens = maybe_contiguous(cache_seqlens)
|
||||
|
||||
# q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
|
||||
# v_cache = (
|
||||
# v_cache.contiguous()
|
||||
# if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1
|
||||
# else v_cache
|
||||
# )
|
||||
# cu_seqlens_q, cu_seqlens_k_new = [
|
||||
# maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)
|
||||
# ]
|
||||
# page_table, cache_batch_idx, cache_leftpad = [
|
||||
# maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
|
||||
# ]
|
||||
# rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
|
||||
# rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
||||
|
||||
# if hasattr(torch.version, 'hip') and torch.version.hip is not None:
|
||||
# # HIP环境回退
|
||||
# from flash_attn import flash_attn_with_kvcache as fa_with_kv
|
||||
# out, softmax_lse, *rest = fa_with_kv(
|
||||
# q, k, v, k_cache, v_cache, cache_seqlens, cache_batch_idx,
|
||||
# block_tables, softmax_scale, causal, alibi_slopes, out
|
||||
# )
|
||||
# else:
|
||||
# out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
|
||||
# q,
|
||||
# k_cache,
|
||||
# v_cache,
|
||||
# k,
|
||||
# v,
|
||||
# qv,
|
||||
# None, # out
|
||||
# cu_seqlens_q,
|
||||
# None, # cu_seqlens_k
|
||||
# cu_seqlens_k_new,
|
||||
# None, # seqused_q
|
||||
# cache_seqlens,
|
||||
# max_seqlen_q,
|
||||
# None, # max_seqlen_k
|
||||
# page_table,
|
||||
# cache_batch_idx,
|
||||
# cache_leftpad,
|
||||
# rotary_cos,
|
||||
# rotary_sin,
|
||||
# rotary_seqlens,
|
||||
# q_descale,
|
||||
# k_descale,
|
||||
# v_descale,
|
||||
# softmax_scale,
|
||||
# causal,
|
||||
# window_size[0],
|
||||
# window_size[1],
|
||||
# softcap,
|
||||
# rotary_interleaved,
|
||||
# scheduler_metadata,
|
||||
# num_splits,
|
||||
# pack_gqa,
|
||||
# sm_margin,
|
||||
# sinks,
|
||||
# )
|
||||
# return (out, softmax_lse) if return_softmax_lse else out
|
||||
def flash_attn_with_kvcache(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k=None,
|
||||
v=None,
|
||||
qv=None,
|
||||
rotary_cos=None,
|
||||
rotary_sin=None,
|
||||
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
||||
cache_batch_idx: Optional[torch.Tensor] = None,
|
||||
cache_leftpad: Optional[torch.Tensor] = None,
|
||||
page_table: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k_new: Optional[torch.Tensor] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
rotary_seqlens: Optional[torch.Tensor] = None,
|
||||
q_descale: Optional[torch.Tensor] = None,
|
||||
k_descale: Optional[torch.Tensor] = None,
|
||||
v_descale: Optional[torch.Tensor] = None,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
rotary_interleaved=True,
|
||||
scheduler_metadata=None,
|
||||
num_splits=0, # Can be tuned for speed
|
||||
pack_gqa=None, # Can be tuned for speed
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
return_softmax_lse=False,
|
||||
sinks=None,
|
||||
ver=3,
|
||||
):
|
||||
if ver == 4:
|
||||
raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4")
|
||||
|
||||
# HIP环境检测和回退
|
||||
# if hasattr(torch.version, 'hip') and torch.version.hip is not None:
|
||||
# # 简单PyTorch回退,处理实际的张量形状
|
||||
# # q: [1, 4, 256], k_cache: [411528, 1, 1, 256], v_cache: [411528, 1, 1, 256]
|
||||
|
||||
# if softmax_scale is None:
|
||||
# softmax_scale = (q.shape[-1]) ** (-0.5)
|
||||
|
||||
# # 重塑以匹配attention计算
|
||||
# q_reshaped = q.unsqueeze(1) # [1, 1, 4, 256]
|
||||
# k_reshaped = k_cache.squeeze(1).squeeze(1) # [411528, 256]
|
||||
# v_reshaped = v_cache.squeeze(1).squeeze(1) # [411528, 256]
|
||||
|
||||
# # 简单的点积attention
|
||||
# scores = torch.matmul(q, k_reshaped.T) * softmax_scale # [1, 4, 411528]
|
||||
# attn_weights = torch.softmax(scores, dim=-1)
|
||||
# out = torch.matmul(attn_weights, v_reshaped) # [1, 4, 256]
|
||||
|
||||
# if return_softmax_lse:
|
||||
# softmax_lse = torch.zeros(1, 4, 1, device=q.device)
|
||||
# return out, softmax_lse
|
||||
# return out
|
||||
|
||||
# 原始sgl_kernel实现
|
||||
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||
if softmax_scale is None:
|
||||
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
|
||||
if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
||||
cache_seqlens = torch.full(
|
||||
(k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
||||
)
|
||||
cache_seqlens = maybe_contiguous(cache_seqlens)
|
||||
|
||||
q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
|
||||
v_cache = (
|
||||
v_cache.contiguous()
|
||||
if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1
|
||||
else v_cache
|
||||
)
|
||||
cu_seqlens_q, cu_seqlens_k_new = [
|
||||
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)
|
||||
]
|
||||
page_table, cache_batch_idx, cache_leftpad = [
|
||||
maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
|
||||
]
|
||||
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
|
||||
rotary_seqlens = maybe_contiguous(rotary_seqlens)
|
||||
|
||||
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
k,
|
||||
v,
|
||||
qv,
|
||||
None, # out
|
||||
cu_seqlens_q,
|
||||
None, # cu_seqlens_k
|
||||
cu_seqlens_k_new,
|
||||
None, # seqused_q
|
||||
cache_seqlens,
|
||||
max_seqlen_q,
|
||||
None, # max_seqlen_k
|
||||
page_table,
|
||||
cache_batch_idx,
|
||||
cache_leftpad,
|
||||
rotary_cos,
|
||||
rotary_sin,
|
||||
rotary_seqlens,
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
rotary_interleaved,
|
||||
scheduler_metadata,
|
||||
num_splits,
|
||||
pack_gqa,
|
||||
sm_margin,
|
||||
sinks,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
|
||||
|
||||
def flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqused_q=None,
|
||||
seqused_k=None,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=(-1, -1),
|
||||
softcap=0.0,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
sm_margin=0,
|
||||
return_softmax_lse=False,
|
||||
sinks=None,
|
||||
ver=3,
|
||||
):
|
||||
if ver == 4:
|
||||
assert (
|
||||
flash_attn_varlen_func_v4 is not None
|
||||
), "FA4 is not available, please check your installation."
|
||||
# Using `(-1, -1)` as no sliding window causes correctness issues for FA4.
|
||||
if window_size == (-1, -1):
|
||||
window_size = (None, None)
|
||||
return flash_attn_varlen_func_v4(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
# max_seqlen_q,
|
||||
# max_seqlen_k,
|
||||
seqused_q=seqused_q,
|
||||
seqused_k=seqused_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
# qv=qv,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
# num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
# sm_margin=sm_margin,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
learnable_sink=sinks,
|
||||
)
|
||||
|
||||
if not is_fa3_supported():
|
||||
raise NotImplementedError(
|
||||
"flash_attn at sgl-kernel is only supported on sm90 and above"
|
||||
)
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (
|
||||
-0.5
|
||||
)
|
||||
|
||||
out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None, # k_new
|
||||
None, # v_new
|
||||
qv, # qv
|
||||
None, # out
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None, # cu_seqlens_k_new
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
None, # page_table,
|
||||
None, # kv_batch_idx
|
||||
None, # leftpad_k
|
||||
None, # rotary cos
|
||||
None, # rotary sin
|
||||
None, # seqlens_rotary
|
||||
q_descale,
|
||||
k_descale,
|
||||
v_descale,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
is_rotary_interleaved=False,
|
||||
scheduler_metadata=None,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
sm_margin=sm_margin,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
||||
225
sgl-kernel/python/sgl_kernel/fused_moe.py
Normal file
225
sgl-kernel/python/sgl_kernel/fused_moe.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
|
||||
def get_scalar_type(num_bits: int, has_zp: bool):
|
||||
from sgl_kernel.scalar_type import scalar_types
|
||||
|
||||
if has_zp:
|
||||
assert num_bits == 4
|
||||
return scalar_types.uint4
|
||||
else:
|
||||
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
|
||||
|
||||
|
||||
def fused_marlin_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
workspace: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- w1_scale (torch.Tensor): Scale to be used for w1.
|
||||
- w2_scale (torch.Tensor): Scale to be used for w2.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
|
||||
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
|
||||
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
|
||||
permutation.
|
||||
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
|
||||
permutation.
|
||||
- topk_weights (torch.Tensor): Top-k weights.
|
||||
- topk_ids (torch.Tensor): Indices of topk-k elements.
|
||||
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
|
||||
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
|
||||
- num_bits (bool): The number of bits in expert weights quantization.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
moe_align_block_size,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
|
||||
assert hidden_states.shape[1] == w2.shape[2] // (
|
||||
num_bits // 2
|
||||
), "Hidden size mismatch w2"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [torch.float16, torch.bfloat16]
|
||||
assert num_bits in [4, 8]
|
||||
|
||||
M, K = hidden_states.shape
|
||||
E = w1.shape[0]
|
||||
N = w2.shape[1] * 16
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
None,
|
||||
is_marlin=True,
|
||||
)
|
||||
config = get_config_func(M)
|
||||
|
||||
block_size_m = config["BLOCK_SIZE_M"]
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, block_size_m, global_num_experts
|
||||
)
|
||||
|
||||
if workspace is None:
|
||||
max_workspace_size = (max(2 * N, K) // 64) * (
|
||||
sorted_token_ids.size(0) // block_size_m
|
||||
)
|
||||
device = hidden_states.device
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
max_workspace_size = min(max_workspace_size, sms * 4)
|
||||
workspace = torch.zeros(
|
||||
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
|
||||
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
|
||||
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
|
||||
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache13 = torch.empty(
|
||||
(M * topk_ids.shape[1] * max(2 * N, K),),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache1 = intermediate_cache13[: M * topk_ids.shape[1] * 2 * N]
|
||||
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
|
||||
intermediate_cache3 = intermediate_cache13[: M * topk_ids.shape[1] * K]
|
||||
intermediate_cache3 = intermediate_cache3.view(-1, K)
|
||||
|
||||
use_atomic_add = (
|
||||
hidden_states.dtype == torch.half
|
||||
or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
|
||||
)
|
||||
|
||||
intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
|
||||
hidden_states,
|
||||
intermediate_cache1,
|
||||
w1,
|
||||
w1_scale,
|
||||
w1_zeros,
|
||||
g_idx1,
|
||||
sort_indices1,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=topk,
|
||||
mul_topk_weights=False,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type_id=scalar_type1.id,
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)
|
||||
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
intermediate_cache3 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
|
||||
intermediate_cache2,
|
||||
intermediate_cache3,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_zeros,
|
||||
g_idx2,
|
||||
sort_indices2,
|
||||
workspace,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
topk_weights,
|
||||
moe_block_size=block_size_m,
|
||||
top_k=1,
|
||||
mul_topk_weights=True,
|
||||
is_ep=expert_map is not None,
|
||||
b_q_type_id=scalar_type2.id,
|
||||
size_m=M * topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
).view(-1, topk, K)
|
||||
|
||||
output = hidden_states if inplace else torch.empty_like(hidden_states)
|
||||
return torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=output
|
||||
)
|
||||
|
||||
|
||||
def fused_marlin_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
g_idx1: Optional[torch.Tensor] = None,
|
||||
g_idx2: Optional[torch.Tensor] = None,
|
||||
sort_indices1: Optional[torch.Tensor] = None,
|
||||
sort_indices2: Optional[torch.Tensor] = None,
|
||||
w1_zeros: Optional[torch.Tensor] = None,
|
||||
w2_zeros: Optional[torch.Tensor] = None,
|
||||
num_bits: int = 8,
|
||||
is_k_full: bool = True,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
550
sgl-kernel/python/sgl_kernel/gemm.py
Normal file
550
sgl-kernel/python/sgl_kernel/gemm.py
Normal file
@@ -0,0 +1,550 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from sgl_kernel.scalar_type import ScalarType
|
||||
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
||||
|
||||
|
||||
def awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.ByteTensor:
|
||||
return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros)
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernel.int8_scaled_mm.default(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
||||
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
)
|
||||
|
||||
|
||||
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernel.fp8_scaled_mm.default(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
out_dtype,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
def _bmm_fp8_internal(
|
||||
workspace_buffer: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
D: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
) -> None:
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
torch.ops.sgl_kernel.bmm_fp8.default(
|
||||
A,
|
||||
B,
|
||||
D,
|
||||
A_scale,
|
||||
B_scale,
|
||||
workspace_buffer,
|
||||
cublas_handle,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def bmm_fp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty(
|
||||
(A.shape[0], A.shape[1], B.shape[2]),
|
||||
device=A.device,
|
||||
dtype=dtype,
|
||||
)
|
||||
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
|
||||
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
|
||||
return out
|
||||
|
||||
|
||||
def dsv3_fused_a_gemm(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is None:
|
||||
output = torch.empty(
|
||||
(mat_a.shape[0], mat_b.shape[1]),
|
||||
device=mat_a.device,
|
||||
dtype=mat_a.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b)
|
||||
return output
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float,
|
||||
fp8_min: float,
|
||||
fp8_max: float,
|
||||
scale_ue8m0: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
|
||||
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_int8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float,
|
||||
int8_min: float,
|
||||
int8_max: float,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
|
||||
input, output_q, output_s, group_size, eps, int8_min, int8_max
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_tensor_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
is_static: bool,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
|
||||
input, output_q, output_s, is_static
|
||||
)
|
||||
|
||||
|
||||
def sgl_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
output_s: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
|
||||
|
||||
|
||||
def cutlass_scaled_fp4_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
block_scale_a: torch.Tensor,
|
||||
block_scale_b: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
m, n = a.shape[0], b.shape[0]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default(
|
||||
out, a, b, block_scale_a, block_scale_b, alpha
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def scaled_fp4_quant(
|
||||
input: torch.Tensor, input_global_scale: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale.
|
||||
|
||||
This function quantizes the last dimension of the given tensor `input`. For
|
||||
every 16 consecutive elements, a single dynamically computed scaling factor
|
||||
is shared. This scaling factor is quantized using the `input_global_scale`
|
||||
and is stored in a swizzled layout (see
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4
|
||||
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
|
||||
two values are packed into a uint8 and float8_e4m3 scaling factors
|
||||
in a sizzled layout.
|
||||
"""
|
||||
assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}."
|
||||
other_dims = 1 if input.ndim == 1 else -1
|
||||
input = input.reshape(other_dims, input.shape[-1])
|
||||
m, n = input.shape
|
||||
block_size = 16
|
||||
device = input.device
|
||||
|
||||
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
|
||||
assert input.dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
|
||||
|
||||
# Two fp4 values will be packed into an uint8.
|
||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
|
||||
# We use the rounded values to store the swizzled values. Then, the scaling
|
||||
# factors in float8_e4m3fn are packed into an int32 for every 4 values.
|
||||
rounded_m = ((m + 128 - 1) // 128) * 128
|
||||
scale_n = n // block_size
|
||||
rounded_n = ((scale_n + 4 - 1) // 4) * 4
|
||||
# padded part should be zeroed out
|
||||
if rounded_n > scale_n:
|
||||
output_scale = torch.zeros(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
output_scale = torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.scaled_fp4_quant.default(
|
||||
output, input, output_scale, input_global_scale
|
||||
)
|
||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||
return output, output_scale
|
||||
|
||||
|
||||
def qserve_w4a8_per_chn_gemm(
|
||||
in_feats: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
wscales: torch.Tensor,
|
||||
ascales: torch.Tensor,
|
||||
w_szs: torch.Tensor,
|
||||
a_ssums: torch.Tensor,
|
||||
out_feats: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out_feats is None:
|
||||
# NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now
|
||||
out_feats = torch.empty(
|
||||
(in_feats.shape[0], kernel.shape[0]),
|
||||
device=in_feats.device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default(
|
||||
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats
|
||||
)
|
||||
return out_feats
|
||||
|
||||
|
||||
def qserve_w4a8_per_group_gemm(
|
||||
in_feats: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
zeros: torch.Tensor,
|
||||
scales_i8: torch.Tensor,
|
||||
wscales: torch.Tensor,
|
||||
ascales: torch.Tensor,
|
||||
out_feats: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out_feats is None:
|
||||
# NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now
|
||||
out_feats = torch.empty(
|
||||
(in_feats.shape[0], kernel.shape[0]),
|
||||
device=in_feats.device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default(
|
||||
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
|
||||
)
|
||||
return out_feats
|
||||
|
||||
|
||||
def dsv3_router_gemm(
|
||||
hidden_states: torch.Tensor,
|
||||
router_weights: torch.Tensor,
|
||||
out_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
router_weights.shape[0],
|
||||
device=hidden_states.device,
|
||||
dtype=out_dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.dsv3_router_gemm(
|
||||
output,
|
||||
hidden_states,
|
||||
router_weights,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
|
||||
output_tensor = torch.empty(
|
||||
output_tensor_shape,
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def scaled_fp4_grouped_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4, with shape (l, m, k)
|
||||
l is number of groups, m is number of tokens per group, k is number of features.
|
||||
input_global_scale: A scalar scaling factor for the entire tensor, with
|
||||
shape (l,).
|
||||
Outputs:
|
||||
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
|
||||
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
|
||||
an uint8.
|
||||
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
|
||||
but the physical layout is (l, rm, rk, 32, 4, 4).
|
||||
Note:
|
||||
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
|
||||
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
|
||||
required by the NVIDIA Blackwell MMA operations.
|
||||
"""
|
||||
device = input_tensor.device
|
||||
l, m, k = input_tensor.shape
|
||||
sf_vec_size = 16
|
||||
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."
|
||||
|
||||
scale_k = k // sf_vec_size
|
||||
padded_k = (scale_k + (4 - 1)) // 4 * 4
|
||||
padded_k_int32 = padded_k // 4
|
||||
padded_m = (m + (128 - 1)) // 128 * 128
|
||||
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
|
||||
output_scales = torch.empty(
|
||||
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
|
||||
output.view(l * m, k // 2),
|
||||
output_scales.view(l * padded_m, padded_k_int32),
|
||||
input_tensor.view(l * m, k),
|
||||
input_global_scale,
|
||||
mask,
|
||||
use_silu_and_mul=False,
|
||||
)
|
||||
# The physical layout of the output is (l, m, k // 2), but we want to return a
|
||||
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
|
||||
output = output.permute(1, 2, 0)
|
||||
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
|
||||
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
|
||||
# layout is (32, 4, rm, 4, rk, l).
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn).view(
|
||||
l, padded_m // 128, padded_k // 4, 32, 4, 4
|
||||
)
|
||||
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
|
||||
return output, output_scales
|
||||
|
||||
|
||||
def silu_and_mul_scaled_fp4_grouped_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4, with shape (l, m, k * 2)
|
||||
l is number of groups, m is number of tokens per group, k is number of features.
|
||||
input_global_scale: A scalar scaling factor for the entire tensor, with
|
||||
shape (l,).
|
||||
mask: The mask tensor, with shape (l,)
|
||||
Outputs:
|
||||
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
|
||||
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
|
||||
an uint8.
|
||||
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
|
||||
but the physical layout is (l, rm, rk, 32, 4, 4).
|
||||
Note:
|
||||
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
|
||||
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
|
||||
required by the NVIDIA Blackwell MMA operations.
|
||||
"""
|
||||
device = input_tensor.device
|
||||
l, m, k_by_2 = input_tensor.shape
|
||||
k = k_by_2 // 2
|
||||
sf_vec_size = 16
|
||||
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."
|
||||
|
||||
scale_k = k // sf_vec_size
|
||||
padded_k = (scale_k + (4 - 1)) // 4 * 4
|
||||
padded_k_int32 = padded_k // 4
|
||||
padded_m = (m + (128 - 1)) // 128 * 128
|
||||
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
|
||||
output_scales = torch.empty(
|
||||
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default(
|
||||
output.view(l * m, k // 2),
|
||||
output_scales.view(l * padded_m, padded_k_int32),
|
||||
input_tensor.view(l * m, k_by_2),
|
||||
input_global_scale,
|
||||
mask,
|
||||
use_silu_and_mul=True,
|
||||
)
|
||||
# The physical layout of the output is (l, m, k // 2), but we want to return a
|
||||
# logical layout (m, k // 2, l) required by the flashinfer masked group gemm.
|
||||
output = output.permute(1, 2, 0)
|
||||
# The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a
|
||||
# requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic
|
||||
# layout is (32, 4, rm, 4, rk, l).
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn).view(
|
||||
l, padded_m // 128, padded_k // 4, 32, 4, 4
|
||||
)
|
||||
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
|
||||
return output, output_scales
|
||||
|
||||
|
||||
def scaled_fp4_experts_quant(
|
||||
input_tensor: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
blockscale_offsets: torch.Tensor,
|
||||
topk: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
packed MoE Inputs.
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4
|
||||
expert_map: The expert map tensor
|
||||
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||
expert_offsets: The expert offsets tensor
|
||||
blockscale_offsets: The blockscale offsets tensor
|
||||
Outputs:
|
||||
output: The quantized tensor in FP4
|
||||
output_scales: The blockscale tensor in FP8-E4M3
|
||||
"""
|
||||
assert (
|
||||
input_tensor.ndim == 2
|
||||
), f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
|
||||
if expert_map is not None:
|
||||
(m, k) = input_tensor.shape
|
||||
output_tensor_shape = (m * topk, k)
|
||||
input_tensor = shuffle_rows(input_tensor, expert_map, output_tensor_shape)
|
||||
m_numtopk, k = input_tensor.shape
|
||||
# Control the maximum number of tokens per expert supported by the
|
||||
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
|
||||
# from running out of memory. This value can also be increased to support
|
||||
# larger models.
|
||||
import os
|
||||
|
||||
MAX_TOKENS_PER_EXPERT = os.environ.get("MODELOPT_MAX_TOKENS_PER_EXPERT", 65536)
|
||||
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
|
||||
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
|
||||
f"{MAX_TOKENS_PER_EXPERT})"
|
||||
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
|
||||
f" MODELOPT_MAX_TOKENS_PER_EXPERT to set this value."
|
||||
)
|
||||
scales_k = k // 16
|
||||
padded_k = (scales_k + (4 - 1)) // 4
|
||||
|
||||
# output is uint8 and packed fp4 values
|
||||
output = torch.empty(
|
||||
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
|
||||
)
|
||||
# padded part should be zeroed out
|
||||
if padded_k > scales_k:
|
||||
output_scales = torch.zeros(
|
||||
MAX_TOKENS_PER_EXPERT * topk,
|
||||
padded_k,
|
||||
dtype=torch.int32,
|
||||
device=input_tensor.device,
|
||||
)
|
||||
else:
|
||||
output_scales = torch.empty(
|
||||
MAX_TOKENS_PER_EXPERT * topk,
|
||||
padded_k,
|
||||
dtype=torch.int32,
|
||||
device=input_tensor.device,
|
||||
)
|
||||
torch.ops.sgl_kernel.scaled_fp4_experts_quant.default(
|
||||
output,
|
||||
output_scales,
|
||||
input_tensor,
|
||||
input_global_scale,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
)
|
||||
output_scales = output_scales.view(torch.float8_e4m3fn)
|
||||
return output, output_scales
|
||||
|
||||
|
||||
# GPTQ kernels
|
||||
def gptq_marlin_gemm(
|
||||
a: torch.Tensor,
|
||||
c: Optional[torch.Tensor],
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
global_scale: Optional[torch.Tensor],
|
||||
b_zeros: Optional[torch.Tensor],
|
||||
g_idx: Optional[torch.Tensor],
|
||||
perm: Optional[torch.Tensor],
|
||||
workspace: torch.Tensor,
|
||||
b_q_type: ScalarType,
|
||||
size_m: int,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
is_k_full: bool = True,
|
||||
use_atomic_add: bool = False,
|
||||
use_fp32_reduce: bool = False,
|
||||
is_zp_float: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_marlin_gemm(
|
||||
a,
|
||||
c,
|
||||
b_q_weight,
|
||||
b_scales,
|
||||
global_scale,
|
||||
b_zeros,
|
||||
g_idx,
|
||||
perm,
|
||||
workspace,
|
||||
b_q_type.id,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
is_zp_float,
|
||||
)
|
||||
|
||||
|
||||
def gptq_gemm(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_gptq_qzeros: torch.Tensor,
|
||||
b_gptq_scales: torch.Tensor,
|
||||
b_g_idx: torch.Tensor,
|
||||
use_shuffle: bool,
|
||||
bit: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_gemm(
|
||||
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
|
||||
)
|
||||
|
||||
|
||||
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
|
||||
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)
|
||||
15
sgl-kernel/python/sgl_kernel/grammar.py
Normal file
15
sgl-kernel/python/sgl_kernel/grammar.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def apply_token_bitmask_inplace_cuda(
|
||||
logits: torch.Tensor,
|
||||
bitmask: torch.Tensor,
|
||||
indices: Optional[Union[List[int], torch.Tensor]] = None,
|
||||
) -> None:
|
||||
if isinstance(indices, list):
|
||||
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
|
||||
if indices is not None:
|
||||
indices = indices.to(logits.device)
|
||||
torch.ops.sgl_kernel.apply_token_bitmask_inplace_cuda(logits, bitmask, indices)
|
||||
243
sgl-kernel/python/sgl_kernel/kvcacheio.py
Normal file
243
sgl-kernel/python/sgl_kernel/kvcacheio.py
Normal file
@@ -0,0 +1,243 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
def transfer_kv_per_layer(
|
||||
src_k: torch.Tensor,
|
||||
dst_k: torch.Tensor,
|
||||
src_v: torch.Tensor,
|
||||
dst_v: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer(
|
||||
src_k,
|
||||
dst_k,
|
||||
src_v,
|
||||
dst_v,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_per_layer_pf_lf(
|
||||
src_k: torch.Tensor,
|
||||
dst_k: torch.Tensor,
|
||||
src_v: torch.Tensor,
|
||||
dst_v: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
layer_id: int,
|
||||
item_size: int,
|
||||
src_layout_dim: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
|
||||
src_k,
|
||||
dst_k,
|
||||
src_v,
|
||||
dst_v,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
layer_id,
|
||||
item_size,
|
||||
src_layout_dim,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_all_layer(
|
||||
src_k_layers: torch.Tensor,
|
||||
dst_k_layers: torch.Tensor,
|
||||
src_v_layers: torch.Tensor,
|
||||
dst_v_layers: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer(
|
||||
src_k_layers,
|
||||
dst_k_layers,
|
||||
src_v_layers,
|
||||
dst_v_layers,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
num_layers,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_all_layer_lf_pf(
|
||||
src_k_layers: torch.Tensor,
|
||||
dst_k: torch.Tensor,
|
||||
src_v_layers: torch.Tensor,
|
||||
dst_v: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
dst_layout_dim: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
|
||||
src_k_layers,
|
||||
dst_k,
|
||||
src_v_layers,
|
||||
dst_v,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
dst_layout_dim,
|
||||
num_layers,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_direct(
|
||||
src_layers: List[torch.Tensor],
|
||||
dst_layers: List[torch.Tensor],
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
page_size: int,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_direct(
|
||||
src_layers, dst_layers, src_indices, dst_indices, page_size
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_per_layer_direct_pf_lf(
|
||||
src_ptrs: List[torch.Tensor],
|
||||
dst_ptrs: List[torch.Tensor],
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
layer_id: int,
|
||||
page_size: int,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_direct_pf_lf(
|
||||
src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_all_layer_direct_lf_pf(
|
||||
src_ptrs: List[torch.Tensor],
|
||||
dst_ptrs: List[torch.Tensor],
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
page_size: int,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_direct_lf_pf(
|
||||
src_ptrs, dst_ptrs, src_indices, dst_indices, page_size
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_per_layer_mla(
|
||||
src: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
|
||||
src,
|
||||
dst,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_per_layer_mla_pf_lf(
|
||||
src: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
layer_id: int,
|
||||
item_size: int,
|
||||
src_layout_dim: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
|
||||
src,
|
||||
dst,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
layer_id,
|
||||
item_size,
|
||||
src_layout_dim,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_all_layer_mla(
|
||||
src_layers: torch.Tensor,
|
||||
dst_layers: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
|
||||
src_layers,
|
||||
dst_layers,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
num_layers,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
|
||||
|
||||
def transfer_kv_all_layer_mla_lf_pf(
|
||||
src_layers: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
src_indices: torch.Tensor,
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
dst_layout_dim: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
|
||||
src_layers,
|
||||
dst,
|
||||
src_indices,
|
||||
dst_indices,
|
||||
item_size,
|
||||
dst_layout_dim,
|
||||
num_layers,
|
||||
block_quota,
|
||||
num_warps_per_block,
|
||||
)
|
||||
50
sgl-kernel/python/sgl_kernel/mamba.py
Normal file
50
sgl-kernel/python/sgl_kernel/mamba.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# mamba
|
||||
def causal_conv1d_fwd(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
conv_states: Optional[torch.Tensor],
|
||||
query_start_loc: Optional[torch.Tensor],
|
||||
cache_indices: Optional[torch.Tensor],
|
||||
has_initial_state: Optional[torch.Tensor],
|
||||
silu_activation: bool,
|
||||
pad_slot_id: int,
|
||||
):
|
||||
torch.ops.sgl_kernel.causal_conv1d_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias_,
|
||||
conv_states,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
silu_activation,
|
||||
pad_slot_id,
|
||||
)
|
||||
|
||||
|
||||
def causal_conv1d_update(
|
||||
x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
silu_activation: bool,
|
||||
cache_seqlens: Optional[torch.Tensor],
|
||||
conv_state_indices: Optional[torch.Tensor],
|
||||
pad_slot_id: int,
|
||||
):
|
||||
torch.ops.sgl_kernel.causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias_,
|
||||
silu_activation,
|
||||
cache_seqlens,
|
||||
conv_state_indices,
|
||||
pad_slot_id,
|
||||
)
|
||||
44
sgl-kernel/python/sgl_kernel/marlin.py
Normal file
44
sgl-kernel/python/sgl_kernel/marlin.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
|
||||
|
||||
def gptq_marlin_repack(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.gptq_marlin_repack(
|
||||
b_q_weight,
|
||||
perm,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
)
|
||||
|
||||
|
||||
def awq_marlin_repack(
|
||||
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.sgl_kernel.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
||||
|
||||
|
||||
def awq_marlin_moe_repack(
|
||||
b_q_weight: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
) -> torch.Tensor:
|
||||
num_experts = b_q_weight.shape[0]
|
||||
assert size_k % 16 == 0
|
||||
output = torch.empty(
|
||||
(num_experts, size_k // 16, size_n * (num_bits // 2)),
|
||||
device=b_q_weight.device,
|
||||
dtype=b_q_weight.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = torch.ops.sgl_kernel.awq_marlin_repack(
|
||||
b_q_weight[e], size_k, size_n, num_bits
|
||||
)
|
||||
return output
|
||||
18
sgl-kernel/python/sgl_kernel/memory.py
Normal file
18
sgl-kernel/python/sgl_kernel/memory.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
|
||||
def set_kv_buffer_kernel(
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
fallback: bool = False,
|
||||
):
|
||||
try:
|
||||
if fallback:
|
||||
raise RuntimeError("Fallback to torch implementation")
|
||||
torch.ops.sgl_kernel.store_kv_cache(k_cache, v_cache, loc, k, v)
|
||||
except RuntimeError: # ok, fallback to torch implementation
|
||||
k_cache[loc] = k
|
||||
v_cache[loc] = v
|
||||
197
sgl-kernel/python/sgl_kernel/moe.py
Executable file
197
sgl-kernel/python/sgl_kernel/moe.py
Executable file
@@ -0,0 +1,197 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
cumsum_buffer,
|
||||
pad_sorted_token_ids=False,
|
||||
):
|
||||
torch.ops.sgl_kernel.moe_align_block_size.default(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_token_ids,
|
||||
experts_ids,
|
||||
num_tokens_post_pad,
|
||||
cumsum_buffer,
|
||||
pad_sorted_token_ids,
|
||||
)
|
||||
|
||||
|
||||
def topk_softmax(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
gating_output: float,
|
||||
renormalize: bool = False,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.topk_softmax.default(
|
||||
topk_weights, topk_ids, gating_output, renormalize
|
||||
)
|
||||
|
||||
|
||||
def moe_fused_gate(
|
||||
input_tensor,
|
||||
bias,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts=0,
|
||||
routed_scaling_factor=0,
|
||||
apply_routed_scaling_factor_on_output=False,
|
||||
):
|
||||
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
|
||||
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
|
||||
# as the group weight to select expert groups and then select topk experts within the selected groups
|
||||
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
|
||||
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limited for now.
|
||||
# for non-supported case, we suggest to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
|
||||
# num_fused_shared_experts: if > 0, the last several experts will be
|
||||
# replaced with shared experts. the shared experts will be divided by the
|
||||
# routed_scaling_factor - this is intended to cancel out later when routed+shared
|
||||
# output is scaled so that shared experts are not scaled.
|
||||
# routed_scaling_factor: if > 0, the experts will be scaled by this factor
|
||||
# apply_routed_scaling_factor_on_output: if true, output will be
|
||||
# scaled by the routed_scaling_factor
|
||||
return torch.ops.sgl_kernel.moe_fused_gate.default(
|
||||
input_tensor,
|
||||
bias,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
apply_routed_scaling_factor_on_output,
|
||||
)
|
||||
|
||||
|
||||
def fp8_blockwise_scaled_grouped_mm(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace,
|
||||
):
|
||||
torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default(
|
||||
output,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
scales_a,
|
||||
scales_b,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
expert_offsets,
|
||||
workspace,
|
||||
)
|
||||
|
||||
|
||||
def prepare_moe_input(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
blockscale_offsets: Optional[torch.Tensor] = None,
|
||||
):
|
||||
torch.ops.sgl_kernel.prepare_moe_input.default(
|
||||
topk_ids,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
input_permutation,
|
||||
output_permutation,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
|
||||
|
||||
def apply_shuffle_mul_sum(
|
||||
input,
|
||||
output,
|
||||
permutation,
|
||||
factors,
|
||||
):
|
||||
torch.ops.sgl_kernel.apply_shuffle_mul_sum.default(
|
||||
input, output, permutation, factors
|
||||
)
|
||||
|
||||
|
||||
def cutlass_fp4_group_mm(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_blockscale,
|
||||
b_blockscale,
|
||||
alphas,
|
||||
out_dtype,
|
||||
device,
|
||||
params: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
||||
the gemms for each combination based on the specified problem sizes.
|
||||
|
||||
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
|
||||
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
|
||||
input and expert weights.
|
||||
- a_/b_scales: The blockscales in FP8-E4M3 precision
|
||||
- ab_strides/c_strides: Strides for the a/b tensors between rows.
|
||||
- expert_offsets/sf_offsets: Indices that mark at which token index
|
||||
each expert begins its computation. The number of tokens
|
||||
computed with expert E is expert_offsets[E + 1] -
|
||||
expert_offsets[E] And the sf_size per expert is
|
||||
sf_offset[E+1] - sf_offset[E]
|
||||
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
||||
MMs used in the fused MoE operation.
|
||||
"""
|
||||
m_topk = a_fp4.shape[0]
|
||||
n = b_fp4.shape[1]
|
||||
c_shape = (m_topk, n)
|
||||
c = torch.empty(c_shape, device=device, dtype=out_dtype)
|
||||
torch.ops.sgl_kernel.cutlass_fp4_group_mm.default(
|
||||
c,
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_blockscale,
|
||||
b_blockscale,
|
||||
alphas,
|
||||
params["ab_strides"],
|
||||
params["c_strides"],
|
||||
params["problem_sizes"],
|
||||
params["expert_offsets"],
|
||||
params["blockscale_offsets"],
|
||||
)
|
||||
return c.to(dtype=out_dtype)
|
||||
543
sgl-kernel/python/sgl_kernel/sampling.py
Normal file
543
sgl-kernel/python/sgl_kernel/sampling.py
Normal file
@@ -0,0 +1,543 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import _to_tensor_scalar_tuple
|
||||
|
||||
|
||||
def _top_k_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
) -> torch.Tensor:
|
||||
probs = probs.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernel.top_k_renorm_probs.default(
|
||||
probs, renorm_probs, maybe_top_k_arr, top_k_val
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_k_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for renormalizing probabilities by top-k thresholding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
probs: torch.Tensor
|
||||
Probabilities, shape ``(batch_size, num_classes)``.
|
||||
top_k: Union[torch.Tensor, int]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
|
||||
for re-normalizing probabilities, should be in ``(0, num_classes)``.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
We keep the top-k probabilities, set the rest to zero, and renormalize the probabilities.
|
||||
|
||||
Returns
|
||||
-------
|
||||
renorm_probs: torch.Tensor
|
||||
Renormalized probabilities, shape ``(batch_size, num_classes)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This combination of ``top_k_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
|
||||
``top_k_sampling_from_probs``.
|
||||
"""
|
||||
return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k))
|
||||
|
||||
|
||||
top_k_renorm_prob = top_k_renorm_probs
|
||||
|
||||
|
||||
def _top_p_renorm_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
) -> torch.Tensor:
|
||||
probs = probs.float()
|
||||
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
renorm_probs = torch.empty_like(probs)
|
||||
torch.ops.sgl_kernel.top_p_renorm_probs.default(
|
||||
probs, renorm_probs, maybe_top_p_arr, top_p_val
|
||||
)
|
||||
return renorm_probs
|
||||
|
||||
|
||||
def top_p_renorm_probs(
|
||||
probs: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for renormalizing probabilities by top-p thresholding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
probs: torch.Tensor
|
||||
Probabilities, shape ``(batch_size, num_classes)``.
|
||||
top_p: Union[torch.Tensor, float]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-p threshold for for
|
||||
re-normalizing probabilities, should be in ``(0, 1)``.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
We mask out the probabilities less than `threshold` where the cumulative sum
|
||||
of ``probs[probs >= threshold]`` is `top_p`, and renormalize the probabilities.
|
||||
|
||||
Returns
|
||||
-------
|
||||
renorm_probs: torch.Tensor
|
||||
Renormalized probabilities, shape ``(batch_size, num_classes)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This combination of ``top_p_renorm_probs`` and ``sampling_from_probs`` should be equivalent to
|
||||
``top_p_sampling_from_probs``.
|
||||
|
||||
"""
|
||||
return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p))
|
||||
|
||||
|
||||
top_p_renorm_prob = top_p_renorm_probs
|
||||
|
||||
|
||||
def _top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
indices: Optional[torch.Tensor],
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
generator: Optional[torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
|
||||
probs,
|
||||
samples,
|
||||
indices,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
generator,
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
top_p: Union[torch.Tensor, float],
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
deterministic: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
|
||||
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||
|
||||
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||
which is more efficient than the naive implementation that launches a series of kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
probs: torch.Tensor
|
||||
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
|
||||
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||
probability distributions.
|
||||
top_p: Union[torch.Tensor, float]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
indices: Optional[torch.Tensor]
|
||||
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||
This allows reusing the same probability distribution for multiple outputs.
|
||||
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||
deterministic: bool
|
||||
Whether to use deterministic kernel implementation, default is ``True``.
|
||||
generator: Optional[torch.Generator]
|
||||
A random number generator for the operation.
|
||||
check_nan: bool
|
||||
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
samples: torch.Tensor
|
||||
Sampled categories, shape ``(batch_size,)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This function expects float32 inputs, and the output is int32.
|
||||
|
||||
"""
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_p_sampling_from_probs_internal(
|
||||
probs, indices, *_to_tensor_scalar_tuple(top_p), deterministic, generator
|
||||
)
|
||||
|
||||
|
||||
def _top_k_top_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
indices: Optional[torch.Tensor],
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
maybe_top_p_arr: Optional[torch.Tensor],
|
||||
top_p_val: float,
|
||||
deterministic: bool,
|
||||
generator: Optional[torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
maybe_top_p_arr = (
|
||||
maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default(
|
||||
probs,
|
||||
samples,
|
||||
indices,
|
||||
maybe_top_k_arr,
|
||||
top_k_val,
|
||||
maybe_top_p_arr,
|
||||
top_p_val,
|
||||
deterministic,
|
||||
generator,
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
top_p: Union[torch.Tensor, float],
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
filter_apply_order: str = "top_k_first",
|
||||
deterministic: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for top-k and top-p sampling from probabilities,
|
||||
|
||||
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||
|
||||
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||
which is more efficient than the naive implementation that launches a series of kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
probs: torch.Tensor
|
||||
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
|
||||
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||
probability distributions.
|
||||
top_k: Union[torch.Tensor, int]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
top_p: Union[torch.Tensor, float]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
indices: Optional[torch.Tensor]
|
||||
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||
This allows reusing the same probability distribution for multiple outputs.
|
||||
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||
filter_apply_order: str
|
||||
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
|
||||
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
|
||||
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
|
||||
deterministic: bool
|
||||
Whether to use deterministic kernel implementation, default is ``True``.
|
||||
generator: Optional[torch.Generator]
|
||||
A random number generator for the operation.
|
||||
check_nan: bool
|
||||
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
samples: torch.Tensor
|
||||
Sampled categories, shape ``(batch_size,)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This function expects float32 inputs, and the output is int32.
|
||||
|
||||
"""
|
||||
if filter_apply_order == "top_k_first":
|
||||
renorm_probs = top_k_renorm_probs(probs, top_k)
|
||||
return top_p_sampling_from_probs(
|
||||
renorm_probs,
|
||||
top_p,
|
||||
indices,
|
||||
deterministic,
|
||||
check_nan=check_nan,
|
||||
generator=generator,
|
||||
)
|
||||
elif filter_apply_order == "joint":
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_k_top_p_sampling_from_probs_internal(
|
||||
probs,
|
||||
indices,
|
||||
*_to_tensor_scalar_tuple(top_k),
|
||||
*_to_tensor_scalar_tuple(top_p),
|
||||
deterministic,
|
||||
generator,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
||||
|
||||
|
||||
def _min_p_sampling_from_probs_internal(
|
||||
probs: torch.Tensor,
|
||||
indices: Optional[torch.Tensor],
|
||||
maybe_min_p_arr: Optional[torch.Tensor],
|
||||
min_p_val: float,
|
||||
deterministic: bool,
|
||||
generator: Optional[torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
with probs.device as device:
|
||||
probs = probs.float()
|
||||
maybe_min_p_arr = (
|
||||
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
|
||||
)
|
||||
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
|
||||
torch.ops.sgl_kernel.min_p_sampling_from_probs.default(
|
||||
probs,
|
||||
samples,
|
||||
indices,
|
||||
maybe_min_p_arr,
|
||||
min_p_val,
|
||||
deterministic,
|
||||
generator,
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def min_p_sampling_from_probs(
|
||||
probs: torch.Tensor,
|
||||
min_p: Union[torch.Tensor, float],
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
deterministic: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
|
||||
|
||||
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||
|
||||
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||
which is more efficient than the naive implementation that launches a series of kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
probs: torch.Tensor
|
||||
Probabilities for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||
and the i-th output will be sampled from the i-th row of probabilities. When indices is provided,
|
||||
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||
probability distributions.
|
||||
min_p: Union[torch.Tensor, float]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
indices: Optional[torch.Tensor]
|
||||
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||
This allows reusing the same probability distribution for multiple outputs.
|
||||
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||
deterministic: bool
|
||||
Whether to use deterministic kernel implementation, default is ``True``.
|
||||
generator: Optional[torch.Generator]
|
||||
A random number generator for the operation.
|
||||
check_nan: bool
|
||||
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
samples: torch.Tensor
|
||||
Sampled categories, shape ``(batch_size,)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This function expects float32 inputs, and the output is int32.
|
||||
"""
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _min_p_sampling_from_probs_internal(
|
||||
probs, indices, *_to_tensor_scalar_tuple(min_p), deterministic, generator
|
||||
)
|
||||
|
||||
|
||||
def _top_k_mask_logits_internal(
|
||||
logits: torch.Tensor,
|
||||
maybe_top_k_arr: Optional[torch.Tensor],
|
||||
top_k_val: int,
|
||||
) -> torch.Tensor:
|
||||
logits = logits.float()
|
||||
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
|
||||
mask_logits = torch.empty_like(logits)
|
||||
torch.ops.sgl_kernel.top_k_mask_logits.default(
|
||||
logits, mask_logits, maybe_top_k_arr, top_k_val
|
||||
)
|
||||
return mask_logits
|
||||
|
||||
|
||||
def top_k_mask_logits(
|
||||
logits: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for masking logits by top-k thresholding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits: torch.Tensor
|
||||
Logits before softmax, shape ``(batch_size, num_classes)``.
|
||||
top_k: Union[torch.Tensor, int]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the top-k threshold for for
|
||||
for masking logits, should be in ``(0, num_classes)``.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
We keep the top-k logits, set the rest to negative infinity.
|
||||
|
||||
Returns
|
||||
-------
|
||||
masked_logits: torch.Tensor
|
||||
Masked logits, shape ``(batch_size, num_classes)``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import torch
|
||||
>>> import flashinfer
|
||||
>>> torch.manual_seed(42)
|
||||
>>> batch_size = 4
|
||||
>>> vocab_size = 5
|
||||
>>> top_k = 3
|
||||
>>> logits = torch.randn(batch_size, vocab_size).to(0)
|
||||
>>> logits
|
||||
tensor([[ 1.9269, 1.4873, 0.9007, -2.1055, -0.7581],
|
||||
[ 1.0783, 0.8008, 1.6806, 0.3559, -0.6866],
|
||||
[-0.4934, 0.2415, -0.2316, 0.0418, -0.2516],
|
||||
[ 0.8599, -0.3097, -0.3957, 0.8034, -0.6216]], device='cuda:0')
|
||||
>>> masked_logits = flashinfer.sampling.top_k_mask_logits(logits, top_k)
|
||||
>>> masked_logits
|
||||
tensor([[ 1.9269, 1.4873, 0.9007, -inf, -inf],
|
||||
[ 1.0783, 0.8008, 1.6806, -inf, -inf],
|
||||
[ -inf, 0.2415, -0.2316, 0.0418, -inf],
|
||||
[ 0.8599, -0.3097, -inf, 0.8034, -inf]], device='cuda:0')
|
||||
|
||||
Note
|
||||
----
|
||||
The combination of ``top_k_mask_logits`` and ``softmax`` should be equivalent to ``top_k_renorm_probs``.
|
||||
|
||||
See Also
|
||||
--------
|
||||
top_k_renorm_probs
|
||||
"""
|
||||
return _top_k_mask_logits_internal(logits, *_to_tensor_scalar_tuple(top_k))
|
||||
|
||||
|
||||
def top_k_top_p_sampling_from_logits(
|
||||
logits: torch.Tensor,
|
||||
top_k: Union[torch.Tensor, int],
|
||||
top_p: Union[torch.Tensor, float],
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
filter_apply_order: str = "top_k_first",
|
||||
deterministic: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
check_nan: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
|
||||
Fused GPU kernel for top-k and top-p sampling from probabilities,
|
||||
|
||||
this operator implements GPU-based rejection sampling without explicit sorting.
|
||||
Check the `blog post <https://flashinfer.ai/2025/03/10/sampling.html>`_ for more details.
|
||||
|
||||
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
|
||||
which is more efficient than the naive implementation that launches a series of kernels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits: torch.Tensor
|
||||
Pre-softmax logits for sampling. When indices is not provided, shape should be ``(batch_size, num_classes)``
|
||||
and the i-th output will be sampled from the i-th row of logits. When indices is provided,
|
||||
shape should be ``(unique_batch_size, num_classes)`` where unique_batch_size is the number of unique
|
||||
probability distributions.
|
||||
top_k: Union[torch.Tensor, int]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-k sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
top_p: Union[torch.Tensor, float]
|
||||
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for top-p sampling.
|
||||
If a scalar, the same threshold is used for all requests.
|
||||
If a tensor, each request has its own threshold.
|
||||
indices: Optional[torch.Tensor]
|
||||
Optional indices tensor of shape ``(batch_size,)`` that maps each output to a row in probs.
|
||||
For example, if indices[i] = j, then the i-th output will be sampled from probs[j].
|
||||
This allows reusing the same probability distribution for multiple outputs.
|
||||
If indices is not provided, the i-th output will be sampled from the i-th row of probs.
|
||||
filter_apply_order: str
|
||||
The order of applying top-k and top-p sampling, should be either ``"top_k_first"`` or ``"joint"``.
|
||||
If ``"top_k_first"``, we first apply top-k filter, then apply top-p sampling on the top-k results.
|
||||
If ``"joint"``, we apply top-k and top-p filter simultaneously in each round. Default is ``"top_k_first"``.
|
||||
deterministic: bool
|
||||
Whether to use deterministic kernel implementation, default is ``True``.
|
||||
generator: Optional[torch.Generator]
|
||||
A random number generator for the operation.
|
||||
check_nan: bool
|
||||
Whether to check nan in :attr:`probs`, default is ``False``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
samples: torch.Tensor
|
||||
Sampled categories, shape ``(batch_size,)``.
|
||||
|
||||
Note
|
||||
----
|
||||
This function expects float32 inputs, and the output is int32.
|
||||
|
||||
"""
|
||||
if filter_apply_order == "top_k_first":
|
||||
masked_logits = top_k_mask_logits(logits, top_k)
|
||||
probs = torch.softmax(masked_logits, dim=-1)
|
||||
return top_p_sampling_from_probs(
|
||||
probs,
|
||||
top_p,
|
||||
indices,
|
||||
deterministic,
|
||||
check_nan=check_nan,
|
||||
generator=generator,
|
||||
)
|
||||
elif filter_apply_order == "joint":
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
if check_nan:
|
||||
if torch.any(torch.isnan(probs)):
|
||||
raise ValueError("Input probs contains NaN.")
|
||||
return _top_k_top_p_sampling_from_probs_internal(
|
||||
probs,
|
||||
indices,
|
||||
*_to_tensor_scalar_tuple(top_k),
|
||||
*_to_tensor_scalar_tuple(top_p),
|
||||
deterministic,
|
||||
generator,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
|
||||
352
sgl-kernel/python/sgl_kernel/scalar_type.py
Normal file
352
sgl-kernel/python/sgl_kernel/scalar_type.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
_SCALAR_TYPES_ID_MAP = {}
|
||||
|
||||
|
||||
# Mirrors enum in `core/scalar_type.hpp`
|
||||
class NanRepr(Enum):
|
||||
NONE = 0 # nans are not supported
|
||||
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
||||
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
||||
|
||||
|
||||
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
||||
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
||||
# in sync until the inductor fully supports custom C++ classes.
|
||||
@dataclass(frozen=True)
|
||||
class ScalarType:
|
||||
"""
|
||||
ScalarType can represent a wide range of floating point and integer
|
||||
types, in particular it can be used to represent sub-byte data types
|
||||
(something that torch.dtype currently does not support). It is also
|
||||
capable of representing types with a bias, i.e.:
|
||||
`stored_value = value + bias`,
|
||||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
||||
of 8). The implementation for this class can be found in
|
||||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
||||
with that file.
|
||||
"""
|
||||
|
||||
exponent: int
|
||||
"""
|
||||
Number of bits in the exponent if this is a floating point type
|
||||
(zero if this an integer type)
|
||||
"""
|
||||
|
||||
mantissa: int
|
||||
"""
|
||||
Number of bits in the mantissa if this is a floating point type,
|
||||
or the number bits representing an integer excluding the sign bit if
|
||||
this an integer type.
|
||||
"""
|
||||
|
||||
signed: bool
|
||||
"If the type is signed (i.e. has a sign bit)"
|
||||
|
||||
bias: int
|
||||
"""
|
||||
bias used to encode the values in this scalar type
|
||||
(value = stored_value - bias, default 0) for example if we store the
|
||||
type as an unsigned integer with a bias of 128 then the value 0 will be
|
||||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
||||
"""
|
||||
|
||||
_finite_values_only: bool = False
|
||||
"""
|
||||
Private: if infs are supported, used `has_infs()` instead.
|
||||
"""
|
||||
|
||||
nan_repr: NanRepr = NanRepr.IEEE_754
|
||||
"""
|
||||
How NaNs are represent in this scalar type, returns NanRepr value.
|
||||
(not applicable for integer types)
|
||||
"""
|
||||
|
||||
def _floating_point_max_int(self) -> int:
|
||||
assert (
|
||||
self.mantissa <= 52 and self.exponent <= 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
|
||||
max_mantissa = (1 << self.mantissa) - 1
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
||||
max_mantissa = max_mantissa - 1
|
||||
|
||||
max_exponent = (1 << self.exponent) - 2
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
|
||||
assert (
|
||||
self.exponent < 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
max_exponent = max_exponent + 1
|
||||
|
||||
# adjust the exponent to match that of a double
|
||||
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
||||
# e is the exponent bits), there is some precedent for non-standard
|
||||
# biases, example `float8_e4m3b11fnuz` here:
|
||||
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
||||
# complication we are just assuming the standard exponent bias until
|
||||
# there is a need to support non-standard biases
|
||||
exponent_bias = (1 << (self.exponent - 1)) - 1
|
||||
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
||||
|
||||
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
|
||||
|
||||
# shift the mantissa and exponent into the proper positions for an
|
||||
# IEEE double and bitwise-or them together.
|
||||
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
|
||||
|
||||
def _floating_point_max(self) -> float:
|
||||
double_raw = self._floating_point_max_int()
|
||||
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
|
||||
|
||||
def _raw_max(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
return self._floating_point_max()
|
||||
else:
|
||||
assert (
|
||||
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
|
||||
), "Cannot represent max as an int"
|
||||
return (1 << self.mantissa) - 1
|
||||
|
||||
def _raw_min(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
assert (
|
||||
self.is_signed()
|
||||
), "We currently assume all floating point types are signed"
|
||||
sign_bit_double = 1 << 63
|
||||
|
||||
max_raw = self._floating_point_max_int()
|
||||
min_raw = max_raw | sign_bit_double
|
||||
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
|
||||
else:
|
||||
assert (
|
||||
not self.is_signed() or self.size_bits <= 64
|
||||
), "Cannot represent min as a int64_t"
|
||||
|
||||
if self.is_signed():
|
||||
return -(1 << (self.size_bits - 1))
|
||||
else:
|
||||
return 0
|
||||
|
||||
@functools.cached_property
|
||||
def id(self) -> int:
|
||||
"""
|
||||
Convert the ScalarType to an int which can be passed to pytorch custom
|
||||
ops. This layout of the int must be kept in sync with the C++
|
||||
ScalarType's from_id method.
|
||||
"""
|
||||
val = 0
|
||||
offset = 0
|
||||
|
||||
def or_and_advance(member, bit_width):
|
||||
nonlocal val
|
||||
nonlocal offset
|
||||
bit_mask = (1 << bit_width) - 1
|
||||
val = val | (int(member) & bit_mask) << offset
|
||||
offset = offset + bit_width
|
||||
|
||||
or_and_advance(self.exponent, 8)
|
||||
or_and_advance(self.mantissa, 8)
|
||||
or_and_advance(self.signed, 1)
|
||||
or_and_advance(self.bias, 32)
|
||||
or_and_advance(self._finite_values_only, 1)
|
||||
or_and_advance(self.nan_repr.value, 8)
|
||||
|
||||
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
|
||||
|
||||
_SCALAR_TYPES_ID_MAP[val] = self
|
||||
|
||||
return val
|
||||
|
||||
@property
|
||||
def size_bits(self) -> int:
|
||||
return self.exponent + self.mantissa + int(self.signed)
|
||||
|
||||
def min(self) -> Union[int, float]:
|
||||
"""
|
||||
Min representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_min() - self.bias
|
||||
|
||||
def max(self) -> Union[int, float]:
|
||||
"""
|
||||
Max representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_max() - self.bias
|
||||
|
||||
def is_signed(self) -> bool:
|
||||
"""
|
||||
If the type is signed (i.e. has a sign bit), same as `signed`
|
||||
added for consistency with:
|
||||
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
||||
"""
|
||||
return self.signed
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
"If the type is a floating point type"
|
||||
return self.exponent != 0
|
||||
|
||||
def is_integer(self) -> bool:
|
||||
"If the type is an integer type"
|
||||
return self.exponent == 0
|
||||
|
||||
def has_bias(self) -> bool:
|
||||
"If the type has a non-zero bias"
|
||||
return self.bias != 0
|
||||
|
||||
def has_infs(self) -> bool:
|
||||
"If the type is floating point and supports infinity"
|
||||
return not self._finite_values_only
|
||||
|
||||
def has_nans(self) -> bool:
|
||||
return self.nan_repr != NanRepr.NONE
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
"""
|
||||
If the type is a floating point type that follows IEEE 754
|
||||
conventions
|
||||
"""
|
||||
return self.nan_repr == NanRepr.IEEE_754 and not self._finite_values_only
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
for floating point types (leading f) the scheme is:
|
||||
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
flags:
|
||||
- no-flags: means it follows IEEE 754 conventions
|
||||
- f: means finite values only (no infinities)
|
||||
- n: means nans are supported (non-standard encoding)
|
||||
for integer types the scheme is:
|
||||
`[u]int<size_bits>[b<bias>]`
|
||||
- if bias is not present it means its zero
|
||||
"""
|
||||
if self.is_floating_point():
|
||||
ret = (
|
||||
"float"
|
||||
+ str(self.size_bits)
|
||||
+ "_e"
|
||||
+ str(self.exponent)
|
||||
+ "m"
|
||||
+ str(self.mantissa)
|
||||
)
|
||||
|
||||
if not self.is_ieee_754():
|
||||
if self._finite_values_only:
|
||||
ret = ret + "f"
|
||||
if self.nan_repr != NanRepr.NONE:
|
||||
ret = ret + "n"
|
||||
|
||||
return ret
|
||||
else:
|
||||
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
||||
if self.has_bias():
|
||||
ret = ret + "b" + str(self.bias)
|
||||
return ret
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "ScalarType." + self.__str__()
|
||||
|
||||
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
||||
# opcheck to work.
|
||||
def __len__(self) -> int:
|
||||
raise TypeError
|
||||
|
||||
#
|
||||
# Convenience Constructors
|
||||
#
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
||||
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"""Create a unsigned integer scalar type."""
|
||||
ret = cls(0, size_bits, False, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
|
||||
"""
|
||||
Create a standard floating point type
|
||||
(i.e. follows IEEE 754 conventions).
|
||||
"""
|
||||
assert mantissa > 0 and exponent > 0
|
||||
ret = cls(exponent, mantissa, True, 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_(
|
||||
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
|
||||
) -> "ScalarType":
|
||||
"""
|
||||
Create a non-standard floating point type
|
||||
(i.e. does not follow IEEE 754 conventions).
|
||||
"""
|
||||
assert mantissa > 0 and exponent > 0
|
||||
assert nan_repr != NanRepr.IEEE_754, (
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions"
|
||||
)
|
||||
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_id(cls, scalar_type_id: int):
|
||||
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
|
||||
raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.")
|
||||
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
|
||||
|
||||
|
||||
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
# for floating point types (leading f) the scheme is:
|
||||
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
# flags:
|
||||
# - no-flags: means it follows IEEE 754 conventions
|
||||
# - f: means finite values only (no infinities)
|
||||
# - n: means nans are supported (non-standard encoding)
|
||||
# for integer types the scheme is:
|
||||
# `[u]int<size_bits>[b<bias>]`
|
||||
# - if bias is not present it means its zero
|
||||
|
||||
|
||||
class scalar_types:
|
||||
int4 = ScalarType.int_(4, None)
|
||||
uint4 = ScalarType.uint(4, None)
|
||||
int8 = ScalarType.int_(8, None)
|
||||
uint8 = ScalarType.uint(8, None)
|
||||
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
||||
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
||||
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
||||
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
||||
|
||||
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
||||
|
||||
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
|
||||
|
||||
# "gptq" types
|
||||
uint2b2 = ScalarType.uint(2, 2)
|
||||
uint3b4 = ScalarType.uint(3, 4)
|
||||
uint4b8 = ScalarType.uint(4, 8)
|
||||
uint8b128 = ScalarType.uint(8, 128)
|
||||
|
||||
# colloquial names
|
||||
bfloat16 = float16_e8m7
|
||||
float16 = float16_e5m10
|
||||
293
sgl-kernel/python/sgl_kernel/sparse_flash_attn.py
Normal file
293
sgl-kernel/python/sgl_kernel/sparse_flash_attn.py
Normal file
@@ -0,0 +1,293 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def maybe_contiguous(x):
|
||||
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
||||
|
||||
|
||||
# Sparse attention utils
|
||||
def convert_vertical_slash_indexes(
|
||||
q_seqlens: torch.Tensor, # [BATCH, ]
|
||||
kv_seqlens: torch.Tensor, # [BATCH, ]
|
||||
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
||||
context_size: int,
|
||||
block_size_M: int,
|
||||
block_size_N: int,
|
||||
causal: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = slash_indexes.size(0)
|
||||
num_heads = slash_indexes.size(1)
|
||||
nnz_slash = slash_indexes.size(2)
|
||||
nnz_vertical = vertical_indexes.size(2)
|
||||
num_rows = (context_size + block_size_M - 1) // block_size_M
|
||||
|
||||
block_count = torch.zeros(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
block_offset = torch.zeros(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_slash,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
column_count = torch.zeros(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
column_index = torch.zeros(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_vertical,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.convert_vertical_slash_indexes.default(
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
context_size,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
causal,
|
||||
)
|
||||
return block_count, block_offset, column_count, column_index
|
||||
|
||||
|
||||
def convert_vertical_slash_indexes_mergehead(
|
||||
q_seqlens: torch.Tensor, # [BATCH, ]
|
||||
kv_seqlens: torch.Tensor, # [BATCH, ]
|
||||
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
||||
# [N_HEADS] : different head use different number of indices
|
||||
vertical_indices_count: torch.Tensor,
|
||||
slash_indices_count: torch.Tensor,
|
||||
context_size: int,
|
||||
block_size_M: int,
|
||||
block_size_N: int,
|
||||
causal: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = slash_indexes.size(0)
|
||||
num_heads = slash_indexes.size(1)
|
||||
nnz_slash = slash_indexes.size(2)
|
||||
nnz_vertical = vertical_indexes.size(2)
|
||||
num_rows = (context_size + block_size_M - 1) // block_size_M
|
||||
|
||||
block_count = torch.empty(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
block_offset = torch.empty(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_slash,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
column_count = torch.empty(
|
||||
batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device
|
||||
)
|
||||
column_index = torch.empty(
|
||||
batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_vertical,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device,
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.convert_vertical_slash_indexes_mergehead.default(
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
q_seqlens,
|
||||
kv_seqlens,
|
||||
vertical_indexes,
|
||||
slash_indexes,
|
||||
vertical_indices_count,
|
||||
slash_indices_count,
|
||||
context_size,
|
||||
block_size_M,
|
||||
block_size_N,
|
||||
causal,
|
||||
)
|
||||
return block_count, block_offset, column_count, column_index
|
||||
|
||||
|
||||
def sparse_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
*,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
):
|
||||
"""Compute attention with vertical and slash sparsity patterns.
|
||||
Most Arguments are the same with the flash_attn_func interface, except for 4 extra args:
|
||||
block_count and block_offset for slash sparsity patterns, and
|
||||
column_count and column_index for vertical sparsity patterns.
|
||||
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k: (batch_size, seqlen, nheads_k, headdim)
|
||||
v: (batch_size, seqlen, nheads_k, headdim)
|
||||
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
out,
|
||||
alibi_slopes,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
softcap,
|
||||
return_attn_probs and dropout_p > 0,
|
||||
None,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
|
||||
|
||||
def sparse_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
softcap=0.0, # 0.0 means deactivated
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
*,
|
||||
return_softmax_lse=False,
|
||||
out=None,
|
||||
):
|
||||
"""Compute attention with vertical and slash sparsity patterns.
|
||||
Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args:
|
||||
block_count and block_offset for slash sparsity patterns, and
|
||||
column_count and column_index for vertical sparsity patterns.
|
||||
For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S)
|
||||
column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M))
|
||||
column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V)
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into kv.
|
||||
max_seqlen_q: int. Maximum query sequence length in the batch.
|
||||
max_seqlen_k: int. Maximum key sequence length in the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
softcap: float. Anything > 0 activates softcapping attention.
|
||||
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
||||
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
block_count,
|
||||
block_offset,
|
||||
column_count,
|
||||
column_index,
|
||||
out,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None,
|
||||
alibi_slopes,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
softcap,
|
||||
return_attn_probs and dropout_p > 0,
|
||||
None,
|
||||
)
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
63
sgl-kernel/python/sgl_kernel/spatial.py
Normal file
63
sgl-kernel/python/sgl_kernel/spatial.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from torch.cuda.streams import ExternalStream
|
||||
|
||||
try:
|
||||
from . import spatial_ops # triggers TORCH extension registration
|
||||
except Exception as _e:
|
||||
_spatial_import_error = _e
|
||||
else:
|
||||
_spatial_import_error = None
|
||||
|
||||
_IMPORT_ERROR = ImportError(
|
||||
"Failed to load sgl_kernel.spatial_ops extension. Ensure CUDA Driver >= 12.4"
|
||||
)
|
||||
|
||||
|
||||
def create_greenctx_stream_by_value(
|
||||
SM_a: int, SM_b: int, device_id: int = None
|
||||
) -> tuple[ExternalStream, ExternalStream]:
|
||||
"""
|
||||
Create two streams for greenctx.
|
||||
Args:
|
||||
sm_A (int): The SM of stream A.
|
||||
sm_B (int): The weight of stream B.
|
||||
device_id (int): The device id.
|
||||
Returns:
|
||||
tuple[ExternalStream, ExternalStream]: The two streams.
|
||||
"""
|
||||
if _spatial_import_error is not None:
|
||||
raise _IMPORT_ERROR from _spatial_import_error
|
||||
if device_id is None:
|
||||
device_id = torch.cuda.current_device()
|
||||
|
||||
res = torch.ops.sgl_kernel.create_greenctx_stream_by_value(SM_a, SM_b, device_id)
|
||||
|
||||
stream_a = ExternalStream(
|
||||
stream_ptr=res[0], device=torch.device(f"cuda:{device_id}")
|
||||
)
|
||||
stream_b = ExternalStream(
|
||||
stream_ptr=res[1], device=torch.device(f"cuda:{device_id}")
|
||||
)
|
||||
|
||||
return stream_a, stream_b
|
||||
|
||||
|
||||
def get_sm_available(device_id: int = None) -> int:
|
||||
"""
|
||||
Get the SMs available on the device.
|
||||
Args:
|
||||
device_id (int): The device id.
|
||||
Returns:
|
||||
int: The SMs available.
|
||||
"""
|
||||
if _spatial_import_error is not None:
|
||||
raise _IMPORT_ERROR from _spatial_import_error
|
||||
if device_id is None:
|
||||
device_id = torch.cuda.current_device()
|
||||
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
|
||||
# Get the number of Streaming Multiprocessors (SMs)
|
||||
sm_count = device_props.multi_processor_count
|
||||
|
||||
return sm_count
|
||||
107
sgl-kernel/python/sgl_kernel/speculative.py
Normal file
107
sgl-kernel/python/sgl_kernel/speculative.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream
|
||||
|
||||
|
||||
def tree_speculative_sampling_target_only(
|
||||
predicts: torch.Tensor, # mutable
|
||||
accept_index: torch.Tensor, # mutable
|
||||
accept_token_num: torch.Tensor, # mutable
|
||||
candidates: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
uniform_samples: torch.Tensor,
|
||||
uniform_samples_for_final_sampling: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
threshold_single: float = 1.0,
|
||||
threshold_acc: float = 1.0,
|
||||
deterministic: bool = True,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.tree_speculative_sampling_target_only.default(
|
||||
predicts,
|
||||
accept_index,
|
||||
accept_token_num,
|
||||
candidates,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
uniform_samples,
|
||||
uniform_samples_for_final_sampling,
|
||||
target_probs,
|
||||
draft_probs,
|
||||
threshold_single,
|
||||
threshold_acc,
|
||||
deterministic,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def verify_tree_greedy(
|
||||
predicts: torch.Tensor, # mutable
|
||||
accept_index: torch.Tensor, # mutable
|
||||
accept_token_num: torch.Tensor, # mutable
|
||||
candidates: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
target_predict: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.verify_tree_greedy.default(
|
||||
predicts,
|
||||
accept_index,
|
||||
accept_token_num,
|
||||
candidates,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
target_predict,
|
||||
get_cuda_stream(),
|
||||
)
|
||||
|
||||
|
||||
def build_tree_kernel_efficient(
|
||||
parent_list: torch.Tensor,
|
||||
selected_index: torch.Tensor,
|
||||
verified_seq_len: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
topk: int,
|
||||
depth: int,
|
||||
draft_token_num: int,
|
||||
tree_mask_mode: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
|
||||
parent_list,
|
||||
selected_index,
|
||||
verified_seq_len,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
topk,
|
||||
depth,
|
||||
draft_token_num,
|
||||
tree_mask_mode,
|
||||
)
|
||||
|
||||
|
||||
def segment_packbits(
|
||||
x: torch.Tensor,
|
||||
input_indptr: torch.Tensor,
|
||||
output_indptr: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.segment_packbits.default(
|
||||
x,
|
||||
input_indptr,
|
||||
output_indptr,
|
||||
y,
|
||||
batch_size,
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
)
|
||||
0
sgl-kernel/python/sgl_kernel/testing/__init__.py
Normal file
0
sgl-kernel/python/sgl_kernel/testing/__init__.py
Normal file
217
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
Normal file
217
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
|
||||
|
||||
|
||||
# vLLM torch native
|
||||
def _apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
||||
)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
|
||||
# Modification: float32 is required for the rotary embedding to work correctly
|
||||
query = query.to(torch.float32)
|
||||
key = key.to(torch.float32)
|
||||
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
|
||||
# Modification: convert to the correct dtype
|
||||
query = query.to(self.dtype)
|
||||
key = key.to(self.dtype)
|
||||
return query, key
|
||||
|
||||
|
||||
class FlashInferRotaryEmbedding(RotaryEmbedding):
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
apply_rope_with_cos_sin_cache_inplace(
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
fused_set_kv_buffer_arg=fused_set_kv_buffer_arg,
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=self.cos_sin_cache,
|
||||
is_neox=self.is_neox_style,
|
||||
)
|
||||
|
||||
return query, key
|
||||
|
||||
|
||||
class MHATokenToKVPool:
|
||||
KV_POOL_SIZE = 16384
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
):
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
self.size = MHATokenToKVPool.KV_POOL_SIZE
|
||||
self.page_size = 1
|
||||
self.store_dtype = torch.bfloat16
|
||||
self.device = "cuda"
|
||||
self.layer_num = 1
|
||||
self.start_layer = 0
|
||||
self._create_buffers()
|
||||
|
||||
def _create_buffers(self):
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(self.size + self.page_size, self.head_num, self.head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
):
|
||||
layer_id = 0
|
||||
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||
|
||||
|
||||
def create_inputs(
|
||||
head_size: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
device,
|
||||
dtype: torch.dtype,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
|
||||
)
|
||||
value = torch.randn(
|
||||
batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device
|
||||
)
|
||||
out_cache_loc = torch.randperm(
|
||||
MHATokenToKVPool.KV_POOL_SIZE, dtype=torch.int64, device=device
|
||||
)[: batch_size * seq_len].clone()
|
||||
|
||||
return dict(
|
||||
pos_ids=pos_ids, query=query, key=key, value=value, out_cache_loc=out_cache_loc
|
||||
)
|
||||
11
sgl-kernel/python/sgl_kernel/top_k.py
Normal file
11
sgl-kernel/python/sgl_kernel/top_k.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import torch
|
||||
|
||||
|
||||
def fast_topk(values, topk, dim):
|
||||
if topk == 1:
|
||||
# Use max along the specified dimension to get both value and index
|
||||
return torch.max(values, dim=dim, keepdim=True)
|
||||
else:
|
||||
# Use topk for efficiency with larger k values
|
||||
# TODO: implement faster cuda kernels for large vocab sizes
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
50
sgl-kernel/python/sgl_kernel/utils.py
Normal file
50
sgl-kernel/python/sgl_kernel/utils.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import functools
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_cuda_stream() -> int:
|
||||
return torch.cuda.current_stream().cuda_stream
|
||||
|
||||
|
||||
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
|
||||
|
||||
|
||||
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
|
||||
key = (name, device)
|
||||
buf = _cache_buf.get(key)
|
||||
if buf is None:
|
||||
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
|
||||
_cache_buf[key] = buf
|
||||
return buf
|
||||
|
||||
|
||||
def _to_tensor_scalar_tuple(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return (x, 0)
|
||||
else:
|
||||
return (None, x)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def is_arch_support_pdl() -> bool:
|
||||
# Hopper arch's compute capability == 9.0
|
||||
device = torch.cuda.current_device()
|
||||
major, minor = torch.cuda.get_device_capability(device)
|
||||
return major >= 9
|
||||
1
sgl-kernel/python/sgl_kernel/version.py
Normal file
1
sgl-kernel/python/sgl_kernel/version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.3.9.post2"
|
||||
Reference in New Issue
Block a user